From 332c04b573e638a45401e334be5682f1fbaea03c Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sat, 23 Jan 2021 19:19:13 +0300 Subject: [PATCH] [WIP] Refactor NDStructures --- .../kscience/kmath/operations/ComplexDemo.kt | 8 ++-- .../kscience/kmath/structures/NDField.kt | 13 ++++-- .../kscience/kmath/commons/linear/CMMatrix.kt | 5 +-- .../kscience/kmath/commons/linear/CMSolver.kt | 2 +- .../kscience/kmath/nd/BufferNDAlgebra.kt | 6 +-- .../kotlin/kscience/kmath/nd/RealNDField.kt | 7 ++- .../kotlin/kscience/kmath/nd/Structure2D.kt | 10 ++++- .../kscience/kmath/operations/BigInt.kt | 13 +++--- .../kscience/kmath/structures/Buffers.kt | 2 +- .../kmath/structures/FlaggedBuffer.kt | 2 +- .../kscience/kmath/structures/MemoryBuffer.kt | 2 +- .../kscience/kmath/structures/NDFieldTest.kt | 8 ++-- .../kmath/structures/NumberNDFieldTest.kt | 44 ++++++++++++------- .../kmath/structures/LazyNDStructure.kt | 2 +- .../kscience/kmath/ejml/EjmlMatrixContext.kt | 5 +-- .../kotlin/kscience/kmath/real/RealMatrix.kt | 6 +-- .../kaceince/kmath/real/RealMatrixTest.kt | 4 +- .../kscience/kmath/histogram/RealHistogram.kt | 6 +-- .../kscience.kmath.nd4j/Nd4jArrayAlgebra.kt | 18 +++++--- .../kmath/viktor/ViktorNDStructure.kt | 40 ++++++++++------- 20 files changed, 120 insertions(+), 83 deletions(-) diff --git a/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt index f90221582..821618af5 100644 --- a/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt +++ b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt @@ -1,17 +1,17 @@ package kscience.kmath.operations -import kscience.kmath.nd.NDField -import kscience.kmath.structures.NDElement +import kscience.kmath.nd.NDAlgebra +import kscience.kmath.nd.complex fun main() { // 2d element - val element = NDElement.complex(2, 2) { (i,j) -> + val element = NDAlgebra.complex(2, 2).produce { (i,j) -> Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble()) } println(element) // 1d element operation - val result = with(NDField.complex(8)) { + val result = with(NDAlgebra.complex(8)) { val a = produce { (it) -> i * it - it.toDouble() } val b = 3 val c = Complex(1.0, 1.0) diff --git a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt index e26b8cbce..3e84c0be0 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt @@ -1,7 +1,10 @@ package kscience.kmath.structures import kotlinx.coroutines.GlobalScope -import kscience.kmath.nd.* +import kscience.kmath.nd.NDAlgebra +import kscience.kmath.nd.NDStructure +import kscience.kmath.nd.field +import kscience.kmath.nd.real import kscience.kmath.nd4j.Nd4jArrayField import kscience.kmath.operations.RealField import kscience.kmath.operations.invoke @@ -39,8 +42,10 @@ fun main() { } measureAndPrint("Element addition") { - var res: NDStructure = boxingField.one - repeat(n) { res += 1.0 } + boxingField { + var res: NDStructure = one + repeat(n) { res += 1.0 } + } } measureAndPrint("Specialized addition") { @@ -52,7 +57,7 @@ fun main() { measureAndPrint("Nd4j specialized addition") { nd4jField { - var res = one + var res:NDStructure = one repeat(n) { res += 1.0 } } } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt index b9fc8b72a..e168b3d7c 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt @@ -1,10 +1,7 @@ package kscience.kmath.commons.linear -import kscience.kmath.linear.DiagonalFeature -import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.Point +import kscience.kmath.linear.* import kscience.kmath.misc.UnstableKMathAPI -import kscience.kmath.nd.Matrix import org.apache.commons.math3.linear.* import kotlin.reflect.KClass import kotlin.reflect.cast diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMSolver.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMSolver.kt index aa0bf4e1a..1c0896597 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMSolver.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMSolver.kt @@ -1,7 +1,7 @@ package kscience.kmath.commons.linear +import kscience.kmath.linear.Matrix import kscience.kmath.linear.Point -import kscience.kmath.nd.Matrix import org.apache.commons.math3.linear.* public enum class CMDecomposition { diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/BufferNDAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/BufferNDAlgebra.kt index a92e08c60..cc1764643 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/BufferNDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/BufferNDAlgebra.kt @@ -86,7 +86,7 @@ public fun > NDAlgebra.Companion.space( vararg shape: Int, ): BufferedNDSpace = BufferedNDSpace(shape, space, bufferFactory) -public inline fun , R> A.nd( +public inline fun , R> A.ndSpace( noinline bufferFactory: BufferFactory, vararg shape: Int, action: BufferedNDSpace.() -> R, @@ -102,7 +102,7 @@ public fun > NDAlgebra.Companion.ring( vararg shape: Int, ): BufferedNDRing = BufferedNDRing(shape, ring, bufferFactory) -public inline fun , R> A.nd( +public inline fun , R> A.ndRing( noinline bufferFactory: BufferFactory, vararg shape: Int, action: BufferedNDRing.() -> R, @@ -118,7 +118,7 @@ public fun > NDAlgebra.Companion.field( vararg shape: Int, ): BufferedNDField = BufferedNDField(shape, field, bufferFactory) -public inline fun , R> A.nd( +public inline fun , R> A.ndField( noinline bufferFactory: BufferFactory, vararg shape: Int, action: BufferedNDField.() -> R, diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/RealNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/RealNDField.kt index feb3c509a..7a0d8f8f9 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/RealNDField.kt @@ -94,9 +94,12 @@ public open class RealNDField( /** * Fast element production using function inlining */ -public inline fun BufferedNDField.produceInline(crossinline initializer: RealField.(Int) -> Double): NDBuffer { +public inline fun BufferedNDField.produceInline(crossinline initializer: RealField.(IntArray) -> Double): NDBuffer { contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) } - val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) } + val array = DoubleArray(strides.linearSize) { offset -> + val index = strides.index(offset) + RealField.initializer(index) + } return NDBuffer(strides, RealBuffer(array)) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/Structure2D.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/Structure2D.kt index 4c730877c..b092d48e8 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/nd/Structure2D.kt @@ -52,7 +52,15 @@ public interface Structure2D : NDStructure { for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j)) } - public companion object + public companion object { + public inline fun real( + rows: Int, + columns: Int, + crossinline init: (i: Int, j: Int) -> Double, + ): NDBuffer = NDAlgebra.real(rows, columns).produceInline { (i, j) -> + init(i, j) + } + } } /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt index 9599401bc..e0bf817ef 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt @@ -1,10 +1,12 @@ package kscience.kmath.operations import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.nd.BufferedNDRing import kscience.kmath.nd.NDAlgebra import kscience.kmath.operations.BigInt.Companion.BASE import kscience.kmath.operations.BigInt.Companion.BASE_SIZE -import kscience.kmath.structures.* +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.MutableBuffer import kotlin.math.log2 import kotlin.math.max import kotlin.math.min @@ -463,10 +465,5 @@ public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigIn public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer = boxing(size, initializer) -public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing = - BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt) - -public fun NDElement.Companion.bigInt( - vararg shape: Int, - initializer: BigIntField.(IntArray) -> BigInt -): BufferedNDRingElement = NDAlgebra.bigInt(*shape).produce(initializer) +public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BufferedNDRing = + BufferedNDRing(shape, BigIntField, Buffer.Companion::bigInt) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt index 4e7c0d1b1..bfec6f871 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt @@ -43,7 +43,7 @@ public interface Buffer { * Checks content equality with another buffer. */ public fun contentEquals(other: Buffer<*>): Boolean = - kscience.kmath.nd.mapIndexed { index, value -> value == other[index] }.all { it } + asSequence().mapIndexed { index, value -> value == other[index] }.all { it } public companion object { /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt index 7828cea10..4965e37cf 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt @@ -60,7 +60,7 @@ public class FlaggedRealBuffer(public val values: DoubleArray, public val flags: override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null - override operator fun iterator(): Iterator = kscience.kmath.nd.map { + override operator fun iterator(): Iterator = values.indices.asSequence().map { if (isValid(it)) values[it] else null }.iterator() } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt index 22331ae81..66c9212cf 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt @@ -15,7 +15,7 @@ public open class MemoryBuffer(protected val memory: Memory, protected private val reader: MemoryReader = memory.reader() override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index) - override operator fun iterator(): Iterator = kscience.kmath.nd.map { get(it) }.iterator() + override operator fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() public companion object { public fun create(spec: MemorySpec, size: Int): MemoryBuffer = diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt index 0f5f99b49..35d49e29d 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt @@ -1,6 +1,8 @@ package kscience.kmath.structures -import kscience.kmath.nd.NDField +import kscience.kmath.nd.NDAlgebra +import kscience.kmath.nd.get +import kscience.kmath.nd.real import kscience.kmath.operations.internal.FieldVerifier import kotlin.test.Test import kotlin.test.assertEquals @@ -8,12 +10,12 @@ import kotlin.test.assertEquals internal class NDFieldTest { @Test fun verify() { - NDField.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) } + NDAlgebra.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) } } @Test fun testStrides() { - val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() } + val ndArray = NDAlgebra.real(10, 10).produce { (it[0] + it[1]).toDouble() } assertEquals(ndArray[5, 5], 10.0) } } diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt index baf3656fc..b90e0f07f 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt @@ -1,9 +1,8 @@ package kscience.kmath.structures -import kscience.kmath.nd.NDField -import kscience.kmath.nd.NDStructure +import kscience.kmath.nd.* import kscience.kmath.operations.Norm -import kscience.kmath.structures.NDElement.Companion.real2D +import kscience.kmath.operations.invoke import kotlin.math.abs import kotlin.math.pow import kotlin.test.Test @@ -11,25 +10,30 @@ import kotlin.test.assertEquals @Suppress("UNUSED_VARIABLE") class NumberNDFieldTest { - val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() } - val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() } + val algebra = NDAlgebra.real(3,3) + val array1 = algebra.produceInline { (i, j) -> (i + j).toDouble() } + val array2 = algebra.produceInline { (i, j) -> (i - j).toDouble() } @Test fun testSum() { - val sum = array1 + array2 - assertEquals(4.0, sum[2, 2]) + algebra { + val sum = array1 + array2 + assertEquals(4.0, sum[2, 2]) + } } @Test fun testProduct() { - val product = array1 * array2 - assertEquals(0.0, product[2, 2]) + algebra { + val product = array1 * array2 + assertEquals(0.0, product[2, 2]) + } } @Test fun testGeneration() { - val array = real2D(3, 3) { i, j -> (i * 10 + j).toDouble() } + val array = Structure2D.real(3, 3) { i, j -> (i * 10 + j).toDouble() } for (i in 0..2) { for (j in 0..2) { @@ -41,16 +45,20 @@ class NumberNDFieldTest { @Test fun testExternalFunction() { - val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 } - val result = function(array1) + 1.0 - assertEquals(10.0, result[1, 1]) + algebra { + val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 } + val result = function(array1) + 1.0 + assertEquals(10.0, result[1, 1]) + } } @Test fun testLibraryFunction() { - val abs: (Double) -> Double = ::abs - val result = abs(array2) - assertEquals(2.0, result[0, 2]) + algebra { + val abs: (Double) -> Double = ::abs + val result = abs(array2) + assertEquals(2.0, result[0, 2]) + } } @Test @@ -65,6 +73,8 @@ class NumberNDFieldTest { @Test fun testInternalContext() { - (NDField.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } } + algebra { + (NDAlgebra.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } } + } } } diff --git a/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt index cac692eb2..3933ef28b 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt @@ -21,7 +21,7 @@ public class LazyNDStructure( public override fun elements(): Sequence> { val strides = DefaultStrides(shape) - val res = runBlocking { kscience.kmath.nd.map { index -> index to await(index) } } + val res = runBlocking { strides.indices().toList().map { index -> index to await(index) } } return res.asSequence() } diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt index ebd4dd586..eb126e00a 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -1,10 +1,7 @@ package kscience.kmath.ejml -import kscience.kmath.linear.InverseMatrixFeature -import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.Point +import kscience.kmath.linear.* import kscience.kmath.misc.UnstableKMathAPI -import kscience.kmath.nd.Matrix import kscience.kmath.nd.getFeature import org.ejml.simple.SimpleMatrix diff --git a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt index defaba126..e7bde9980 100644 --- a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt @@ -1,11 +1,11 @@ package kscience.kmath.real -import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.real +import kscience.kmath.linear.* import kscience.kmath.misc.UnstableKMathAPI -import kscience.kmath.nd.Matrix import kscience.kmath.structures.Buffer import kscience.kmath.structures.RealBuffer +import kscience.kmath.structures.asIterable +import kotlin.math.pow /* * Functions for convenient "numpy-like" operations with Double matrices. diff --git a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt index 68aa8cf52..309997ae3 100644 --- a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt @@ -1,7 +1,9 @@ package kaceince.kmath.real -import kscience.kmath.nd.Matrix +import kscience.kmath.linear.Matrix +import kscience.kmath.linear.build import kscience.kmath.real.* +import kscience.kmath.structures.contentEquals import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue diff --git a/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/RealHistogram.kt index 42dbaabac..085641106 100644 --- a/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/RealHistogram.kt @@ -17,7 +17,7 @@ public data class BinDef>( require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" } val upper = space { center + sizes / 2.0 } val lower = space { center - sizes / 2.0 } - return kscience.kmath.nd.mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it } + return vector.asSequence().mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it } } } @@ -70,7 +70,7 @@ public class RealHistogram( public fun getValue(point: Buffer): Long = getValue(getIndex(point)) private fun getDef(index: IntArray): BinDef { - val center = kscience.kmath.nd.mapIndexed { axis, i -> + val center = index.mapIndexed { axis, i -> when (i) { 0 -> Double.NEGATIVE_INFINITY strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY @@ -100,7 +100,7 @@ public class RealHistogram( } public override operator fun iterator(): Iterator> = - kscience.kmath.nd.map { (index, value) -> MultivariateBin(getDef(index), value.sum()) } + weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) } .iterator() /** diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt index b958d1ccf..91d45dccd 100644 --- a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt @@ -7,6 +7,13 @@ import kscience.kmath.structures.* import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.factory.Nd4j +internal fun NDAlgebra<*, *>.checkShape(array: INDArray): INDArray { + val arrayShape = array.shape().toIntArray() + if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape) + return array +} + + /** * Represents [NDAlgebra] over [Nd4jArrayAlgebra]. * @@ -18,8 +25,9 @@ public interface Nd4jArrayAlgebra : NDAlgebra { * Wraps [INDArray] to [N]. */ public fun INDArray.wrap(): Nd4jArrayStructure + public val NDStructure.ndArray: INDArray - get() = when { + get() = when { !shape.contentEquals(this@Nd4jArrayAlgebra.shape) -> throw ShapeMismatchException( this@Nd4jArrayAlgebra.shape, shape @@ -213,7 +221,7 @@ public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArray public override val elementContext: RealField get() = RealField - public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(asRealStructure()) + public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asRealStructure() public override operator fun NDStructure.div(arg: Double): Nd4jArrayStructure { return ndArray.div(arg).wrap() @@ -247,7 +255,7 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra public override val elementContext: FloatField get() = FloatField - public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(asFloatStructure()) + public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asFloatStructure() public override operator fun NDStructure.div(arg: Float): Nd4jArrayStructure { return ndArray.div(arg).wrap() @@ -281,7 +289,7 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi public override val elementContext: IntRing get() = IntRing - public override fun INDArray.wrap(): Nd4jArrayStructure = check(asIntStructure()) + public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asIntStructure() public override operator fun NDStructure.plus(arg: Int): Nd4jArrayStructure { return ndArray.add(arg).wrap() @@ -307,7 +315,7 @@ public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayR public override val elementContext: LongRing get() = LongRing - public override fun INDArray.wrap(): Nd4jArrayStructure = check(asLongStructure()) + public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asLongStructure() public override operator fun NDStructure.plus(arg: Long): Nd4jArrayStructure { return ndArray.add(arg).wrap() diff --git a/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorNDStructure.kt b/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorNDStructure.kt index 5dbc297b4..a6c4f3ce0 100644 --- a/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorNDStructure.kt +++ b/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorNDStructure.kt @@ -1,9 +1,6 @@ package kscience.kmath.viktor -import kscience.kmath.nd.DefaultStrides -import kscience.kmath.nd.MutableNDStructure -import kscience.kmath.nd.NDField -import kscience.kmath.nd.Strides +import kscience.kmath.nd.* import kscience.kmath.operations.RealField import org.jetbrains.bio.viktor.F64Array @@ -24,14 +21,25 @@ public inline class ViktorNDStructure(public val f64Buffer: F64Array) : MutableN public fun F64Array.asStructure(): ViktorNDStructure = ViktorNDStructure(this) @Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public class ViktorNDField(public override val shape: IntArray) : NDField { +public class ViktorNDField(public override val shape: IntArray) : NDField { + + public val NDStructure.f64Buffer: F64Array + get() = when { + !shape.contentEquals(this@ViktorNDField.shape) -> throw ShapeMismatchException( + this@ViktorNDField.shape, + shape + ) + this is ViktorNDStructure && this.f64Buffer.shape.contentEquals(this@ViktorNDField.shape) -> this.f64Buffer + else -> produce { this@f64Buffer[it] }.f64Buffer + } + public override val zero: ViktorNDStructure get() = F64Array.full(init = 0.0, shape = shape).asStructure() public override val one: ViktorNDStructure get() = F64Array.full(init = 1.0, shape = shape).asStructure() - public val strides: Strides = DefaultStrides(shape) + private val strides: Strides = DefaultStrides(shape) public override val elementContext: RealField get() = RealField @@ -42,7 +50,7 @@ public class ViktorNDField(public override val shape: IntArray) : NDField Double): ViktorNDStructure = + public override fun map(arg: NDStructure, transform: RealField.(Double) -> Double): ViktorNDStructure = F64Array(*shape).apply { this@ViktorNDField.strides.indices().forEach { index -> set(value = RealField.transform(arg[index]), indices = index) @@ -50,7 +58,7 @@ public class ViktorNDField(public override val shape: IntArray) : NDField, transform: RealField.(index: IntArray, Double) -> Double ): ViktorNDStructure = F64Array(*shape).apply { this@ViktorNDField.strides.indices().forEach { index -> @@ -59,8 +67,8 @@ public class ViktorNDField(public override val shape: IntArray) : NDField, + b: NDStructure, transform: RealField.(Double, Double) -> Double ): ViktorNDStructure = F64Array(*shape).apply { this@ViktorNDField.strides.indices().forEach { index -> @@ -68,21 +76,21 @@ public class ViktorNDField(public override val shape: IntArray) : NDField, b: NDStructure): ViktorNDStructure = (a.f64Buffer + b.f64Buffer).asStructure() - public override fun multiply(a: ViktorNDStructure, k: Number): ViktorNDStructure = + public override fun multiply(a: NDStructure, k: Number): ViktorNDStructure = (a.f64Buffer * k.toDouble()).asStructure() - public override inline fun ViktorNDStructure.plus(b: ViktorNDStructure): ViktorNDStructure = + public override inline fun NDStructure.plus(b: NDStructure): ViktorNDStructure = (f64Buffer + b.f64Buffer).asStructure() - public override inline fun ViktorNDStructure.minus(b: ViktorNDStructure): ViktorNDStructure = + public override inline fun NDStructure.minus(b: NDStructure): ViktorNDStructure = (f64Buffer - b.f64Buffer).asStructure() - public override inline fun ViktorNDStructure.times(k: Number): ViktorNDStructure = + public override inline fun NDStructure.times(k: Number): ViktorNDStructure = (f64Buffer * k.toDouble()).asStructure() - public override inline fun ViktorNDStructure.plus(arg: Double): ViktorNDStructure = + public override inline fun NDStructure.plus(arg: Double): ViktorNDStructure = (f64Buffer.plus(arg)).asStructure() } \ No newline at end of file