diff --git a/benchmarks/src/main/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt b/benchmarks/src/main/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt index fcacc44f5..85150ad32 100644 --- a/benchmarks/src/main/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt +++ b/benchmarks/src/main/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt @@ -9,7 +9,7 @@ fun main(args: Array) { val bufferedField = NDField.auto(intArrayOf(dim, dim), RealField) val specializedField = NDField.real(intArrayOf(dim, dim)) - val genericField = NDField.generic(intArrayOf(dim, dim), RealField) + val genericField = NDField.buffered(intArrayOf(dim, dim), RealField) val lazyNDField = NDField.lazy(intArrayOf(dim, dim), RealField) // val action: NDField>.() -> Unit = { @@ -75,7 +75,7 @@ fun main(args: Array) { val genericTime = measureTimeMillis { //genericField.run(action) genericField.run { - var res = one + var res: NDBuffer = one repeat(n) { res += 1.0 } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt index c16b03608..e586bca71 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt @@ -32,6 +32,6 @@ fun ClosedFloatingPointRange.toSequence(step: Double): Sequence * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] */ fun ClosedFloatingPointRange.toGrid(numPoints: Int): DoubleArray { - if (numPoints < 2) error("Can't generic grid with less than two points") + if (numPoints < 2) error("Can't buffered grid with less than two points") return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt index 6beea1a04..022a4e61f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt @@ -35,24 +35,22 @@ object NDElements { /** * Simple boxing NDArray */ - fun > generic( + fun > buffered( shape: IntArray, field: F, initializer: F.(IntArray) -> T - ): GenericNDElement { - val ndField = GenericNDField(shape, field) - val structure = ndStructure(shape) { index -> field.initializer(index) } - return GenericNDElement(ndField, structure) + ): BufferNDElement { + val ndField = BufferNDField(shape, field, ::boxingBuffer) + return ndField.produce(initializer) } - inline fun > inline( + inline fun > auto( shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T - ): GenericNDElement { - val ndField = GenericNDField(shape, field) - val structure = ndStructure(shape, ::autoBuffer) { index -> field.initializer(index) } - return GenericNDElement(ndField, structure) + ): BufferNDElement { + val ndField = NDField.auto(shape, field) + return ndField.produce(initializer) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt index 9ed82faa1..c1abf8acd 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt @@ -22,43 +22,10 @@ interface NDField, N : NDStructure> : Field { val elementField: F fun produce(initializer: F.(IntArray) -> T): N - fun map(arg: N, transform: F.(T) -> T): N - fun mapIndexed(arg: N, transform: F.(index: IntArray, T) -> T): N - fun combine(a: N, b: N, transform: F.(T, T) -> T): N - /** - * Element by element application of any operation on elements to the whole array. Just like in numpy - */ - operator fun Function1.invoke(structure: N): N - - /** - * Summation operation for [NDElements] and single element - */ - operator fun N.plus(arg: T): N - - /** - * Subtraction operation between [NDElements] and single element - */ - operator fun N.minus(arg: T): N - - /** - * Product operation for [NDElements] and single element - */ - operator fun N.times(arg: T): N - - /** - * Division operation between [NDElements] and single element - */ - operator fun N.div(arg: T): N - - operator fun T.plus(arg: N): N - operator fun T.minus(arg: N): N - operator fun T.times(arg: N): N - operator fun T.div(arg: N): N - companion object { /** * Create a nd-field for [Double] values @@ -68,14 +35,14 @@ interface NDField, N : NDStructure> : Field { /** * Create a nd-field with boxing generic buffer */ - fun > generic(shape: IntArray, field: F) = GenericNDField(shape, field) + fun > buffered(shape: IntArray, field: F) = + BufferNDField(shape, field, ::boxingBuffer) /** * Create a most suitable implementation for nd-field using reified class. */ - inline fun > auto(shape: IntArray, field: F): BufferNDField { - return BufferNDField(shape, field, ::autoBuffer) - } + inline fun > auto(shape: IntArray, field: F) = + BufferNDField(shape, field, ::autoBuffer) } } @@ -88,16 +55,16 @@ abstract class AbstractNDField, N : NDStructure>( override val one: N by lazy { produce { one } } - final override operator fun Function1.invoke(structure: N) = map(structure) { value -> this@invoke(value) } - final override operator fun N.plus(arg: T) = map(this) { value -> elementField.run { arg + value } } - final override operator fun N.minus(arg: T) = map(this) { value -> elementField.run { arg - value } } - final override operator fun N.times(arg: T) = map(this) { value -> elementField.run { arg * value } } - final override operator fun N.div(arg: T) = map(this) { value -> elementField.run { arg / value } } + operator fun Function1.invoke(structure: N) = map(structure) { value -> this@invoke(value) } + operator fun N.plus(arg: T) = map(this) { value -> elementField.run { arg + value } } + operator fun N.minus(arg: T) = map(this) { value -> elementField.run { arg - value } } + operator fun N.times(arg: T) = map(this) { value -> elementField.run { arg * value } } + operator fun N.div(arg: T) = map(this) { value -> elementField.run { arg / value } } - final override operator fun T.plus(arg: N) = arg + this - final override operator fun T.minus(arg: N) = arg - this - final override operator fun T.times(arg: N) = arg * this - final override operator fun T.div(arg: N) = arg / this + operator fun T.plus(arg: N) = arg + this + operator fun T.minus(arg: N) = arg - this + operator fun T.times(arg: N) = arg * this + operator fun T.div(arg: N) = arg / this /** @@ -135,23 +102,4 @@ abstract class AbstractNDField, N : NDStructure>( } } } -} - -class GenericNDField>( - shape: IntArray, - elementField: F, - val bufferFactory: BufferFactory = ::boxingBuffer -) : AbstractNDField>(shape, elementField) { - - override fun produce(initializer: F.(IntArray) -> T): NDStructure = - ndStructure(shape, bufferFactory) { elementField.initializer(it) } - - override fun map(arg: NDStructure, transform: F.(T) -> T): NDStructure = - produce { index -> transform(arg.get(index)) } - - override fun mapIndexed(arg: NDStructure, transform: F.(index: IntArray, T) -> T): NDStructure = - produce { index -> transform(index, arg.get(index)) } - - override fun combine(a: NDStructure, b: NDStructure, transform: F.(T, T) -> T): NDStructure = - produce { index -> transform(a[index], b[index]) } } \ No newline at end of file diff --git a/kmath-coroutines/src/commonTest/kotlin/scientifik/kmath/structures/LazyNDFieldTest.kt b/kmath-coroutines/src/commonTest/kotlin/scientifik/kmath/structures/LazyNDFieldTest.kt index 51e1db5ae..d5f43d3eb 100644 --- a/kmath-coroutines/src/commonTest/kotlin/scientifik/kmath/structures/LazyNDFieldTest.kt +++ b/kmath-coroutines/src/commonTest/kotlin/scientifik/kmath/structures/LazyNDFieldTest.kt @@ -9,7 +9,7 @@ class LazyNDFieldTest { @Test fun testLazyStructure() { var counter = 0 - val regularStructure = NDField.generic(intArrayOf(2, 2, 2), IntField).produce { it[0] + it[1] - it[2] } + val regularStructure = NDField.auto(intArrayOf(2, 2, 2), IntField).produce { it[0] + it[1] - it[2] } val result = (regularStructure.lazy(IntField) + 2).map { counter++ it * it