From 783087982fbe93f60a2c20e7a3d50627b69b8b58 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 29 Jun 2020 02:50:34 +0700 Subject: [PATCH] Rollback making Structures inline, implement Algebras for NDArrayStructure --- .../scientifik.kmath.nd4j/INDArrayAlgebra.kt | 89 +++++++++++++++++++ .../INDArrayIterators.kt | 7 +- .../INDArrayStructures.kt | 25 +++--- .../kmath/nd4j/INDArrayAlgebraTest.kt | 30 +++++++ .../kmath/nd4j/INDArrayStructureTest.kt | 14 +-- 5 files changed, 147 insertions(+), 18 deletions(-) create mode 100644 kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt create mode 100644 kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt new file mode 100644 index 000000000..760e3d03e --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayAlgebra.kt @@ -0,0 +1,89 @@ +package scientifik.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.factory.Nd4j +import scientifik.kmath.operations.* +import scientifik.kmath.structures.MutableNDStructure +import scientifik.kmath.structures.NDField +import scientifik.kmath.structures.NDRing + +interface INDArrayRing : + NDRing where F : Ring, N : INDArrayStructure, N : MutableNDStructure { + fun INDArray.wrap(): N + + override val zero: N + get() = Nd4j.zeros(*shape).wrap() + + override val one: N + get() = Nd4j.ones(*shape).wrap() + + override fun produce(initializer: F.(IntArray) -> T): N { + val struct = Nd4j.create(*shape).wrap() + struct.elements().map(Pair::first).forEach { struct[it] = elementContext.initializer(it) } + return struct + } + + override fun map(arg: N, transform: F.(T) -> T): N { + val new = Nd4j.create(*shape) + Nd4j.copy(arg.ndArray, new) + val newStruct = new.wrap() + newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } + return newStruct + } + + override fun mapIndexed(arg: N, transform: F.(index: IntArray, T) -> T): N { + val new = Nd4j.create(*shape).wrap() + new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) } + return new + } + + override fun combine(a: N, b: N, transform: F.(T, T) -> T): N { + val new = Nd4j.create(*shape).wrap() + new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) } + return new + } + + override fun add(a: N, b: N): N = a.ndArray.addi(b.ndArray).wrap() + override fun N.minus(b: N): N = ndArray.subi(b.ndArray).wrap() + override fun N.unaryMinus(): N = ndArray.negi().wrap() + override fun multiply(a: N, b: N): N = a.ndArray.muli(b.ndArray).wrap() + override fun multiply(a: N, k: Number): N = a.ndArray.muli(k).wrap() + override fun N.div(k: Number): N = ndArray.divi(k).wrap() + override fun N.minus(b: Number): N = ndArray.subi(b).wrap() + override fun N.plus(b: Number): N = ndArray.addi(b).wrap() + override fun N.times(k: Number): N = ndArray.muli(k).wrap() +} + +interface INDArrayField : NDField, + INDArrayRing where F : Field, N : INDArrayStructure, N : MutableNDStructure { + override fun divide(a: N, b: N): N = a.ndArray.divi(b.ndArray).wrap() +} + +class RealINDArrayField(override val shape: IntArray, override val elementContext: Field = RealField) : + INDArrayField, INDArrayRealStructure> { + override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure() + override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap() + override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap() + override fun INDArrayRealStructure.div(k: Number): INDArrayRealStructure = ndArray.divi(k).wrap() + override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap() + override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap() +} + +class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field = FloatField) : + INDArrayField, INDArrayFloatStructure> { + override fun INDArray.wrap(): INDArrayFloatStructure = asFloatStructure() + override fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure = ndArray.divi(arg).wrap() + override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap() + override fun INDArrayFloatStructure.div(k: Number): INDArrayFloatStructure = ndArray.divi(k).wrap() + override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap() + override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap() +} + +class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring = IntRing) : + INDArrayRing, INDArrayIntStructure> { + override fun INDArray.wrap(): INDArrayIntStructure = asIntStructure() + override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap() + override fun INDArrayIntStructure.div(k: Number): INDArrayIntStructure = ndArray.divi(k).wrap() + override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap() + override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap() +} diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt index f6efdc0ba..115c78cb9 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayIterators.kt @@ -20,14 +20,19 @@ internal sealed class INDArrayIteratorBase(protected val iterateOver: INDArra } } -internal class INDArrayDoubleIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { +internal class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices) } +internal fun INDArray.realIterator(): INDArrayRealIterator = INDArrayRealIterator(this) + internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices) } +// TODO +//internal fun INDArray.longI + internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getInt(*narrowToIntArray(indices)) } diff --git a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt index 66aa00fac..351110485 100644 --- a/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt +++ b/kmath-nd4j/src/main/kotlin/scientifik.kmath.nd4j/INDArrayStructures.kt @@ -14,30 +14,35 @@ interface INDArrayStructure : NDStructure { override fun elements(): Sequence> = Sequence(::elementsIterator) } -inline class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayIntIterator(ndArray) override fun get(index: IntArray): Int = ndArray.getInt(*index) override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } - override fun toString(): String = "INDArrayIntStructure(ndArray=$ndArray)" } -inline class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { +fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this) + +data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure { override fun elementsIterator(): Iterator> = INDArrayLongIterator(ndArray) override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index)) - override fun toString(): String = "INDArrayLongStructure(ndArray=$ndArray)" - } -inline class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { - override fun elementsIterator(): Iterator> = INDArrayDoubleIterator(ndArray) +fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this) + +data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure, + MutableNDStructure { + override fun elementsIterator(): Iterator> = INDArrayRealIterator(ndArray) override fun get(index: IntArray): Double = ndArray.getDouble(*index) override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } - override fun toString(): String = "INDArrayDoubleStructure(ndArray=$ndArray)" } -inline class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, MutableNDStructure { +fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this) + +data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure, + MutableNDStructure { override fun elementsIterator(): Iterator> = INDArrayFloatIterator(ndArray) override fun get(index: IntArray): Float = ndArray.getFloat(*index) override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } - override fun toString(): String = "INDArrayFloatStructure(ndArray=$ndArray)" } + +fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this) diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt new file mode 100644 index 000000000..f971e7871 --- /dev/null +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayAlgebraTest.kt @@ -0,0 +1,30 @@ +package scientifik.kmath.nd4j + +import org.nd4j.linalg.factory.Nd4j +import scientifik.kmath.operations.invoke +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class INDArrayAlgebraTest { + @Test + fun testProduce() { + val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } + val expected = Nd4j.create(2, 2)!!.asRealStructure() + expected[intArrayOf(0, 0)] = 0.0 + expected[intArrayOf(0, 1)] = 1.0 + expected[intArrayOf(1, 0)] = 1.0 + expected[intArrayOf(1, 1)] = 2.0 + assertEquals(expected, res) + } + + @Test + fun testMap() { + val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } } + val expected = Nd4j.create(2, 2)!!.asIntStructure() + expected[intArrayOf(0, 0)] = 3 + expected[intArrayOf(0, 1)] = 3 + expected[intArrayOf(1, 0)] = 3 + expected[intArrayOf(1, 1)] = 3 + assertEquals(expected, res) + } +} diff --git a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt index ad1cbb585..dfede6d32 100644 --- a/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/scientifik/kmath/nd4j/INDArrayStructureTest.kt @@ -10,7 +10,7 @@ internal class INDArrayStructureTest { @Test fun testElements() { val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct = INDArrayDoubleStructure(nd) + val struct = INDArrayRealStructure(nd) val res = struct.elements().map(Pair::second).toList() assertEquals(listOf(1.0, 2.0, 3.0), res) } @@ -25,15 +25,15 @@ internal class INDArrayStructureTest { @Test fun testEquals() { val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct1 = INDArrayDoubleStructure(nd1) + val struct1 = INDArrayRealStructure(nd1) assertEquals(struct1, struct1) - assertNotEquals(struct1, null as INDArrayDoubleStructure?) + assertNotEquals(struct1, null as INDArrayRealStructure?) val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct2 = INDArrayDoubleStructure(nd2) + val struct2 = INDArrayRealStructure(nd2) assertEquals(struct1, struct2) assertEquals(struct2, struct1) val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct3 = INDArrayDoubleStructure(nd3) + val struct3 = INDArrayRealStructure(nd3) assertEquals(struct2, struct3) assertEquals(struct1, struct3) } @@ -41,9 +41,9 @@ internal class INDArrayStructureTest { @Test fun testHashCode() { val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct1 = INDArrayDoubleStructure(nd1) + val struct1 = INDArrayRealStructure(nd1) val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! - val struct2 = INDArrayDoubleStructure(nd2) + val struct2 = INDArrayRealStructure(nd2) assertEquals(struct1.hashCode(), struct2.hashCode()) }