package kscience.kmath.nd4j import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.shape.Shape private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iterator { private var i: Int = 0 override fun hasNext(): Boolean = i < iterateOver.length() override fun next(): IntArray { val la = if (iterateOver.ordering() == 'c') Shape.ind2subC(iterateOver, i++.toLong())!! else Shape.ind2sub(iterateOver, i++.toLong())!! return la.toIntArray() } } internal fun INDArray.indicesIterator(): Iterator = Nd4jArrayIndicesIterator(this) private sealed class Nd4jArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { private var i: Int = 0 final override fun hasNext(): Boolean = i < iterateOver.length() abstract fun getSingle(indices: LongArray): T final override fun next(): Pair { val la = if (iterateOver.ordering() == 'c') Shape.ind2subC(iterateOver, i++.toLong())!! else Shape.ind2sub(iterateOver, i++.toLong())!! return la.toIntArray() to getSingle(la) } } private class Nd4jArrayRealIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices) } internal fun INDArray.realIterator(): Iterator> = Nd4jArrayRealIterator(this) private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices) } internal fun INDArray.longIterator(): Iterator> = Nd4jArrayLongIterator(this) private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray()) } internal fun INDArray.intIterator(): Iterator> = Nd4jArrayIntIterator(this) private class Nd4jArrayFloatIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices) } internal fun INDArray.floatIterator(): Iterator> = Nd4jArrayFloatIterator(this)