2020-09-21 16:53:31 +03:00
|
|
|
package kscience.kmath.nd4j
|
2020-06-11 10:36:19 +03:00
|
|
|
|
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray
|
|
|
|
import org.nd4j.linalg.api.shape.Shape
|
|
|
|
|
2020-08-15 14:35:16 +03:00
|
|
|
private class INDArrayIndicesIterator(private val iterateOver: INDArray) : Iterator<IntArray> {
|
|
|
|
private var i: Int = 0
|
|
|
|
|
|
|
|
override fun hasNext(): Boolean = i < iterateOver.length()
|
|
|
|
|
|
|
|
override fun next(): IntArray {
|
|
|
|
val la = if (iterateOver.ordering() == 'c')
|
|
|
|
Shape.ind2subC(iterateOver, i++.toLong())!!
|
|
|
|
else
|
|
|
|
Shape.ind2sub(iterateOver, i++.toLong())!!
|
|
|
|
|
|
|
|
return la.toIntArray()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-09-21 16:53:31 +03:00
|
|
|
internal fun INDArray.indicesIterator(): Iterator<IntArray> = INDArrayIndicesIterator(this)
|
2020-08-15 14:35:16 +03:00
|
|
|
|
|
|
|
private sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
|
2020-06-11 10:36:19 +03:00
|
|
|
private var i: Int = 0
|
|
|
|
|
2020-06-28 13:33:09 +03:00
|
|
|
final override fun hasNext(): Boolean = i < iterateOver.length()
|
2020-06-11 10:36:19 +03:00
|
|
|
|
|
|
|
abstract fun getSingle(indices: LongArray): T
|
|
|
|
|
|
|
|
final override fun next(): Pair<IntArray, T> {
|
|
|
|
val la = if (iterateOver.ordering() == 'c')
|
|
|
|
Shape.ind2subC(iterateOver, i++.toLong())!!
|
|
|
|
else
|
|
|
|
Shape.ind2sub(iterateOver, i++.toLong())!!
|
|
|
|
|
2020-06-29 17:31:08 +03:00
|
|
|
return la.toIntArray() to getSingle(la)
|
2020-06-11 10:36:19 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-08-15 14:35:16 +03:00
|
|
|
private class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
|
2020-06-11 10:36:19 +03:00
|
|
|
override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
|
|
|
|
}
|
|
|
|
|
2020-08-15 14:35:16 +03:00
|
|
|
internal fun INDArray.realIterator(): Iterator<Pair<IntArray, Double>> = INDArrayRealIterator(this)
|
2020-06-28 22:50:34 +03:00
|
|
|
|
2020-08-15 14:35:16 +03:00
|
|
|
private class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase<Long>(iterateOver) {
|
2020-06-11 10:36:19 +03:00
|
|
|
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
|
|
|
|
}
|
|
|
|
|
2020-08-15 14:35:16 +03:00
|
|
|
internal fun INDArray.longIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(this)
|
2020-06-28 22:50:34 +03:00
|
|
|
|
2020-08-15 14:35:16 +03:00
|
|
|
private class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
|
2020-06-29 17:31:08 +03:00
|
|
|
override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
|
2020-06-11 10:36:19 +03:00
|
|
|
}
|
|
|
|
|
2020-08-15 14:35:16 +03:00
|
|
|
internal fun INDArray.intIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(this)
|
2020-06-29 17:31:08 +03:00
|
|
|
|
2020-08-15 14:35:16 +03:00
|
|
|
private class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase<Float>(iterateOver) {
|
2020-06-11 10:36:19 +03:00
|
|
|
override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
|
|
|
|
}
|
2020-06-29 17:31:08 +03:00
|
|
|
|
2020-08-15 14:35:16 +03:00
|
|
|
internal fun INDArray.floatIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(this)
|