diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index 9764fccc3..511571fc9 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -7,6 +7,11 @@ repositories { } kotlin.sourceSets { +// all { +// languageSettings.apply{ +// enableLanguageFeature("NewInference") +// } +// } commonMain { dependencies { api(project(":kmath-core")) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt index 45693645c..07194a7bb 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt @@ -1,33 +1,76 @@ -@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE") - package scientifik.kmath.ast import scientifik.kmath.operations.* object MSTAlgebra : NumericAlgebra { + override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MST.Symbolic(value) - override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = + MST.Unary(operation, arg) - override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right) - - override fun number(value: Number): MST = MST.Numeric(value) + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MST.Binary(operation, left, right) } -object MSTSpace : Space, NumericAlgebra by MSTAlgebra { +object MSTSpace : Space, NumericAlgebra { override val zero: MST = number(0.0) - override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + override fun number(value: Number): MST = MST.Numeric(value) - override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + override fun symbol(value: String): MST = MST.Symbolic(value) + + override fun add(a: MST, b: MST): MST = + binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: MST, k: Number): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MSTAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } -object MSTRing : Ring, Space by MSTSpace { +object MSTRing : Ring, NumericAlgebra { + override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MST.Symbolic(value) + + override val zero: MST = MSTSpace.number(0.0) override val one: MST = number(1.0) + override fun add(a: MST, b: MST): MST = + MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun multiply(a: MST, k: Number): MST = + MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + + override fun multiply(a: MST, b: MST): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, b) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MSTAlgebra.binaryOperation(operation, left, right) } -object MSTField : Field, Ring by MSTRing { - override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) -} \ No newline at end of file +object MSTField : Field{ + override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MST.Numeric(value) + + override val zero: MST = MSTSpace.number(0.0) + override val one: MST = number(1.0) + override fun add(a: MST, b: MST): MST = + MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + + override fun multiply(a: MST, k: Number): MST = + MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + + override fun multiply(a: MST, b: MST): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, b) + + override fun divide(a: MST, b: MST): MST = + binaryOperation(FieldOperations.DIV_OPERATION, a, b) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MSTAlgebra.binaryOperation(operation, left, right) +} diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt index dbd5238e3..61703cac7 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt @@ -1,19 +1,55 @@ package scientifik.kmath.ast import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.NumericAlgebra +import scientifik.kmath.expressions.FunctionalExpressionField +import scientifik.kmath.expressions.FunctionalExpressionRing +import scientifik.kmath.expressions.FunctionalExpressionSpace +import scientifik.kmath.operations.* /** * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. */ -class MSTExpression(val algebra: NumericAlgebra, val mst: MST) : Expression { +class MSTExpression(val algebra: Algebra, val mst: MST) : Expression { /** * Substitute algebra raw value */ - private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra by algebra { - override fun symbol(value: String): T = arguments[value] ?: super.symbol(value) + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra{ + override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) + override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right) + + override fun number(value: Number): T = if(algebra is NumericAlgebra){ + algebra.number(value) + } else{ + error("Numeric nodes are not supported by $this") + } } override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) -} \ No newline at end of file +} + + +inline fun , E : Algebra> A.mst( + mstAlgebra: E, + block: E.() -> MST +): MSTExpression = MSTExpression(this, mstAlgebra.block()) + +inline fun Space.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = + MSTExpression(this, MSTSpace.block()) + +inline fun Ring.mstInRing(block: MSTRing.() -> MST): MSTExpression = + MSTExpression(this, MSTRing.block()) + +inline fun Field.mstInField(block: MSTField.() -> MST): MSTExpression = + MSTExpression(this, MSTField.block()) + +inline fun > FunctionalExpressionSpace.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = + algebra.mstInSpace(block) + +inline fun > FunctionalExpressionRing.mstInRing(block: MSTRing.() -> MST): MSTExpression = + algebra.mstInRing(block) + +inline fun > FunctionalExpressionField.mstInField(block: MSTField.() -> MST): MSTExpression = + algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index e7644998e..bb091732e 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -5,11 +5,10 @@ import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST -import scientifik.kmath.ast.MSTField -import scientifik.kmath.ast.MSTRing -import scientifik.kmath.ast.MSTSpace +import scientifik.kmath.ast.MSTExpression import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.* +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.NumericAlgebra import kotlin.reflect.KClass /** @@ -71,18 +70,12 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() } -inline fun Algebra.compile(mst: MST): Expression = mst.compileWith(T::class, this) +/** + * Compile an [MST] to ASM using given algebra + */ +inline fun Algebra.expresion(mst: MST): Expression = mst.compileWith(T::class, this) -inline fun , E : Algebra> A.asm( - mstAlgebra: E, - block: E.() -> MST -): Expression = mstAlgebra.block().compileWith(T::class, this) - -inline fun > A.asmInSpace(block: MSTSpace.() -> MST): Expression = - MSTSpace.block().compileWith(T::class, this) - -inline fun > A.asmInRing(block: MSTRing.() -> MST): Expression = - MSTRing.block().compileWith(T::class, this) - -inline fun > A.asmInField(block: MSTField.() -> MST): Expression = - MSTField.block().compileWith(T::class, this) +/** + * Optimize performance of an [MSTExpression] using ASM codegen + */ +inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) \ No newline at end of file diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index bfbf5e926..82af1a927 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -1,6 +1,8 @@ package scietifik.kmath.asm -import scientifik.kmath.asm.asmInField +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.mstInSpace import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.RealField import kotlin.test.Test @@ -9,13 +11,14 @@ import kotlin.test.assertEquals class TestAsmExpressions { @Test fun testUnaryOperationInvocation() { - val res = RealField.asmInField { -symbol("x") }("x" to 2.0) + val expression = RealField.mstInSpace { -symbol("x") }.compile() + val res = expression("x" to 2.0) assertEquals(-2.0, res) } @Test fun testConstProductInvocation() { - val res = RealField.asmInField { symbol("x") * 2 }("x" to 2.0) + val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) assertEquals(4.0, res) } } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index 752b3b601..f0f9a8bc1 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -1,8 +1,7 @@ package scietifik.kmath.ast -import scientifik.kmath.asm.compile +import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.parseMath -import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import kotlin.test.Test @@ -12,7 +11,7 @@ class AsmTest { @Test fun parsedExpression() { val mst = "2+2*(2+2)".parseMath() - val res = ComplexField.compile(mst)() + val res = ComplexField.evaluate(mst) assertEquals(Complex(10.0, 0.0), res) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt index cdd6f695b..9f1503285 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt @@ -7,17 +7,17 @@ import scientifik.kmath.operations.Space /** * Create a functional expression on this [Space] */ -fun Space.buildExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = +fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = FunctionalExpressionSpace(this).run(block) /** * Create a functional expression on this [Ring] */ -fun Ring.buildExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = +fun Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = FunctionalExpressionRing(this).run(block) /** * Create a functional expression on this [Field] */ -fun Field.buildExpression(block: FunctionalExpressionField>.() -> Expression): Expression = +fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression = FunctionalExpressionField(this).run(block) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index eaf3cd1d7..e512b1cd8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -8,11 +8,14 @@ import scientifik.kmath.operations.Algebra interface Expression { operator fun invoke(arguments: Map): T - companion object { - operator fun invoke(block: (Map) -> T): Expression = object : Expression { - override fun invoke(arguments: Map): T = block(arguments) - } - } + companion object +} + +/** + * Create simple lazily evaluated expression inside given algebra + */ +fun Algebra.expression(block: Algebra.(arguments: Map) -> T): Expression = object: Expression { + override fun invoke(arguments: Map): T = block(arguments) } operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt index a441a80c0..c8d6e8eb0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -69,10 +69,10 @@ interface FunctionalExpressionAlgebra> : ExpressionAlgebra(override val algebra: A) : FunctionalExpressionAlgebra, - Space> where A : Space { - override val zero: Expression - get() = const(algebra.zero) +open class FunctionalExpressionSpace>(override val algebra: A) : + FunctionalExpressionAlgebra, Space> { + + override val zero: Expression get() = const(algebra.zero) /** * Builds an Expression of addition of two another expressions.