diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LazyStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LazyStructure.kt new file mode 100644 index 000000000..0428535f9 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LazyStructure.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.structures + +// +//class LazyStructureField(val field: Field): Field>{ +// +//} +// +//class LazyStructure : NDStructure { +// +// override val shape: IntArray +// get() = TODO("not implemented") //To change initializer of created properties use File | Settings | File Templates. +// +// override fun get(index: IntArray): T { +// TODO("not implemented") //To change body of created functions use File | Settings | File Templates. +// } +// +// override fun iterator(): Iterator> { +// TODO("not implemented") //To change body of created functions use File | Settings | File Templates. +// } +//} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt index 9461bee85..7b991200b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt @@ -1,6 +1,5 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.DoubleField import scientifik.kmath.operations.Field import scientifik.kmath.operations.FieldElement @@ -107,14 +106,15 @@ abstract class NDField(val shape: IntArray, val field: Field) : Field(override val context: NDField, private val structure: NDStructure) : FieldElement, NDField>, NDStructure by structure { +class NDArray(override val context: NDField, private val structure: NDStructure) : FieldElement, NDField>, NDStructure by structure { //TODO ensure structure is immutable override val self: NDArray get() = this - fun transform(action: (IntArray, T) -> T): NDArray = context.produce { action(it, get(*it)) } + inline fun transform(crossinline action: (IntArray, T) -> T): NDArray = context.produce { action(it, get(*it)) } + inline fun transform(crossinline action: (T) -> T): NDArray = context.produce { action(get(*it)) } } /** @@ -173,7 +173,7 @@ object NDArrays { * Create a platform-optimized NDArray of doubles */ fun realNDArray(shape: IntArray, initializer: (IntArray) -> Double = { 0.0 }): NDArray { - return GenericNDField(shape, DoubleField).produce(initializer) + return RealNDField(shape).produce(initializer) } fun real1DArray(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDArray { @@ -188,6 +188,8 @@ object NDArrays { return realNDArray(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } } + inline fun produceReal(shape: IntArray, block: RealNDField.() -> RealNDArray) = RealNDField(shape).run(block) + // /** // * Simple boxing NDField // */ diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt new file mode 100644 index 000000000..2ba5fd706 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -0,0 +1,46 @@ +package scientifik.kmath.structures + +import scientifik.kmath.operations.DoubleField +import scientifik.kmath.operations.ExponentialOperations +import scientifik.kmath.operations.PowerOperations +import scientifik.kmath.operations.TrigonometricOperations +import kotlin.math.* + +typealias RealNDArray = NDArray + + +class RealNDField(shape: IntArray) : NDField(shape, DoubleField), + TrigonometricOperations, + PowerOperations, + ExponentialOperations { + + override fun produceStructure(initializer: (IntArray) -> Double): NDStructure { + return genericNdStructure(shape, initializer) + } + + override fun power(arg: RealNDArray, pow: Double): RealNDArray { + return arg.transform { d -> d.pow(pow) } + } + + override fun exp(arg: RealNDArray): RealNDArray { + return arg.transform { d -> exp(d) } + } + + override fun ln(arg: RealNDArray): RealNDArray { + return arg.transform { d -> ln(d) } + } + + override fun sin(arg: RealNDArray): RealNDArray { + return arg.transform { d -> sin(d) } + } + + override fun cos(arg: RealNDArray): RealNDArray { + return arg.transform { d -> cos(d) } + } + + fun abs(arg: RealNDArray): RealNDArray { + return arg.transform { d -> abs(d) } + } +} + + diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/FieldExpressionContextTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/FieldExpressionContextTest.kt index 543e13d79..8e2845b9e 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/FieldExpressionContextTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/FieldExpressionContextTest.kt @@ -39,4 +39,15 @@ class FieldExpressionContextTest { val expression = FieldExpressionContext(DoubleField).expression() assertEquals(expression("x" to 1.0), 4.0) } + + @Test + fun valueExpression() { + val expressionBuilder: FieldExpressionContext.()->Expression = { + val x = variable("x") + x * x + 2 * x + 1.0 + } + + val expression = FieldExpressionContext(DoubleField).expressionBuilder() + assertEquals(expression("x" to 1.0), 4.0) + } } \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt index 329f13189..179fb5ade 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt @@ -1,6 +1,8 @@ package scientifik.kmath.structures +import scientifik.kmath.structures.NDArrays.produceReal import scientifik.kmath.structures.NDArrays.real2DArray +import kotlin.math.abs import kotlin.math.pow import kotlin.test.Test import kotlin.test.assertEquals @@ -40,4 +42,18 @@ class RealNDFieldTest { val result = function(array1) + 1.0 assertEquals(10.0, result[1,1]) } + + @Test + fun testLibraryFunction() { + val abs: (Double) -> Double = ::abs + val result = abs(array1) + assertEquals(10.0, result[1,1]) + } + + @Test + fun testAbs(){ + val res = produceReal(array1.shape){ + 1 + abs(array1) + exp(array2) + } + } }