diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt index f39d84716..66aa00fac 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt @@ -14,25 +14,30 @@ interface INDArrayStructure : NDStructure { override fun elements(): Sequence> = Sequence(::elementsIterator) } -data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +inline class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayIntIterator(ndArray) override fun get(index: IntArray): Int = ndArray.getInt(*index) override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } + override fun toString(): String = "INDArrayIntStructure(ndArray=$ndArray)" } -data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { +inline class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { override fun elementsIterator(): Iterator> = INDArrayLongIterator(ndArray) override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index)) + override fun toString(): String = "INDArrayLongStructure(ndArray=$ndArray)" + } -data class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +inline class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayDoubleIterator(ndArray) override fun get(index: IntArray): Double = ndArray.getDouble(*index) override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } + override fun toString(): String = "INDArrayDoubleStructure(ndArray=$ndArray)" } -data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +inline class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayFloatIterator(ndArray) override fun get(index: IntArray): Float = ndArray.getFloat(*index) override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } + override fun toString(): String = "INDArrayFloatStructure(ndArray=$ndArray)" } diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt index 77565856a..ad1cbb585 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt @@ -4,6 +4,7 @@ import org.nd4j.linalg.factory.Nd4j import scientifik.kmath.structures.get import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertNotEquals internal class INDArrayStructureTest { @Test @@ -25,9 +26,25 @@ internal class INDArrayStructureTest { fun testEquals() { val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val struct1 = INDArrayDoubleStructure(nd1) + assertEquals(struct1, struct1) + assertNotEquals(struct1, null as INDArrayDoubleStructure?) val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val struct2 = INDArrayDoubleStructure(nd2) assertEquals(struct1, struct2) + assertEquals(struct2, struct1) + val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct3 = INDArrayDoubleStructure(nd3) + assertEquals(struct2, struct3) + assertEquals(struct1, struct3) + } + + @Test + fun testHashCode() { + val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct1 = INDArrayDoubleStructure(nd1) + val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct2 = INDArrayDoubleStructure(nd2) + assertEquals(struct1.hashCode(), struct2.hashCode()) } @Test