From 600d8a64b89b2aa37b44232f755151787039f668 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 4 Jan 2019 20:23:32 +0300 Subject: [PATCH] Fixes to tests and lazy structures --- benchmarks/build.gradle | 3 +- ...reBenchmark.kt => NDStructureBenchmark.kt} | 26 ++++- .../kmath/histogram/PhantomHistogram.kt | 2 +- .../kmath/linear/LUDecomposition.kt | 8 +- .../scientifik/kmath/linear/LinearAlgrebra.kt | 6 +- .../kotlin/scientifik/kmath/linear/Matrix.kt | 6 +- .../kotlin/scientifik/kmath/linear/Vector.kt | 12 +- .../scientifik/kmath/operations/Fields.kt | 6 +- .../kmath/structures/BufferNDField.kt | 42 +++---- .../kmath/structures/ExtendedNDField.kt | 4 +- .../scientifik/kmath/structures/NDElement.kt | 4 +- .../scientifik/kmath/structures/NDField.kt | 27 +++-- .../kmath/structures/RealNDField.kt | 60 +++++----- .../kmath/expressions/ExpressionFieldTest.kt | 8 +- .../kmath/operations/RealFieldTest.kt | 2 +- .../kmath/structures/LazyNDField.kt | 105 +++++++++++------- .../kmath/structures/LazyNDFieldTest.kt | 2 +- 17 files changed, 185 insertions(+), 138 deletions(-) rename benchmarks/src/main/kotlin/scientifik/kmath/structures/{BufferNDStructureBenchmark.kt => NDStructureBenchmark.kt} (68%) diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index 1876d86d7..23f967102 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -5,6 +5,7 @@ plugins { } dependencies { - compile project(':kmath-core') + compile project(":kmath-core") + compile project(":kmath-coroutines") //jmh project(':kmath-core') } \ No newline at end of file diff --git a/benchmarks/src/main/kotlin/scientifik/kmath/structures/BufferNDStructureBenchmark.kt b/benchmarks/src/main/kotlin/scientifik/kmath/structures/NDStructureBenchmark.kt similarity index 68% rename from benchmarks/src/main/kotlin/scientifik/kmath/structures/BufferNDStructureBenchmark.kt rename to benchmarks/src/main/kotlin/scientifik/kmath/structures/NDStructureBenchmark.kt index 99111f6b3..77f0a2737 100644 --- a/benchmarks/src/main/kotlin/scientifik/kmath/structures/BufferNDStructureBenchmark.kt +++ b/benchmarks/src/main/kotlin/scientifik/kmath/structures/NDStructureBenchmark.kt @@ -1,15 +1,16 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.DoubleField +import scientifik.kmath.operations.RealField import kotlin.system.measureTimeMillis fun main(args: Array) { val dim = 1000 - val n = 10000 + val n = 100 - val bufferedField = NDField.buffered(intArrayOf(dim, dim), DoubleField) + val bufferedField = NDField.buffered(intArrayOf(dim, dim), RealField) val specializedField = NDField.real(intArrayOf(dim, dim)) - val genericField = NDField.generic(intArrayOf(dim, dim), DoubleField) + val genericField = NDField.generic(intArrayOf(dim, dim), RealField) + val lazyNDField = NDField.lazy(intArrayOf(dim, dim), RealField) // val action: NDField>.() -> Unit = { // var res = one @@ -55,6 +56,23 @@ fun main(args: Array) { println("Specialized addition completed in $specializedTime millis") + val lazyTime = measureTimeMillis { + val tr : RealField.(Double)->Double = {arg-> + var r = arg + repeat(n) { + r += 1.0 + } + r + } + lazyNDField.run { + val res = one.map(tr) + + res.elements().sumByDouble { it.second } + } + } + + println("Lazy addition completed in $lazyTime millis") + val genericTime = measureTimeMillis { //genericField.run(action) genericField.run { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt index 1d04a8c19..75d9ebd74 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt @@ -42,7 +42,7 @@ class PhantomBin>(val template: BinTemplate, override val v /** * Immutable histogram with explicit structure for content and additional external bin description. * Bin search is slow, but full histogram algebra is supported. - * @param bins map a template into structure index + * @param bins transform a template into structure index */ class PhantomHistogram>( val bins: Map, IntArray>, diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt index 3db4809be..5618b3100 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt @@ -1,7 +1,7 @@ package scientifik.kmath.linear -import scientifik.kmath.operations.DoubleField import scientifik.kmath.operations.Field +import scientifik.kmath.operations.RealField import scientifik.kmath.structures.* import kotlin.math.absoluteValue @@ -184,7 +184,7 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr } class RealLUDecomposition(matrix: RealMatrix, private val singularityThreshold: Double = DEFAULT_TOO_SMALL) : - LUDecomposition(matrix) { + LUDecomposition(matrix) { override fun isSingular(value: Double): Boolean { return value.absoluteValue < singularityThreshold } @@ -197,9 +197,9 @@ class RealLUDecomposition(matrix: RealMatrix, private val singularityThreshold: /** Specialized solver. */ -object RealLUSolver : LinearSolver { +object RealLUSolver : LinearSolver { - fun decompose(mat: Matrix, threshold: Double = 1e-11): RealLUDecomposition = + fun decompose(mat: Matrix, threshold: Double = 1e-11): RealLUDecomposition = RealLUDecomposition(mat, threshold) override fun solve(a: RealMatrix, b: RealMatrix): RealMatrix { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt index c2e54bf3b..11f1f7dd8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt @@ -1,8 +1,8 @@ package scientifik.kmath.linear -import scientifik.kmath.operations.DoubleField import scientifik.kmath.operations.Field import scientifik.kmath.operations.Norm +import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring import scientifik.kmath.structures.asSequence import scientifik.kmath.structures.boxingBuffer @@ -61,5 +61,5 @@ object VectorL2Norm : Norm, Double> { } } -typealias RealVector = Vector -typealias RealMatrix = Matrix +typealias RealVector = Vector +typealias RealMatrix = Matrix diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index e9a45f90a..63c78973a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -1,6 +1,6 @@ package scientifik.kmath.linear -import scientifik.kmath.operations.DoubleField +import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space import scientifik.kmath.operations.SpaceElement @@ -42,8 +42,8 @@ interface MatrixSpace> : Space> { /** * Non-boxing double matrix */ - fun real(rows: Int, columns: Int): MatrixSpace = - StructureMatrixSpace(rows, columns, DoubleField, DoubleBufferFactory) + fun real(rows: Int, columns: Int): MatrixSpace = + StructureMatrixSpace(rows, columns, RealField, DoubleBufferFactory) /** * A structured matrix with custom buffer diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt index 818a65733..9aaa28203 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt @@ -1,6 +1,6 @@ package scientifik.kmath.linear -import scientifik.kmath.operations.DoubleField +import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Space import scientifik.kmath.operations.SpaceElement import scientifik.kmath.structures.* @@ -34,13 +34,13 @@ interface VectorSpace> : Space> { companion object { - private val realSpaceCache = HashMap>() + private val realSpaceCache = HashMap>() /** * Non-boxing double vector space */ - fun real(size: Int): BufferVectorSpace { - return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) } + fun real(size: Int): BufferVectorSpace { + return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, RealField, DoubleBufferFactory) } } /** @@ -79,10 +79,10 @@ interface Vector> : SpaceElement, Vector, V fun > generic(size: Int, field: S, initializer: (Int) -> T): Vector = VectorSpace.buffered(size, field).produceElement(initializer) - fun real(size: Int, initializer: (Int) -> Double): Vector = + fun real(size: Int, initializer: (Int) -> Double): Vector = VectorSpace.real(size).produceElement(initializer) - fun ofReal(vararg elements: Double): Vector = + fun ofReal(vararg elements: Double): Vector = VectorSpace.real(elements.size).produceElement { elements[it] } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt index 97ff3c6f5..fd8bcef11 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt @@ -17,12 +17,12 @@ interface ExtendedField : * * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 */ -inline class Real(val value: Double) : FieldElement { +inline class Real(val value: Double) : FieldElement { override fun unwrap(): Double = value override fun Double.wrap(): Real = Real(value) - override val context get() = DoubleField + override val context get() = RealField companion object { @@ -32,7 +32,7 @@ inline class Real(val value: Double) : FieldElement { /** * A field for double without boxing. Does not produce appropriate field element */ -object DoubleField : AbstractField(),ExtendedField, Norm { +object RealField : AbstractField(),ExtendedField, Norm { override val zero: Double = 0.0 override fun add(a: Double, b: Double): Double = a + b override fun multiply(a: Double, @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") b: Double): Double = a * b diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt index 59fba75df..ec9268b21 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt @@ -5,18 +5,19 @@ import scientifik.kmath.operations.FieldElement abstract class StridedNDField>(shape: IntArray, elementField: F) : AbstractNDField>(shape, elementField) { - - abstract val bufferFactory: BufferFactory val strides = DefaultStrides(shape) + + abstract fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer } class BufferNDField>( shape: IntArray, elementField: F, - override val bufferFactory: BufferFactory -) : - StridedNDField(shape, elementField) { + val bufferFactory: BufferFactory +) : StridedNDField(shape, elementField) { + + override fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) override fun check(vararg elements: NDBuffer) { if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") @@ -32,22 +33,25 @@ class BufferNDField>( bufferFactory(strides.linearSize) { offset -> elementField.initializer(strides.index(offset)) }) @Suppress("OVERRIDE_BY_INLINE") - override inline fun NDBuffer.map(crossinline transform: F.(T) -> T): BufferNDElement { - check(this) + override inline fun map(arg: NDBuffer, crossinline transform: F.(T) -> T): BufferNDElement { + check(arg) return BufferNDElement( - this@BufferNDField, - bufferFactory(strides.linearSize) { offset -> elementField.transform(buffer[offset]) }) + this, + bufferFactory(arg.strides.linearSize) { offset -> elementField.transform(arg.buffer[offset]) }) } @Suppress("OVERRIDE_BY_INLINE") - override inline fun NDBuffer.mapIndexed(crossinline transform: F.(index: IntArray, T) -> T): BufferNDElement { - check(this) + override inline fun mapIndexed( + arg: NDBuffer, + crossinline transform: F.(index: IntArray, T) -> T + ): BufferNDElement { + check(arg) return BufferNDElement( - this@BufferNDField, - bufferFactory(strides.linearSize) { offset -> + this, + bufferFactory(arg.strides.linearSize) { offset -> elementField.transform( - strides.index(offset), - buffer[offset] + arg.strides.index(offset), + arg.buffer[offset] ) }) } @@ -105,11 +109,11 @@ class BufferNDElement>(override val context: StridedNDField> = strides.indices().map { it to get(it) } - override fun map(action: F.(T) -> T): BufferNDElement = - context.run { map(action) } + override fun map(action: F.(T) -> T) = + context.run { map(this@BufferNDElement, action) }.wrap() - override fun mapIndexed(transform: F.(index: IntArray, T) -> T): BufferNDElement = - context.run { mapIndexed(transform) } + override fun mapIndexed(transform: F.(index: IntArray, T) -> T) = + context.run { mapIndexed(this@BufferNDElement, transform) }.wrap() } /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index 50ecd3a57..7648efef8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -6,7 +6,7 @@ import scientifik.kmath.operations.PowerOperations import scientifik.kmath.operations.TrigonometricOperations -interface ExtendedNDField, N : NDStructure> : +interface ExtendedNDField, N : NDStructure> : NDField, TrigonometricOperations, PowerOperations, @@ -16,7 +16,7 @@ interface ExtendedNDField, N : NDStructure> /** * NDField that supports [ExtendedField] operations on its elements */ -class ExtendedNDFieldWrapper, N : NDStructure>(private val ndField: NDField) : +class ExtendedNDFieldWrapper, N : NDStructure>(private val ndField: NDField) : ExtendedNDField, NDField by ndField { override val shape: IntArray get() = ndField.shape diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt index 223e45b9d..4bed26510 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt @@ -1,8 +1,8 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.DoubleField import scientifik.kmath.operations.Field import scientifik.kmath.operations.FieldElement +import scientifik.kmath.operations.RealField interface NDElement> : NDStructure { @@ -17,7 +17,7 @@ object NDElements { /** * Create a optimized NDArray of doubles */ - fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }) = + fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }) = NDField.real(shape).produce(initializer) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt index 5cda16b25..dde406c51 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt @@ -23,9 +23,9 @@ interface NDField, N : NDStructure> : Field { fun produce(initializer: F.(IntArray) -> T): N - fun N.map(transform: F.(T) -> T): N + fun map(arg: N, transform: F.(T) -> T): N - fun N.mapIndexed(transform: F.(index: IntArray, T) -> T): N + fun mapIndexed(arg: N, transform: F.(index: IntArray, T) -> T): N fun combine(a: N, b: N, transform: F.(T, T) -> T): N @@ -87,11 +87,11 @@ abstract class AbstractNDField, N : NDStructure>( override val one: N by lazy { produce { one } } - final override operator fun Function1.invoke(structure: N) = structure.map { value -> this@invoke(value) } - final override operator fun N.plus(arg: T) = this.map { value -> elementField.run { arg + value } } - final override operator fun N.minus(arg: T) = this.map { value -> elementField.run { arg - value } } - final override operator fun N.times(arg: T) = this.map { value -> elementField.run { arg * value } } - final override operator fun N.div(arg: T) = this.map { value -> elementField.run { arg / value } } + final override operator fun Function1.invoke(structure: N) = map(structure) { value -> this@invoke(value) } + final override operator fun N.plus(arg: T) = map(this) { value -> elementField.run { arg + value } } + final override operator fun N.minus(arg: T) = map(this) { value -> elementField.run { arg - value } } + final override operator fun N.times(arg: T) = map(this) { value -> elementField.run { arg * value } } + final override operator fun N.div(arg: T) = map(this) { value -> elementField.run { arg / value } } final override operator fun T.plus(arg: N) = arg + this final override operator fun T.minus(arg: N) = arg - this @@ -109,7 +109,7 @@ abstract class AbstractNDField, N : NDStructure>( * Multiply all elements by cinstant */ override fun multiply(a: N, k: Double): N = - a.map { it * k } + map(a) { it * k } /** @@ -140,17 +140,16 @@ class GenericNDField>( shape: IntArray, elementField: F, val bufferFactory: BufferFactory = ::boxingBuffer -) : - AbstractNDField>(shape, elementField) { +) : AbstractNDField>(shape, elementField) { override fun produce(initializer: F.(IntArray) -> T): NDStructure = ndStructure(shape, bufferFactory) { elementField.initializer(it) } - override fun NDStructure.map(transform: F.(T) -> T): NDStructure = - produce { index -> transform(get(index)) } + override fun map(arg: NDStructure, transform: F.(T) -> T): NDStructure = + produce { index -> transform(arg.get(index)) } - override fun NDStructure.mapIndexed(transform: F.(index: IntArray, T) -> T): NDStructure = - produce { index -> transform(index, get(index)) } + override fun mapIndexed(arg: NDStructure, transform: F.(index: IntArray, T) -> T): NDStructure = + produce { index -> transform(index, arg.get(index)) } override fun combine(a: NDStructure, b: NDStructure, transform: F.(T, T) -> T): NDStructure = produce { index -> transform(a[index], b[index]) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 344154ceb..5bd58eb6e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -1,41 +1,47 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.DoubleField +import scientifik.kmath.operations.RealField -typealias RealNDElement = BufferNDElement +typealias RealNDElement = BufferNDElement class RealNDField(shape: IntArray) : - StridedNDField(shape, DoubleField), - ExtendedNDField> { + StridedNDField(shape, RealField), + ExtendedNDField> { - override val bufferFactory: BufferFactory - get() = DoubleBufferFactory + override fun buildBuffer(size: Int, initializer: (Int) -> Double): Buffer = + DoubleBuffer(DoubleArray(size, initializer)) /** - * Inline map an NDStructure to + * Inline transform an NDStructure to */ @Suppress("OVERRIDE_BY_INLINE") - override inline fun NDBuffer.map(crossinline transform: DoubleField.(Double) -> Double): RealNDElement { - check(this) - val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) } - return BufferNDElement(this@RealNDField, DoubleBuffer(array)) + override inline fun map( + arg: NDBuffer, + crossinline transform: RealField.(Double) -> Double + ): RealNDElement { + check(arg) + val array = DoubleArray(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) } + return BufferNDElement(this, DoubleBuffer(array)) } @Suppress("OVERRIDE_BY_INLINE") - override inline fun produce(crossinline initializer: DoubleField.(IntArray) -> Double): RealNDElement { + override inline fun produce(crossinline initializer: RealField.(IntArray) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> elementField.initializer(strides.index(offset)) } return BufferNDElement(this, DoubleBuffer(array)) } @Suppress("OVERRIDE_BY_INLINE") - override inline fun NDBuffer.mapIndexed(crossinline transform: DoubleField.(index: IntArray, Double) -> Double): BufferNDElement { - check(this) + override inline fun mapIndexed( + arg: NDBuffer, + crossinline transform: RealField.(index: IntArray, Double) -> Double + ): BufferNDElement { + check(arg) return BufferNDElement( - this@RealNDField, - bufferFactory(strides.linearSize) { offset -> + this, + buildBuffer(arg.strides.linearSize) { offset -> elementField.transform( - strides.index(offset), - buffer[offset] + arg.strides.index(offset), + arg.buffer[offset] ) }) } @@ -44,23 +50,23 @@ class RealNDField(shape: IntArray) : override inline fun combine( a: NDBuffer, b: NDBuffer, - crossinline transform: DoubleField.(Double, Double) -> Double - ): BufferNDElement { + crossinline transform: RealField.(Double, Double) -> Double + ): BufferNDElement { check(a, b) return BufferNDElement( this, - bufferFactory(strides.linearSize) { offset -> elementField.transform(a.buffer[offset], b.buffer[offset]) }) + buildBuffer(strides.linearSize) { offset -> elementField.transform(a.buffer[offset], b.buffer[offset]) }) } - override fun power(arg: NDBuffer, pow: Double) = arg.map { power(it, pow) } + override fun power(arg: NDBuffer, pow: Double) = map(arg) { power(it, pow) } - override fun exp(arg: NDBuffer) = arg.map { exp(it) } + override fun exp(arg: NDBuffer) = map(arg) { exp(it) } - override fun ln(arg: NDBuffer) = arg.map { ln(it) } + override fun ln(arg: NDBuffer) = map(arg) { ln(it) } - override fun sin(arg: NDBuffer) = arg.map { sin(it) } + override fun sin(arg: NDBuffer) = map(arg) { sin(it) } - override fun cos(arg: NDBuffer) = arg.map { cos(it) } + override fun cos(arg: NDBuffer) = map(arg) { cos(it) } // // override fun NDBuffer.times(k: Number) = mapInline { value -> value * k.toDouble() } // @@ -75,7 +81,7 @@ class RealNDField(shape: IntArray) : /** * Fast element production using function inlining */ -inline fun StridedNDField.produceInline(crossinline initializer: DoubleField.(Int) -> Double): RealNDElement { +inline fun StridedNDField.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> elementField.initializer(offset) } return BufferNDElement(this, DoubleBuffer(array)) } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt index 9266b9394..5c5002e06 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt @@ -2,14 +2,14 @@ package scientifik.kmath.expressions import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField -import scientifik.kmath.operations.DoubleField +import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals class ExpressionFieldTest { @Test fun testExpression() { - val context = ExpressionField(DoubleField) + val context = ExpressionField(RealField) val expression = with(context) { val x = variable("x", 2.0) x * x + 2 * x + 1.0 @@ -36,7 +36,7 @@ class ExpressionFieldTest { return x * x + 2 * x + one } - val expression = ExpressionField(DoubleField).expression() + val expression = ExpressionField(RealField).expression() assertEquals(expression("x" to 1.0), 4.0) } @@ -47,7 +47,7 @@ class ExpressionFieldTest { x * x + 2 * x + 1.0 } - val expression = ExpressionField(DoubleField).expressionBuilder() + val expression = ExpressionField(RealField).expressionBuilder() assertEquals(expression("x" to 1.0), 4.0) } } \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt index df474c53b..aedbf6655 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt @@ -6,7 +6,7 @@ import kotlin.test.assertEquals class RealFieldTest { @Test fun testSqrt() { - val sqrt = with(DoubleField) { + val sqrt = with(RealField) { sqrt(25 * one) } assertEquals(5.0, sqrt) diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/structures/LazyNDField.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/structures/LazyNDField.kt index f1752b483..edda12888 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/structures/LazyNDField.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/structures/LazyNDField.kt @@ -2,54 +2,80 @@ package scientifik.kmath.structures import kotlinx.coroutines.* import scientifik.kmath.operations.Field +import scientifik.kmath.operations.FieldElement class LazyNDField>(shape: IntArray, field: F, val scope: CoroutineScope = GlobalScope) : - BufferNDField(shape, field, ::boxingBuffer) { + AbstractNDField>(shape, field) { - override fun add(a: NDStructure, b: NDStructure): NDElements { - return LazyNDStructure(this) { index -> - val aDeferred = a.deferred(index) - val bDeferred = b.deferred(index) - aDeferred.await() + bDeferred.await() + override val zero by lazy { produce { zero } } + + override val one by lazy { produce { one } } + + override fun produce(initializer: F.(IntArray) -> T) = + LazyNDStructure(this) { elementField.initializer(it) } + + override fun mapIndexed( + arg: NDStructure, + transform: F.(index: IntArray, T) -> T + ): LazyNDStructure { + check(arg) + return if (arg is LazyNDStructure) { + LazyNDStructure(this) { index -> + this.elementField.transform(index, arg.function(index)) + } + } else { + LazyNDStructure(this) { elementField.transform(it, arg.await(it)) } } } - override fun multiply(a: NDStructure, k: Double): NDElements { - return LazyNDStructure(this) { index -> a.await(index) * k } - } + override fun map(arg: NDStructure, transform: F.(T) -> T) = + mapIndexed(arg) { _, t -> transform(t) } - override fun multiply(a: NDStructure, b: NDStructure): NDElements { - return LazyNDStructure(this) { index -> - val aDeferred = a.deferred(index) - val bDeferred = b.deferred(index) - aDeferred.await() * bDeferred.await() + override fun combine(a: NDStructure, b: NDStructure, transform: F.(T, T) -> T): LazyNDStructure { + check(a, b) + return if (a is LazyNDStructure && b is LazyNDStructure) { + LazyNDStructure(this@LazyNDField) { index -> + elementField.transform( + a.function(index), + b.function(index) + ) + } + } else { + LazyNDStructure(this@LazyNDField) { elementField.transform(a.await(it), b.await(it)) } } } - override fun divide(a: NDStructure, b: NDStructure): NDElements { - return LazyNDStructure(this) { index -> - val aDeferred = a.deferred(index) - val bDeferred = b.deferred(index) - aDeferred.await() / bDeferred.await() + fun NDStructure.lazy(): LazyNDStructure { + check(this) + return if (this is LazyNDStructure) { + LazyNDStructure(this@LazyNDField, function) + } else { + LazyNDStructure(this@LazyNDField) { get(it) } } } } class LazyNDStructure>( override val context: LazyNDField, - val function: suspend F.(IntArray) -> T -) : NDElements, NDStructure { - override val self: NDElements get() = this + val function: suspend (IntArray) -> T +) : FieldElement, LazyNDStructure, LazyNDField>, NDElement { + + + override fun unwrap(): NDStructure = this + + override fun NDStructure.wrap(): LazyNDStructure = LazyNDStructure(context) { await(it) } + override val shape: IntArray get() = context.shape + override val elementField: F get() = context.elementField + + override fun mapIndexed(transform: F.(index: IntArray, T) -> T): NDElement = + context.run { mapIndexed(this@LazyNDStructure, transform) } private val cache = HashMap>() fun deferred(index: IntArray) = cache.getOrPut(index) { context.scope.async(context = Dispatchers.Math) { - function.invoke( - context.elementField, - index - ) + function(index) } } @@ -61,7 +87,10 @@ class LazyNDStructure>( override fun elements(): Sequence> { val strides = DefaultStrides(shape) - return strides.indices().map { index -> index to runBlocking { await(index) } } + val res = runBlocking { + strides.indices().toList().map { index -> index to await(index) } + } + return res.asSequence() } } @@ -71,21 +100,11 @@ fun NDStructure.deferred(index: IntArray) = suspend fun NDStructure.await(index: IntArray) = if (this is LazyNDStructure) this.await(index) else get(index) -fun > NDElements.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure { - return if (this is LazyNDStructure) { - this - } else { - val context = LazyNDField(context.shape, context.field, scope) - LazyNDStructure(context) { get(it) } - } -} -inline fun > LazyNDStructure.mapIndexed(crossinline action: suspend F.(IntArray, T) -> T) = - LazyNDStructure(context) { index -> - action.invoke(this, index, await(index)) - } +fun > NDField.Companion.lazy(shape: IntArray, field: F, scope: CoroutineScope = GlobalScope) = + LazyNDField(shape, field, scope) -inline fun > LazyNDStructure.map(crossinline action: suspend F.(T) -> T) = - LazyNDStructure(context) { index -> - action.invoke(this, await(index)) - } \ No newline at end of file +fun > NDStructure.lazy(field: F, scope: CoroutineScope = GlobalScope): LazyNDStructure { + val context: LazyNDField = LazyNDField(shape, field, scope) + return LazyNDStructure(context) { get(it) } +} \ No newline at end of file diff --git a/kmath-coroutines/src/commonTest/kotlin/scientifik/kmath/structures/LazyNDFieldTest.kt b/kmath-coroutines/src/commonTest/kotlin/scientifik/kmath/structures/LazyNDFieldTest.kt index 1eee26f29..51e1db5ae 100644 --- a/kmath-coroutines/src/commonTest/kotlin/scientifik/kmath/structures/LazyNDFieldTest.kt +++ b/kmath-coroutines/src/commonTest/kotlin/scientifik/kmath/structures/LazyNDFieldTest.kt @@ -10,7 +10,7 @@ class LazyNDFieldTest { fun testLazyStructure() { var counter = 0 val regularStructure = NDField.generic(intArrayOf(2, 2, 2), IntField).produce { it[0] + it[1] - it[2] } - val result = (regularStructure.lazy() + 2).transform { + val result = (regularStructure.lazy(IntField) + 2).map { counter++ it * it }