2020-06-11 14:36:19 +07:00
|
|
|
package scientifik.kmath.nd4j
|
|
|
|
|
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray
|
2020-06-29 00:14:01 +07:00
|
|
|
import scientifik.kmath.structures.MutableNDStructure
|
2020-06-11 14:36:19 +07:00
|
|
|
import scientifik.kmath.structures.NDStructure
|
|
|
|
|
2020-06-28 17:30:09 +07:00
|
|
|
interface INDArrayStructure<T> : NDStructure<T> {
|
2020-06-11 14:36:19 +07:00
|
|
|
val ndArray: INDArray
|
|
|
|
|
|
|
|
override val shape: IntArray
|
2020-06-29 21:31:08 +07:00
|
|
|
get() = ndArray.shape().toIntArray()
|
2020-06-11 14:36:19 +07:00
|
|
|
|
|
|
|
fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
2020-06-27 21:19:19 +07:00
|
|
|
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
2020-06-11 14:36:19 +07:00
|
|
|
}
|
|
|
|
|
2020-06-29 02:50:34 +07:00
|
|
|
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
|
2020-06-29 22:06:13 +07:00
|
|
|
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
|
2020-06-11 14:36:19 +07:00
|
|
|
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
2020-06-29 00:14:01 +07:00
|
|
|
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
2020-06-11 14:36:19 +07:00
|
|
|
}
|
|
|
|
|
2020-06-29 02:50:34 +07:00
|
|
|
fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this)
|
|
|
|
|
|
|
|
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
2020-06-29 22:06:13 +07:00
|
|
|
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
|
2020-06-29 21:31:08 +07:00
|
|
|
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
|
2020-06-11 14:36:19 +07:00
|
|
|
}
|
|
|
|
|
2020-06-29 02:50:34 +07:00
|
|
|
fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)
|
|
|
|
|
|
|
|
data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>,
|
|
|
|
MutableNDStructure<Double> {
|
2020-06-29 22:06:13 +07:00
|
|
|
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
|
2020-06-11 14:36:19 +07:00
|
|
|
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
2020-06-29 00:14:01 +07:00
|
|
|
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
2020-06-11 14:36:19 +07:00
|
|
|
}
|
|
|
|
|
2020-06-29 02:50:34 +07:00
|
|
|
fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this)
|
|
|
|
|
|
|
|
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>,
|
|
|
|
MutableNDStructure<Float> {
|
2020-06-29 22:06:13 +07:00
|
|
|
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
|
2020-06-11 14:36:19 +07:00
|
|
|
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
2020-06-29 00:14:01 +07:00
|
|
|
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
2020-06-11 14:36:19 +07:00
|
|
|
}
|
2020-06-29 02:50:34 +07:00
|
|
|
|
|
|
|
fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this)
|