forked from kscience/kmath
Make several NDStructures mutable
This commit is contained in:
parent
f55b2c7a40
commit
3b18000f1e
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.nd4j
|
package scientifik.kmath.nd4j
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import scientifik.kmath.structures.MutableNDStructure
|
||||||
import scientifik.kmath.structures.NDStructure
|
import scientifik.kmath.structures.NDStructure
|
||||||
|
|
||||||
interface INDArrayStructure<T> : NDStructure<T> {
|
interface INDArrayStructure<T> : NDStructure<T> {
|
||||||
@ -13,9 +14,10 @@ interface INDArrayStructure<T> : NDStructure<T> {
|
|||||||
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int> {
|
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
|
||||||
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
||||||
|
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
||||||
@ -23,12 +25,14 @@ data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStruc
|
|||||||
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
|
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure<Double> {
|
data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure<Double>, MutableNDStructure<Double> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
|
||||||
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||||
|
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float> {
|
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>, MutableNDStructure<Float> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
|
||||||
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
||||||
|
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.nd4j
|
package scientifik.kmath.nd4j
|
||||||
|
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import scientifik.kmath.structures.get
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import scientifik.kmath.structures.get
|
|
||||||
|
|
||||||
internal class INDArrayStructureTest {
|
internal class INDArrayStructureTest {
|
||||||
@Test
|
@Test
|
||||||
@ -17,7 +17,7 @@ internal class INDArrayStructureTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testShape() {
|
fun testShape() {
|
||||||
val nd = Nd4j.rand(10, 2, 3, 6)!!
|
val nd = Nd4j.rand(10, 2, 3, 6)!!
|
||||||
val struct = INDArrayIntStructure(nd)
|
val struct = INDArrayLongStructure(nd)
|
||||||
assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList())
|
assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ internal class INDArrayStructureTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testDimension() {
|
fun testDimension() {
|
||||||
val nd = Nd4j.rand(8, 16, 3, 7, 1)!!
|
val nd = Nd4j.rand(8, 16, 3, 7, 1)!!
|
||||||
val struct = INDArrayIntStructure(nd)
|
val struct = INDArrayFloatStructure(nd)
|
||||||
assertEquals(5, struct.dimension)
|
assertEquals(5, struct.dimension)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,4 +43,12 @@ internal class INDArrayStructureTest {
|
|||||||
val struct = INDArrayIntStructure(nd)
|
val struct = INDArrayIntStructure(nd)
|
||||||
assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0])
|
assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSet() {
|
||||||
|
val nd = Nd4j.rand(17, 12, 4, 8)!!
|
||||||
|
val struct = INDArrayIntStructure(nd)
|
||||||
|
struct[intArrayOf(1, 2, 3, 4)] = 777
|
||||||
|
assertEquals(777, struct[1, 2, 3, 4])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user