From 7a507cd281c1aba261a56776c08f102819901336 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 29 Apr 2018 09:15:30 +0300 Subject: [PATCH] Fixed multiple problems with NDArray. Added test for summation --- .../scientifik/kmath/operations/Algebra.kt | 12 ++--- .../scientifik/kmath/structures/NDArray.kt | 14 +++++- .../kmath/structures/RealNDArray.kt | 44 +++++++++++++------ .../kmath/structures/RealNDFieldTest.kt | 15 +++++++ 4 files changed, 63 insertions(+), 22 deletions(-) create mode 100644 jvm/src/test/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt diff --git a/common/src/main/kotlin/scientifik/kmath/operations/Algebra.kt b/common/src/main/kotlin/scientifik/kmath/operations/Algebra.kt index 4b74cbc7a..0a1d6f221 100644 --- a/common/src/main/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/common/src/main/kotlin/scientifik/kmath/operations/Algebra.kt @@ -50,10 +50,10 @@ interface SpaceElement> { */ val self: S - operator fun plus(b: S): S = with(context) { self + b } - operator fun minus(b: S): S = with(context) { self - b } - operator fun times(k: Number): S = with(context) { self * k } - operator fun div(k: Number): S = with(context) { self / k } + operator fun plus(b: S): S = context.add(self, b) + operator fun minus(b: S): S = context.add(self, context.multiply(b, -1.0)) + operator fun times(k: Number): S = context.multiply(self, k.toDouble()) + operator fun div(k: Number): S = context.multiply(self, 1.0 / k.toDouble()) } /** @@ -80,7 +80,7 @@ interface Ring : Space { interface RingElement> : SpaceElement { override val context: Ring - operator fun times(b: S): S = with(context) { self * b } + operator fun times(b: S): S = context.multiply(self, b) } /** @@ -99,5 +99,5 @@ interface Field : Ring { interface FieldElement> : RingElement { override val context: Field - operator fun div(b: S): S = with(context) { self / b } + operator fun div(b: S): S = context.divide(self, b) } \ No newline at end of file diff --git a/common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt b/common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt index 4ab01da56..5a4e14a36 100644 --- a/common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt +++ b/common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt @@ -105,7 +105,7 @@ interface NDArray> : FieldElement>, Iterable> = iterateIndexes(tailShape).toList() (0 until shape[0]).asSequence().map { firstIndex -> //adding first element to each of provided index lists @@ -116,5 +116,15 @@ interface NDArray> : FieldElement>, Iterable, initializer: (List) -> Double = { 0.0 }): NDArray -expect fun RealNDArray(shape: List, initializer: (List) -> Double): NDArray \ No newline at end of file +fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray { + return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) } +} + +fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray { + return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } +} \ No newline at end of file diff --git a/jvm/src/main/kotlin/scientifik/kmath/structures/RealNDArray.kt b/jvm/src/main/kotlin/scientifik/kmath/structures/RealNDArray.kt index 3d00b79d5..c25dcde07 100644 --- a/jvm/src/main/kotlin/scientifik/kmath/structures/RealNDArray.kt +++ b/jvm/src/main/kotlin/scientifik/kmath/structures/RealNDArray.kt @@ -1,6 +1,5 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.Field import scientifik.kmath.operations.Real import scientifik.kmath.operations.RealField import java.nio.DoubleBuffer @@ -13,8 +12,9 @@ private class RealNDField(shape: List) : NDField(shape, RealField) { private val strides: List by lazy { ArrayList(shape.size).apply { var current = 1 - shape.forEach{ - current *=it + add(0) + shape.forEach { + current *= it add(current) } } @@ -30,7 +30,7 @@ private class RealNDField(shape: List) : NDField(shape, RealField) { } val capacity: Int - get() = strides[shape.size - 1] + get() = strides[shape.size] override fun produce(initializer: (List) -> Real): NDArray { @@ -39,26 +39,42 @@ private class RealNDField(shape: List) : NDField(shape, RealField) { NDArray.iterateIndexes(shape).forEach { buffer.put(offset(it), initializer(it).value) } - return RealNDArray(buffer) + return RealNDArray(this, buffer) } - inner class RealNDArray(val data: DoubleBuffer) : NDArray { - - override val context: Field> - get() = this@RealNDField + class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray { override fun get(vararg index: Int): Real { - return Real(data.get(offset(index.asList()))) + return Real(data.get(context.offset(index.asList()))) } - override val self: NDArray - get() = this - } + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + other as RealNDArray + + if (context.shape != other.context.shape) return false + if (data != other.data) return false + + return true + } + + override fun hashCode(): Int { + var result = context.shape.hashCode() + result = 31 * result + data.hashCode() + return result + } + + //TODO generate fixed hash code for quick comparison? + + + override val self: NDArray = this + } } -actual fun RealNDArray(shape: List, initializer: (List) -> Double): NDArray { +actual fun realNDArray(shape: List, initializer: (List) -> Double): NDArray { //TODO cache fields? return RealNDField(shape).produce { Real(initializer(it)) } } \ No newline at end of file diff --git a/jvm/src/test/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt b/jvm/src/test/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt new file mode 100644 index 000000000..f65642f63 --- /dev/null +++ b/jvm/src/test/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt @@ -0,0 +1,15 @@ +package scientifik.kmath.structures + +import org.junit.Assert.assertEquals +import kotlin.test.Test + +class RealNDFieldTest { + val array1 = real2DArray(3, 3) { i, j -> (i + j).toDouble() } + val array2 = real2DArray(3, 3) { i, j -> (i - j).toDouble() } + + @Test + fun testSum() { + val sum = array1 + array2 + assertEquals(4.0, sum[2, 2].value, 0.1) + } +} \ No newline at end of file