From 53d752dee8267f0b11b1ede2c1807c8634ffab15 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 3 Jan 2019 20:48:10 +0300 Subject: [PATCH] Major refactoring of types for algebra elements. Element type could differ from field type. --- .../structures/StructureWriteBenchmark.kt | 6 +- .../kmath/linear/LUDecomposition.kt | 2 +- .../kotlin/scientifik/kmath/linear/Matrix.kt | 4 +- .../kotlin/scientifik/kmath/linear/Vector.kt | 28 ++- .../scientifik/kmath/operations/Algebra.kt | 48 +---- .../kmath/operations/AlgebraElements.kt | 49 +++++ .../scientifik/kmath/operations/Complex.kt | 12 +- .../scientifik/kmath/operations/Fields.kt | 58 +++--- .../kmath/structures/BufferNDField.kt | 90 ++++---- .../kmath/structures/ExtendedNDField.kt | 24 +-- .../scientifik/kmath/structures/NDElement.kt | 126 +++++++++++ .../scientifik/kmath/structures/NDField.kt | 197 ++++-------------- .../kmath/structures/NDStructure.kt | 36 ++-- .../kmath/structures/RealNDField.kt | 33 ++- .../kmath/structures/NDFieldTest.kt | 2 +- .../kmath/structures/NumberNDFieldTest.kt | 1 - .../kmath/structures/LazyNDField.kt | 14 +- 17 files changed, 378 insertions(+), 352 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt diff --git a/benchmarks/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt b/benchmarks/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt index 538993d8d..4402264f8 100644 --- a/benchmarks/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt +++ b/benchmarks/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt @@ -8,12 +8,12 @@ fun main(args: Array) { val n = 6000 - val structure = NdStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 } + val structure = ndStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 } - structure.map { it + 1 } // warm-up + structure.mapToBuffer { it + 1 } // warm-up val time1 = measureTimeMillis { - val res = structure.map { it + 1 } + val res = structure.mapToBuffer { it + 1 } } println("Structure mapping finished in $time1 millis") diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt index 3f14787f2..311e590b6 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt @@ -104,7 +104,7 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr val m = matrix.numCols val pivot = IntArray(matrix.numRows) //TODO fix performance - val lu: MutableNDStructure = MutableNdStructure(intArrayOf(matrix.numRows, matrix.numCols), ::boxingMutableBuffer) { index: IntArray -> matrix[index[0], index[1]] } + val lu: MutableNDStructure = mutableNdStructure(intArrayOf(matrix.numRows, matrix.numCols), ::boxingMutableBuffer) { index: IntArray -> matrix[index[0], index[1]] } with(matrix.context.ring) { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index 13a54cb5d..f5269c74b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -122,11 +122,11 @@ data class StructureMatrixSpace>( override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix { return if (rows == rowNum && columns == colNum) { - val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) } + val structure = ndStructure(strides, bufferFactory) { initializer(it[0], it[1]) } StructureMatrix(this, structure) } else { val context = StructureMatrixSpace(rows, columns, ring, bufferFactory) - val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) } + val structure = ndStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) } StructureMatrix(context, structure) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt index 756f85959..12758d887 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt @@ -16,13 +16,18 @@ interface VectorSpace> : Space> { val space: S - fun produce(initializer: (Int) -> T): Vector + fun produce(initializer: (Int) -> T): Point - override val zero: Vector get() = produce { space.zero } + /** + * Produce a space-element of this vector space for expressions + */ + fun produceElement(initializer: (Int) -> T): Vector - override fun add(a: Point, b: Point): Vector = produce { with(space) { a[it] + b[it] } } + override val zero: Point get() = produce { space.zero } - override fun multiply(a: Point, k: Double): Vector = produce { with(space) { a[it] * k } } + override fun add(a: Point, b: Point): Point = produce { with(space) { a[it] + b[it] } } + + override fun multiply(a: Point, k: Double): Point = produce { with(space) { a[it] * k } } //TODO add basis @@ -53,7 +58,7 @@ interface VectorSpace> : Space> { /** * A point coupled to the linear space */ -interface Vector> : SpaceElement, VectorSpace>, Point { +interface Vector> : SpaceElement, VectorSpace>, Point { override val size: Int get() = context.size override operator fun plus(b: Point): Vector = context.add(self, b) @@ -65,11 +70,11 @@ interface Vector> : SpaceElement, VectorSpace> generic(size: Int, field: F, initializer: (Int) -> T) = - VectorSpace.buffered(size, field).produce(initializer) + fun > generic(size: Int, field: S, initializer: (Int) -> T): Vector = + VectorSpace.buffered(size, field).produceElement(initializer) - fun real(size: Int, initializer: (Int) -> Double) = VectorSpace.real(size).produce(initializer) - fun ofReal(vararg elements: Double) = VectorSpace.real(elements.size).produce{elements[it]} + fun real(size: Int, initializer: (Int) -> Double): Vector = VectorSpace.real(size).produceElement(initializer) + fun ofReal(vararg elements: Double): Vector = VectorSpace.real(elements.size).produceElement { elements[it] } } } @@ -79,7 +84,8 @@ data class BufferVectorSpace>( override val space: S, val bufferFactory: BufferFactory ) : VectorSpace { - override fun produce(initializer: (Int) -> T): Vector = BufferVector(this, bufferFactory(size, initializer)) + override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer) + override fun produceElement(initializer: (Int) -> T): Vector = BufferVector(this, produce(initializer)) } @@ -95,7 +101,7 @@ data class BufferVector>(override val context: VectorSpace return buffer[index] } - override val self: BufferVector get() = this + override fun getSelf(): BufferVector = (0 until size).map { buffer[it] }.iterator() diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 51f3e75b3..659ab9d07 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -1,23 +1,6 @@ package scientifik.kmath.operations -/** - * The generic mathematics elements which is able to store its context - * @param T the self type of the element - * @param S the type of mathematical context for this element - */ -interface MathElement { - /** - * Self value. Needed for static type checking. - */ - val self: T - - /** - * The context this element belongs to - */ - val context: S -} - /** * A general interface representing linear context of some kind. * The context defines sum operation for its elements and multiplication by real value. @@ -53,19 +36,8 @@ interface Space { //TODO move to external extensions when they are available fun Iterable.sum(): T = fold(zero) { left, right -> left + right } - fun Sequence.sum(): T = fold(zero) { left, right -> left + right } -} -/** - * The element of linear context - * @param T self type of the element. Needed for static type checking - * @param S the type of space - */ -interface SpaceElement> : MathElement { - operator fun plus(b: T): T = context.add(self, b) - operator fun minus(b: T): T = context.add(self, context.multiply(b, -1.0)) - operator fun times(k: Number): T = context.multiply(self, k.toDouble()) - operator fun div(k: Number): T = context.multiply(self, 1.0 / k.toDouble()) + fun Sequence.sum(): T = fold(zero) { left, right -> left + right } } /** @@ -86,15 +58,6 @@ interface Ring : Space { } -/** - * Ring element - */ -interface RingElement> : SpaceElement { - override val context: S - - operator fun times(b: T): T = context.multiply(self, b) -} - /** * Four operations algebra */ @@ -109,13 +72,4 @@ interface Field : Ring { operator fun T.minus(b: Number) = this.minus(b * one) operator fun Number.minus(b: T) = -b + this -} - -/** - * Field element - */ -interface FieldElement> : RingElement { - override val context: F - - operator fun div(b: T): T = context.divide(self, b) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt new file mode 100644 index 000000000..4e5854453 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt @@ -0,0 +1,49 @@ +package scientifik.kmath.operations + +/** + * The generic mathematics elements which is able to store its context + * @param S the type of mathematical context for this element + */ +interface MathElement { + /** + * The context this element belongs to + */ + val context: S +} + +/** + * The element of linear context + * @param T the type of space operation results + * @param I self type of the element. Needed for static type checking + * @param S the type of space + */ +interface SpaceElement, S : Space> : MathElement { + /** + * Self value. Needed for static type checking. + */ + fun unwrap(): T + + fun T.wrap(): I + + operator fun plus(b: T) = context.add(unwrap(), b).wrap() + operator fun minus(b: T) = context.add(unwrap(), context.multiply(b, -1.0)).wrap() + operator fun times(k: Number) = context.multiply(unwrap(), k.toDouble()).wrap() + operator fun div(k: Number) = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap() +} + +/** + * Ring element + */ +interface RingElement, R : Ring> : SpaceElement { + override val context: R + operator fun times(b: T) = context.multiply(unwrap(), b).wrap() +} + +/** + * Field element + */ +interface FieldElement, F : Field> : RingElement { + override val context: F + + operator fun div(b: T) = context.divide(unwrap(), b).wrap() +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index e8a80a451..6f30e8cfc 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -16,15 +16,18 @@ object ComplexField : Field { override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) - override fun divide(a: Complex, b: Complex): Complex = Complex(a.re * b.re + a.im * b.im, a.re * b.im - a.im * b.re) / b.square + override fun divide(a: Complex, b: Complex): Complex { + val norm = b.square + return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm) + } - operator fun Double.plus(c: Complex) = this.toComplex() + c + operator fun Double.plus(c: Complex) = add(this.toComplex(), c) - operator fun Double.minus(c: Complex) = this.toComplex() - c + operator fun Double.minus(c: Complex) = add(this.toComplex(), -c) operator fun Complex.plus(d: Double) = d + this - operator fun Complex.minus(d: Double) = this - d.toComplex() + operator fun Complex.minus(d: Double) = add(this, -d.toComplex()) operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) } @@ -34,6 +37,7 @@ object ComplexField : Field { */ data class Complex(val re: Double, val im: Double) : FieldElement { override val self: Complex get() = this + override val context: ComplexField get() = ComplexField diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt index eabae8ea4..3cc6ed377 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt @@ -12,42 +12,18 @@ interface ExtendedField : ExponentialOperations -/** - * Field for real values - */ -object RealField : ExtendedField, Norm { - override val zero: Real = Real(0.0) - override fun add(a: Real, b: Real): Real = Real(a.value + b.value) - override val one: Real = Real(1.0) - override fun multiply(a: Real, b: Real): Real = Real(a.value * b.value) - override fun multiply(a: Real, k: Double): Real = Real(a.value * k) - override fun divide(a: Real, b: Real): Real = Real(a.value / b.value) - - override fun sin(arg: Real): Real = Real(kotlin.math.sin(arg.value)) - override fun cos(arg: Real): Real = Real(kotlin.math.cos(arg.value)) - - override fun power(arg: Real, pow: Double): Real = Real(arg.value.pow(pow)) - - override fun exp(arg: Real): Real = Real(kotlin.math.exp(arg.value)) - - override fun ln(arg: Real): Real = Real(kotlin.math.ln(arg.value)) - - override fun norm(arg: Real): Real = Real(kotlin.math.abs(arg.value)) -} /** * Real field element wrapping double. * * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 */ -inline class Real(val value: Double) : FieldElement { +inline class Real(val value: Double) : FieldElement { + override fun unwrap(): Double = value - //values are dynamically calculated to save memory - override val self - get() = this + override fun Double.wrap(): Real = Real(value) - override val context - get() = RealField + override val context get() = DoubleField companion object { @@ -79,11 +55,31 @@ object DoubleField : ExtendedField, Norm { /** * A field for double without boxing. Does not produce appropriate field element */ -object IntField : Field{ +object IntField : Field { override val zero: Int = 0 override fun add(a: Int, b: Int): Int = a + b override fun multiply(a: Int, b: Int): Int = a * b - override fun multiply(a: Int, k: Double): Int = (k*a).toInt() + override fun multiply(a: Int, k: Double): Int = (k * a).toInt() override val one: Int = 1 override fun divide(a: Int, b: Int): Int = a / b -} \ No newline at end of file +} + +//interface FieldAdapter : Field { +// +// val field: Field +// +// abstract fun T.evolve(): R +// abstract fun R.devolve(): T +// +// override val zero get() = field.zero.evolve() +// override val one get() = field.zero.evolve() +// +// override fun add(a: R, b: R): R = field.add(a.devolve(), b.devolve()).evolve() +// +// override fun multiply(a: R, k: Double): R = field.multiply(a.devolve(), k).evolve() +// +// +// override fun multiply(a: R, b: R): R = field.multiply(a.devolve(), b.devolve()).evolve() +// +// override fun divide(a: R, b: R): R = field.divide(a.devolve(), b.devolve()).evolve() +//} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt index ebe2a67cf..87035b9fa 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt @@ -1,45 +1,65 @@ package scientifik.kmath.structures import scientifik.kmath.operations.Field +import scientifik.kmath.operations.FieldElement -open class BufferNDField>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory) : NDField { +open class BufferNDField>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory) : NDField> { val strides = DefaultStrides(shape) - override fun produce(initializer: F.(IntArray) -> T): BufferNDElement { - return BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) }) - } + override fun produce(initializer: F.(IntArray) -> T) = + BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) }) - open fun produceBuffered(initializer: F.(Int) -> T) = - BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(offset) }) + open fun NDBuffer.map(transform: F.(T) -> T) = + BufferNDElement(this@BufferNDField, bufferFactory(strides.linearSize) { offset -> field.transform(buffer[offset]) }) -// override fun add(a: NDStructure, b: NDStructure): NDElement { -// checkShape(a, b) -// return if (a is BufferNDElement && b is BufferNDElement) { -// BufferNDElement(this,bufferFactory(strides.linearSize){i-> field.run { a.buffer[i] + b.buffer[i]}}) -// } else { -// produce { field.run { a[it] + b[it] } } -// } -// } -// -// override fun NDStructure.plus(b: Number): NDElement { -// checkShape(this) -// return if (this is BufferNDElement) { -// BufferNDElement(this@BufferNDField,bufferFactory(strides.linearSize){i-> field.run { this@plus.buffer[i] + b}}) -// } else { -// produce {index -> field.run { this@plus[index] + b } } -// } -// } + open fun NDBuffer.mapIndexed(transform: F.(index: IntArray, T) -> T) = + BufferNDElement(this@BufferNDField, bufferFactory(strides.linearSize) { offset -> field.transform(strides.index(offset), buffer[offset]) }) + + open fun combine(a: NDBuffer, b: NDBuffer, transform: F.(T, T) -> T) = + BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.transform(a[offset], b[offset]) }) + + /** + * Convert any [NDStructure] to buffered structure using strides from this context. + * If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over indexes + */ + fun NDStructure.toBuffer(): NDBuffer = + this as? NDBuffer ?: produce { index -> get(index) } + + override val zero: NDBuffer by lazy { produce { field.zero } } + + override fun add(a: NDBuffer, b: NDBuffer): NDBuffer = combine(a, b) { aValue, bValue -> add(aValue, bValue) } + + override fun multiply(a: NDBuffer, k: Double): NDBuffer = a.map { it * k } + + override val one: NDBuffer by lazy { produce { field.one } } + + override fun multiply(a: NDBuffer, b: NDBuffer): NDBuffer = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } + + override fun divide(a: NDBuffer, b: NDBuffer): NDBuffer = combine(a, b) { aValue, bValue -> divide(aValue, bValue) } } -class BufferNDElement>(override val context: BufferNDField, val buffer: Buffer) : NDElement { +class BufferNDElement>(override val context: BufferNDField, override val buffer: Buffer) : + NDBuffer, + FieldElement, BufferNDElement, BufferNDField>, + NDElement { + + override val elementField: F get() = context.field + + override fun unwrap(): NDBuffer = this + + override fun NDBuffer.wrap(): BufferNDElement = BufferNDElement(context, this.buffer) + + override val strides get() = context.strides - override val self: NDStructure get() = this override val shape: IntArray get() = context.shape - override fun get(index: IntArray): T = buffer[context.strides.offset(index)] + override fun get(index: IntArray): T = buffer[strides.offset(index)] - override fun elements(): Sequence> = context.strides.indices().map { it to get(it) } + override fun elements(): Sequence> = strides.indices().map { it to get(it) } + override fun map(action: F.(T) -> T): BufferNDElement = context.run { map(action) } + + override fun mapIndexed(transform: F.(index: IntArray, T) -> T): BufferNDElement = context.run { mapIndexed(transform) } } @@ -47,7 +67,7 @@ class BufferNDElement>(override val context: BufferNDField * Element by element application of any operation on elements to the whole array. Just like in numpy */ operator fun > Function1.invoke(ndElement: BufferNDElement) = - ndElement.context.produceBuffered { i -> invoke(ndElement.buffer[i]) } + ndElement.context.run { ndElement.map { invoke(it) } } /* plus and minus */ @@ -55,24 +75,24 @@ operator fun > Function1.invoke(ndElement: BufferNDE * Summation operation for [BufferNDElement] and single element */ operator fun > BufferNDElement.plus(arg: T) = - context.produceBuffered { i -> buffer[i] + arg } + context.run { map { it + arg } } /** * Subtraction operation between [BufferNDElement] and single element */ -operator fun > BufferNDElement.minus(arg: T) = - context.produceBuffered { i -> buffer[i] - arg } +operator fun > BufferNDElement.minus(arg: T) = + context.run { map { it - arg } } /* prod and div */ /** * Product operation for [BufferNDElement] and single element */ -operator fun > BufferNDElement.times(arg: T) = - context.produceBuffered { i -> buffer[i] * arg } +operator fun > BufferNDElement.times(arg: T) = + context.run { map { it * arg } } /** * Division operation between [BufferNDElement] and single element */ -operator fun > BufferNDElement.div(arg: T) = - context.produceBuffered { i -> buffer[i] / arg } \ No newline at end of file +operator fun > BufferNDElement.div(arg: T) = + context.run { map { it / arg } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index d274f92bd..7c1b63c23 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -6,40 +6,40 @@ import scientifik.kmath.operations.PowerOperations import scientifik.kmath.operations.TrigonometricOperations -interface ExtendedNDField> : - NDField, - TrigonometricOperations>, - PowerOperations>, - ExponentialOperations> +interface ExtendedNDField, N : NDStructure> : + NDField, + TrigonometricOperations, + PowerOperations, + ExponentialOperations /** * NDField that supports [ExtendedField] operations on its elements */ -inline class ExtendedNDFieldWrapper>(private val ndField: NDField) : ExtendedNDField { +class ExtendedNDFieldWrapper, N : NDStructure>(private val ndField: NDField) : ExtendedNDField, NDField by ndField { override val shape: IntArray get() = ndField.shape override val field: F get() = ndField.field - override fun produce(initializer: F.(IntArray) -> T): NDElement = ndField.produce(initializer) + override fun produce(initializer: F.(IntArray) -> T) = ndField.produce(initializer) - override fun power(arg: NDStructure, pow: Double): NDElement { + override fun power(arg: N, pow: Double): N { return produce { with(field) { power(arg[it], pow) } } } - override fun exp(arg: NDStructure): NDElement { + override fun exp(arg: N): N { return produce { with(field) { exp(arg[it]) } } } - override fun ln(arg: NDStructure): NDElement { + override fun ln(arg: N): N { return produce { with(field) { ln(arg[it]) } } } - override fun sin(arg: NDStructure): NDElement { + override fun sin(arg: N): N { return produce { with(field) { sin(arg[it]) } } } - override fun cos(arg: NDStructure): NDElement { + override fun cos(arg: N): N { return produce { with(field) { cos(arg[it]) } } } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt new file mode 100644 index 000000000..7f9f77d4e --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt @@ -0,0 +1,126 @@ +package scientifik.kmath.structures + +import scientifik.kmath.operations.DoubleField +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.FieldElement + + +interface NDElement> : NDStructure { + val elementField: F + + fun mapIndexed(transform: F.(index: IntArray, T) -> T): NDElement + fun map(action: F.(T) -> T) = mapIndexed { _, value -> action(value) } +} + + +object NDElements { + /** + * Create a optimized NDArray of doubles + */ + fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }) = + NDField.real(shape).produce(initializer) + + + fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) = + real(intArrayOf(dim)) { initializer(it[0]) } + + + fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) = + real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } + + fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) = + real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } + + + /** + * Simple boxing NDArray + */ + fun > generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): GenericNDElement { + val ndField = GenericNDField(shape, field) + val structure = ndStructure(shape) { index -> field.initializer(index) } + return GenericNDElement(ndField, structure) + } + + inline fun > inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): GenericNDElement { + val ndField = GenericNDField(shape, field) + val structure = ndStructure(shape, ::inlineBuffer) { index -> field.initializer(index) } + return GenericNDElement(ndField, structure) + } +} + + +/** + * Element by element application of any operation on elements to the whole array. Just like in numpy + */ +operator fun > Function1.invoke(ndElement: NDElement) = ndElement.map { value -> this@invoke(value) } + +/* plus and minus */ + +/** + * Summation operation for [NDElements] and single element + */ +operator fun > NDElement.plus(arg: T): NDElement = this.map { value -> elementField.run { arg + value } } + +/** + * Subtraction operation between [NDElements] and single element + */ +operator fun > NDElement.minus(arg: T): NDElement = this.map { value -> elementField.run { arg - value } } + +/* prod and div */ + +/** + * Product operation for [NDElements] and single element + */ +operator fun > NDElement.times(arg: T): NDElement = this.map { value -> elementField.run { arg * value } } + +/** + * Division operation between [NDElements] and single element + */ +operator fun > NDElement.div(arg: T): NDElement = this.map { value -> elementField.run { arg / value } } + + +// /** +// * Reverse sum operation +// */ +// operator fun T.plus(arg: NDStructure): NDElement = produce { index -> +// field.run { this@plus + arg[index] } +// } +// +// /** +// * Reverse minus operation +// */ +// operator fun T.minus(arg: NDStructure): NDElement = produce { index -> +// field.run { this@minus - arg[index] } +// } +// +// /** +// * Reverse product operation +// */ +// operator fun T.times(arg: NDStructure): NDElement = produce { index -> +// field.run { this@times * arg[index] } +// } +// +// /** +// * Reverse division operation +// */ +// operator fun T.div(arg: NDStructure): NDElement = produce { index -> +// field.run { this@div / arg[index] } +// } + + +/** + * Read-only [NDStructure] coupled to the context. + */ +class GenericNDElement>(override val context: NDField>, private val structure: NDStructure) : + NDStructure by structure, + NDElement, + FieldElement, GenericNDElement, NDField>> { + override val elementField: F get() = context.field + + override fun unwrap(): NDStructure = structure + + override fun NDStructure.wrap() = GenericNDElement(context, this) + + override fun mapIndexed(transform: F.(index: IntArray, T) -> T) = + ndStructure(context.shape) { index: IntArray -> context.field.transform(index, get(index)) }.wrap() +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt index 4fd6c3ee5..e42cebfed 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt @@ -1,8 +1,6 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.DoubleField import scientifik.kmath.operations.Field -import scientifik.kmath.operations.FieldElement /** * An exception is thrown when the expected ans actual shape of NDArray differs @@ -13,27 +11,19 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run * Field for n-dimensional arrays. * @param shape - the list of dimensions of the array * @param field - operations field defined on individual array element - * @param T the type of the element contained in NDArray + * @param T - the type of the element contained in ND structure + * @param F - field over structure elements + * @param R - actual nd-element type of this field */ -interface NDField> : Field> { +interface NDField, N : NDStructure> : Field { val shape: IntArray val field: F - /** - * Create new instance of NDArray using field shape and given initializer - * The producer takes list of indices as argument and returns contained value - */ - fun produce(initializer: F.(IntArray) -> T): NDElement - - override val zero: NDElement get() = produce { zero } - - override val one: NDElement get() = produce { one } - /** * Check the shape of given NDArray and throw exception if it does not coincide with shape of the field */ - fun checkShape(vararg elements: NDStructure) { + fun checkShape(vararg elements: N) { elements.forEach { if (!shape.contentEquals(it.shape)) { throw ShapeMismatchException(shape, it.shape) @@ -41,37 +31,8 @@ interface NDField> : Field> { } } - /** - * Element-by-element addition - */ - override fun add(a: NDStructure, b: NDStructure): NDElement { - checkShape(a, b) - return produce { field.run { a[it] + b[it] } } - } + fun produce(initializer: F.(IntArray) -> T): N - /** - * Multiply all elements by cinstant - */ - override fun multiply(a: NDStructure, k: Double): NDElement { - checkShape(a) - return produce { field.run { a[it] * k } } - } - - /** - * Element-by-element multiplication - */ - override fun multiply(a: NDStructure, b: NDStructure): NDElement { - checkShape(a) - return produce { field.run { a[it] * b[it] } } - } - - /** - * Element-by-element division - */ - override fun divide(a: NDStructure, b: NDStructure): NDElement { - checkShape(a) - return produce { field.run { a[it] / b[it] } } - } companion object { /** @@ -82,7 +43,7 @@ interface NDField> : Field> { /** * Create a nd-field with boxing generic buffer */ - fun > generic(shape: IntArray, field: F) = BufferNDField(shape, field, ::boxingBuffer) + fun > generic(shape: IntArray, field: F) = GenericNDField(shape, field) /** * Create a most suitable implementation for nd-field using reified class @@ -92,123 +53,43 @@ interface NDField> : Field> { } -interface NDElement> : FieldElement, NDField>, NDStructure { - companion object { - /** - * Create a platform-optimized NDArray of doubles - */ - fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement { - return NDField.real(shape).produce(initializer) - } +class GenericNDField>(override val shape: IntArray, override val field: F, val bufferFactory: BufferFactory = ::boxingBuffer) : NDField> { + override fun produce(initializer: F.(IntArray) -> T): NDStructure = ndStructure(shape, bufferFactory) { field.initializer(it) } - fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement { - return real(intArrayOf(dim)) { initializer(it[0]) } - } + override val zero: NDStructure by lazy { produce { zero } } - fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement { - return real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } - } + override val one: NDStructure by lazy { produce { one } } - fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement { - return real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } - } -// inline fun real(shape: IntArray, block: ExtendedNDField.() -> NDStructure): NDElement { -// val field = NDField.real(shape) -// return GenericNDElement(field, field.run(block)) -// } - - /** - * Simple boxing NDArray - */ - fun > generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement { - return NDField.generic(shape, field).produce(initializer) - } - - inline fun > inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement { - return NDField.inline(shape, field).produce(initializer) - } + /** + * Element-by-element addition + */ + override fun add(a: NDStructure, b: NDStructure): NDStructure { + checkShape(a, b) + return produce { field.run { a[it] + b[it] } } } -} -inline fun > NDElement.transformIndexed(crossinline action: F.(IntArray, T) -> T): NDElement = context.produce { action(it, get(*it)) } -inline fun > NDElement.transform(crossinline action: F.(T) -> T): NDElement = context.produce { action(get(*it)) } + /** + * Multiply all elements by cinstant + */ + override fun multiply(a: NDStructure, k: Double): NDStructure { + checkShape(a) + return produce { field.run { a[it] * k } } + } + /** + * Element-by-element multiplication + */ + override fun multiply(a: NDStructure, b: NDStructure): NDStructure { + checkShape(a) + return produce { field.run { a[it] * b[it] } } + } -/** - * Element by element application of any operation on elements to the whole array. Just like in numpy - */ -operator fun > Function1.invoke(ndElement: NDElement): NDElement = ndElement.transform { value -> this@invoke(value) } - -/* plus and minus */ - -/** - * Summation operation for [NDElement] and single element - */ -operator fun > NDElement.plus(arg: T): NDElement = transform { value -> - context.field.run { arg + value } -} - -/** - * Subtraction operation between [NDElement] and single element - */ -operator fun > NDElement.minus(arg: T): NDElement = transform { value -> - context.field.run { arg - value } -} - -/* prod and div */ - -/** - * Product operation for [NDElement] and single element - */ -operator fun > NDElement.times(arg: T): NDElement = transform { value -> - context.field.run { arg * value } -} - -/** - * Division operation between [NDElement] and single element - */ -operator fun > NDElement.div(arg: T): NDElement = transform { value -> - context.field.run { arg / value } -} - - -// /** -// * Reverse sum operation -// */ -// operator fun T.plus(arg: NDStructure): NDElement = produce { index -> -// field.run { this@plus + arg[index] } -// } -// -// /** -// * Reverse minus operation -// */ -// operator fun T.minus(arg: NDStructure): NDElement = produce { index -> -// field.run { this@minus - arg[index] } -// } -// -// /** -// * Reverse product operation -// */ -// operator fun T.times(arg: NDStructure): NDElement = produce { index -> -// field.run { this@times * arg[index] } -// } -// -// /** -// * Reverse division operation -// */ -// operator fun T.div(arg: NDStructure): NDElement = produce { index -> -// field.run { this@div / arg[index] } -// } - -class GenericNDField>(override val shape: IntArray, override val field: F) : NDField { - override fun produce(initializer: F.(IntArray) -> T): NDElement = GenericNDElement(this, produceStructure(initializer)) - private inline fun produceStructure(crossinline initializer: F.(IntArray) -> T): NDStructure = NdStructure(shape, ::boxingBuffer) { field.initializer(it) } -} - -/** - * Read-only [NDStructure] coupled to the context. - */ -class GenericNDElement>(override val context: NDField, private val structure: NDStructure) : NDElement, NDStructure by structure { - override val self: NDElement get() = this -} + /** + * Element-by-element division + */ + override fun divide(a: NDStructure, b: NDStructure): NDStructure { + checkShape(a) + return produce { field.run { a[it] / b[it] } } + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt index 765d7148b..5e55bea78 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt @@ -112,26 +112,24 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides } } -abstract class GenericNDStructure> : NDStructure { - abstract val buffer: B - abstract val strides: Strides +interface NDBuffer : NDStructure { + val buffer: Buffer + val strides: Strides override fun get(index: IntArray): T = buffer[strides.offset(index)] - override val shape: IntArray - get() = strides.shape + override val shape: IntArray get() = strides.shape - override fun elements() = - strides.indices().map { it to this[it] } + override fun elements() = strides.indices().map { it to this[it] } } /** * Boxing generic [NDStructure] */ -class BufferNDStructure( +data class BufferNDStructure( override val strides: Strides, override val buffer: Buffer -) : GenericNDStructure>() { +) : NDBuffer { init { if (strides.linearSize != buffer.size) { @@ -158,7 +156,7 @@ class BufferNDStructure( /** * Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure] */ -inline fun NDStructure.map(factory: BufferFactory = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure { +inline fun NDStructure.mapToBuffer(factory: BufferFactory = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure { return if (this is BufferNDStructure) { BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) } else { @@ -172,8 +170,7 @@ inline fun NDStructure.map(factory: BufferFactory = : * * Strides should be reused if possible */ -@Suppress("FunctionName") -fun NdStructure(strides: Strides, bufferFactory: BufferFactory, initializer: (IntArray) -> T) = +fun ndStructure(strides: Strides, bufferFactory: BufferFactory = ::boxingBuffer, initializer: (IntArray) -> T) = BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) /** @@ -182,9 +179,8 @@ fun NdStructure(strides: Strides, bufferFactory: BufferFactory, ini inline fun inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = BufferNDStructure(strides, inlineBuffer(strides.linearSize) { i -> initializer(strides.index(i)) }) -@Suppress("FunctionName") -fun NdStructure(shape: IntArray, bufferFactory: BufferFactory, initializer: (IntArray) -> T) = - NdStructure(DefaultStrides(shape), bufferFactory, initializer) +fun ndStructure(shape: IntArray, bufferFactory: BufferFactory = ::boxingBuffer, initializer: (IntArray) -> T) = + ndStructure(DefaultStrides(shape), bufferFactory, initializer) inline fun inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = inlineNDStructure(DefaultStrides(shape), initializer) @@ -195,7 +191,7 @@ inline fun inlineNdStructure(shape: IntArray, crossinline init class MutableBufferNDStructure( override val strides: Strides, override val buffer: MutableBuffer -) : GenericNDStructure>(), MutableNDStructure { +) : NDBuffer, MutableNDStructure { init { if (strides.linearSize != buffer.size) { @@ -209,16 +205,14 @@ class MutableBufferNDStructure( /** * The same as [inlineNDStructure], but mutable */ -@Suppress("FunctionName") -fun MutableNdStructure(strides: Strides, bufferFactory: MutableBufferFactory, initializer: (IntArray) -> T) = +fun mutableNdStructure(strides: Strides, bufferFactory: MutableBufferFactory = ::boxingMutableBuffer, initializer: (IntArray) -> T) = MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) inline fun inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = MutableBufferNDStructure(strides, inlineMutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) }) -@Suppress("FunctionName") -fun MutableNdStructure(shape: IntArray, bufferFactory: MutableBufferFactory, initializer: (IntArray) -> T) = - MutableNdStructure(DefaultStrides(shape), bufferFactory, initializer) +fun mutableNdStructure(shape: IntArray, bufferFactory: MutableBufferFactory = ::boxingMutableBuffer, initializer: (IntArray) -> T) = + mutableNdStructure(DefaultStrides(shape), bufferFactory, initializer) inline fun inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = inlineMutableNdStructure(DefaultStrides(shape), initializer) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 7705c710e..3158c38c1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -4,44 +4,41 @@ import scientifik.kmath.operations.DoubleField typealias RealNDElement = BufferNDElement -class RealNDField(shape: IntArray) : BufferNDField(shape, DoubleField, DoubleBufferFactory), ExtendedNDField { +class RealNDField(shape: IntArray) : + BufferNDField(shape, DoubleField, DoubleBufferFactory), + ExtendedNDField> { /** * Inline map an NDStructure to */ - private inline fun NDStructure.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement { - return if (this is BufferNDElement) { - val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) } - BufferNDElement(this@RealNDField, DoubleBuffer(array)) - } else { - produce { index -> DoubleField.operation(get(index)) } - } + private inline fun NDBuffer.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement { + val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) } + return BufferNDElement(this@RealNDField, DoubleBuffer(array)) } - @Suppress("OVERRIDE_BY_INLINE") override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> field.initializer(strides.index(offset)) } return BufferNDElement(this, DoubleBuffer(array)) } - override fun power(arg: NDStructure, pow: Double) = arg.mapInline { power(it, pow) } + override fun power(arg: NDBuffer, pow: Double) = arg.mapInline { power(it, pow) } - override fun exp(arg: NDStructure) = arg.mapInline { exp(it) } + override fun exp(arg: NDBuffer) = arg.mapInline { exp(it) } - override fun ln(arg: NDStructure) = arg.mapInline { ln(it) } + override fun ln(arg: NDBuffer) = arg.mapInline { ln(it) } - override fun sin(arg: NDStructure) = arg.mapInline { sin(it) } + override fun sin(arg: NDBuffer) = arg.mapInline { sin(it) } - override fun cos(arg: NDStructure) = arg.mapInline { cos(it) } + override fun cos(arg: NDBuffer) = arg.mapInline { cos(it) } - override fun NDStructure.times(k: Number) = mapInline { value -> value * k.toDouble() } + override fun NDBuffer.times(k: Number) = mapInline { value -> value * k.toDouble() } - override fun NDStructure.div(k: Number) = mapInline { value -> value / k.toDouble() } + override fun NDBuffer.div(k: Number) = mapInline { value -> value / k.toDouble() } - override fun Number.times(b: NDStructure) = b * this + override fun Number.times(b: NDBuffer) = b * this - override fun Number.div(b: NDStructure) = b * (1.0 / this.toDouble()) + override fun Number.div(b: NDBuffer) = b * (1.0 / this.toDouble()) } /** diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt index 39cce5c67..637147629 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt @@ -7,7 +7,7 @@ import kotlin.test.assertEquals class NDFieldTest { @Test fun testStrides() { - val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() } + val ndArray = NDElements.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() } assertEquals(ndArray[5, 5], 10.0) } } \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt index 3c4160329..7c5227293 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt @@ -1,7 +1,6 @@ package scientifik.kmath.structures import scientifik.kmath.operations.Norm -import scientifik.kmath.structures.NDElement.Companion.real2D import kotlin.math.abs import kotlin.math.pow import kotlin.test.Test diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/structures/LazyNDField.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/structures/LazyNDField.kt index ad38100cb..8b790da99 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/structures/LazyNDField.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/structures/LazyNDField.kt @@ -5,7 +5,7 @@ import scientifik.kmath.operations.Field class LazyNDField>(shape: IntArray, field: F, val scope: CoroutineScope = GlobalScope) : BufferNDField(shape,field, ::boxingBuffer) { - override fun add(a: NDStructure, b: NDStructure): NDElement { + override fun add(a: NDStructure, b: NDStructure): NDElements { return LazyNDStructure(this) { index -> val aDeferred = a.deferred(index) val bDeferred = b.deferred(index) @@ -13,11 +13,11 @@ class LazyNDField>(shape: IntArray, field: F, val scope: Corouti } } - override fun multiply(a: NDStructure, k: Double): NDElement { + override fun multiply(a: NDStructure, k: Double): NDElements { return LazyNDStructure(this) { index -> a.await(index) * k } } - override fun multiply(a: NDStructure, b: NDStructure): NDElement { + override fun multiply(a: NDStructure, b: NDStructure): NDElements { return LazyNDStructure(this) { index -> val aDeferred = a.deferred(index) val bDeferred = b.deferred(index) @@ -25,7 +25,7 @@ class LazyNDField>(shape: IntArray, field: F, val scope: Corouti } } - override fun divide(a: NDStructure, b: NDStructure): NDElement { + override fun divide(a: NDStructure, b: NDStructure): NDElements { return LazyNDStructure(this) { index -> val aDeferred = a.deferred(index) val bDeferred = b.deferred(index) @@ -34,8 +34,8 @@ class LazyNDField>(shape: IntArray, field: F, val scope: Corouti } } -class LazyNDStructure>(override val context: LazyNDField, val function: suspend F.(IntArray) -> T) : NDElement, NDStructure { - override val self: NDElement get() = this +class LazyNDStructure>(override val context: LazyNDField, val function: suspend F.(IntArray) -> T) : NDElements, NDStructure { + override val self: NDElements get() = this override val shape: IntArray get() = context.shape private val cache = HashMap>() @@ -58,7 +58,7 @@ fun NDStructure.deferred(index: IntArray) = if (this is LazyNDStructure NDStructure.await(index: IntArray) = if (this is LazyNDStructure) this.await(index) else get(index) -fun > NDElement.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure { +fun > NDElements.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure { return if (this is LazyNDStructure) { this } else {