Convert INDArray NDStructures implementations to inline classes, add tests to verify equals and hashCode
This commit is contained in:
parent
3b18000f1e
commit
eb9d40fd2a
@ -14,25 +14,30 @@ interface INDArrayStructure<T> : NDStructure<T> {
|
|||||||
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
|
inline class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
|
||||||
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
||||||
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
||||||
|
override fun toString(): String = "INDArrayIntStructure(ndArray=$ndArray)"
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
inline class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
|
||||||
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
|
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
|
||||||
|
override fun toString(): String = "INDArrayLongStructure(ndArray=$ndArray)"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure<Double>, MutableNDStructure<Double> {
|
inline class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure<Double>, MutableNDStructure<Double> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
|
||||||
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||||
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
||||||
|
override fun toString(): String = "INDArrayDoubleStructure(ndArray=$ndArray)"
|
||||||
}
|
}
|
||||||
|
|
||||||
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>, MutableNDStructure<Float> {
|
inline class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>, MutableNDStructure<Float> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
|
||||||
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
||||||
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
||||||
|
override fun toString(): String = "INDArrayFloatStructure(ndArray=$ndArray)"
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import org.nd4j.linalg.factory.Nd4j
|
|||||||
import scientifik.kmath.structures.get
|
import scientifik.kmath.structures.get
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertNotEquals
|
||||||
|
|
||||||
internal class INDArrayStructureTest {
|
internal class INDArrayStructureTest {
|
||||||
@Test
|
@Test
|
||||||
@ -25,9 +26,25 @@ internal class INDArrayStructureTest {
|
|||||||
fun testEquals() {
|
fun testEquals() {
|
||||||
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
val struct1 = INDArrayDoubleStructure(nd1)
|
val struct1 = INDArrayDoubleStructure(nd1)
|
||||||
|
assertEquals(struct1, struct1)
|
||||||
|
assertNotEquals(struct1, null as INDArrayDoubleStructure?)
|
||||||
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
val struct2 = INDArrayDoubleStructure(nd2)
|
val struct2 = INDArrayDoubleStructure(nd2)
|
||||||
assertEquals(struct1, struct2)
|
assertEquals(struct1, struct2)
|
||||||
|
assertEquals(struct2, struct1)
|
||||||
|
val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
|
val struct3 = INDArrayDoubleStructure(nd3)
|
||||||
|
assertEquals(struct2, struct3)
|
||||||
|
assertEquals(struct1, struct3)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testHashCode() {
|
||||||
|
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
|
val struct1 = INDArrayDoubleStructure(nd1)
|
||||||
|
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
|
val struct2 = INDArrayDoubleStructure(nd2)
|
||||||
|
assertEquals(struct1.hashCode(), struct2.hashCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
Reference in New Issue
Block a user