Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116
@ -3,7 +3,7 @@ package scientifik.kmath.nd4j
|
|||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import scientifik.kmath.structures.NDStructure
|
import scientifik.kmath.structures.NDStructure
|
||||||
|
|
||||||
interface INDArrayStructureBase<T> : NDStructure<T> {
|
interface INDArrayStructure<T> : NDStructure<T> {
|
||||||
val ndArray: INDArray
|
val ndArray: INDArray
|
||||||
|
|
||||||
override val shape: IntArray
|
override val shape: IntArray
|
||||||
@ -13,22 +13,22 @@ interface INDArrayStructureBase<T> : NDStructure<T> {
|
|||||||
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructureBase<Int> {
|
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
|
||||||
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructureBase<Long> {
|
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
|
||||||
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
|
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructureBase<Double> {
|
data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure<Double> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
|
||||||
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructureBase<Float> {
|
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
|
||||||
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user