Rework with specialized NDStructure implementations
This commit is contained in:
parent
9a4dd31507
commit
d0cc75098b
@ -1,3 +1,4 @@
|
|||||||
package scientifik.kmath.nd4j
|
package scientifik.kmath.nd4j
|
||||||
|
|
||||||
internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() }
|
internal fun widenToLongArray(ia: IntArray): LongArray = LongArray(ia.size) { ia[it].toLong() }
|
||||||
|
internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() }
|
||||||
|
@ -1,19 +0,0 @@
|
|||||||
package scientifik.kmath.nd4j
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
|
||||||
import org.nd4j.linalg.api.shape.Shape
|
|
||||||
|
|
||||||
internal class INDArrayScalarsIterator(private val iterateOver: INDArray) : Iterator<Pair<IntArray, INDArray>> {
|
|
||||||
private var i: Int = 0
|
|
||||||
|
|
||||||
override fun hasNext(): Boolean = i < iterateOver.length()
|
|
||||||
|
|
||||||
override fun next(): Pair<IntArray, INDArray> {
|
|
||||||
val idx = if (iterateOver.ordering() == 'c')
|
|
||||||
Shape.ind2subC(iterateOver, i++.toLong())!!
|
|
||||||
else
|
|
||||||
Shape.ind2sub(iterateOver, i++.toLong())!!
|
|
||||||
|
|
||||||
return narrowToIntArray(idx) to iterateOver.getScalar(*idx)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
package scientifik.kmath.nd4j
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
|
||||||
import scientifik.kmath.structures.NDStructure
|
|
||||||
|
|
||||||
data class ND4JStructure<T>(val ndArray: INDArray) : NDStructure<INDArray> {
|
|
||||||
override val shape: IntArray
|
|
||||||
get() = narrowToIntArray(ndArray.shape())
|
|
||||||
|
|
||||||
override fun get(index: IntArray): INDArray = ndArray.getScalar(*index)
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, INDArray>> = Sequence { INDArrayScalarsIterator(ndArray) }
|
|
||||||
}
|
|
@ -0,0 +1,37 @@
|
|||||||
|
package scientifik.kmath.nd4j
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import org.nd4j.linalg.api.shape.Shape
|
||||||
|
|
||||||
|
internal sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
|
||||||
|
private var i: Int = 0
|
||||||
|
|
||||||
|
override fun hasNext(): Boolean = i < iterateOver.length()
|
||||||
|
|
||||||
|
abstract fun getSingle(indices: LongArray): T
|
||||||
|
|
||||||
|
final override fun next(): Pair<IntArray, T> {
|
||||||
|
val la = if (iterateOver.ordering() == 'c')
|
||||||
|
Shape.ind2subC(iterateOver, i++.toLong())!!
|
||||||
|
else
|
||||||
|
Shape.ind2sub(iterateOver, i++.toLong())!!
|
||||||
|
|
||||||
|
return narrowToIntArray(la) to getSingle(la)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class INDArrayDoubleIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
|
||||||
|
override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase<Long>(iterateOver) {
|
||||||
|
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
|
||||||
|
override fun getSingle(indices: LongArray) = iterateOver.getInt(*narrowToIntArray(indices))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase<Float>(iterateOver) {
|
||||||
|
override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package scientifik.kmath.nd4j
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import scientifik.kmath.structures.NDStructure
|
||||||
|
|
||||||
|
interface INDArrayStructureBase<T> : NDStructure<T> {
|
||||||
|
val ndArray: INDArray
|
||||||
|
|
||||||
|
override val shape: IntArray
|
||||||
|
get() = narrowToIntArray(ndArray.shape())
|
||||||
|
|
||||||
|
fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence { elementsIterator() }
|
||||||
|
}
|
||||||
|
|
||||||
|
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructureBase<Int> {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
|
||||||
|
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
||||||
|
}
|
||||||
|
|
||||||
|
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructureBase<Long> {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
|
||||||
|
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
|
||||||
|
}
|
||||||
|
|
||||||
|
data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructureBase<Double> {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
|
||||||
|
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||||
|
}
|
||||||
|
|
||||||
|
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructureBase<Float> {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
|
||||||
|
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user