diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt deleted file mode 100644 index 61703cac7..000000000 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTExpression.kt +++ /dev/null @@ -1,55 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.FunctionalExpressionField -import scientifik.kmath.expressions.FunctionalExpressionRing -import scientifik.kmath.expressions.FunctionalExpressionSpace -import scientifik.kmath.operations.* - -/** - * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. - */ -class MSTExpression(val algebra: Algebra, val mst: MST) : Expression { - - /** - * Substitute algebra raw value - */ - private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra{ - override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) - override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: T, right: T): T =algebra.binaryOperation(operation, left, right) - - override fun number(value: Number): T = if(algebra is NumericAlgebra){ - algebra.number(value) - } else{ - error("Numeric nodes are not supported by $this") - } - } - - override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) -} - - -inline fun , E : Algebra> A.mst( - mstAlgebra: E, - block: E.() -> MST -): MSTExpression = MSTExpression(this, mstAlgebra.block()) - -inline fun Space.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = - MSTExpression(this, MSTSpace.block()) - -inline fun Ring.mstInRing(block: MSTRing.() -> MST): MSTExpression = - MSTExpression(this, MSTRing.block()) - -inline fun Field.mstInField(block: MSTField.() -> MST): MSTExpression = - MSTExpression(this, MSTField.block()) - -inline fun > FunctionalExpressionSpace.mstInSpace(block: MSTSpace.() -> MST): MSTExpression = - algebra.mstInSpace(block) - -inline fun > FunctionalExpressionRing.mstInRing(block: MSTRing.() -> MST): MSTExpression = - algebra.mstInRing(block) - -inline fun > FunctionalExpressionField.mstInField(block: MSTField.() -> MST): MSTExpression = - algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt similarity index 62% rename from kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt rename to kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt index f741fc8c4..007cf57c4 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt @@ -2,7 +2,7 @@ package scientifik.kmath.ast import scientifik.kmath.operations.* -object MSTAlgebra : NumericAlgebra { +object MstAlgebra : NumericAlgebra { override fun number(value: Number): MST = MST.Numeric(value) override fun symbol(value: String): MST = MST.Symbolic(value) @@ -14,12 +14,11 @@ object MSTAlgebra : NumericAlgebra { MST.Binary(operation, left, right) } -object MSTSpace : Space, NumericAlgebra { +object MstSpace : Space, NumericAlgebra { override val zero: MST = number(0.0) - override fun number(value: Number): MST = MST.Numeric(value) - - override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) @@ -28,46 +27,46 @@ object MSTSpace : Space, NumericAlgebra { binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) + MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) } -object MSTRing : Ring, NumericAlgebra { - override val zero: MST = MSTSpace.number(0.0) +object MstRing : Ring, NumericAlgebra { + override val zero: MST = number(0.0) override val one: MST = number(1.0) - override fun number(value: Number): MST = MST.Numeric(value) - override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MstAlgebra.number(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) + MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) } -object MSTField : Field { - override val zero: MST = MSTSpace.number(0.0) +object MstField : Field { + override val zero: MST = number(0.0) override val one: MST = number(1.0) - override fun symbol(value: String): MST = MST.Symbolic(value) - override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MstAlgebra.symbol(value) + override fun number(value: Number): MST = MstAlgebra.number(value) override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MstSpace.number(k)) override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MSTAlgebra.binaryOperation(operation, left, right) + MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) } diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt new file mode 100644 index 000000000..1468c3ad4 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt @@ -0,0 +1,55 @@ +package scientifik.kmath.ast + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.FunctionalExpressionField +import scientifik.kmath.expressions.FunctionalExpressionRing +import scientifik.kmath.expressions.FunctionalExpressionSpace +import scientifik.kmath.operations.* + +/** + * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than ASM-generated expressions. + */ +class MstExpression(val algebra: Algebra, val mst: MST) : Expression { + + /** + * Substitute algebra raw value + */ + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { + override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) + override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: T, right: T): T = + algebra.binaryOperation(operation, left, right) + + override fun number(value: Number): T = if (algebra is NumericAlgebra) + algebra.number(value) + else + error("Numeric nodes are not supported by $this") + } + + override fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) +} + + +inline fun , E : Algebra> A.mst( + mstAlgebra: E, + block: E.() -> MST +): MstExpression = MstExpression(this, mstAlgebra.block()) + +inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression = + MstExpression(this, MstSpace.block()) + +inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression = + MstExpression(this, MstRing.block()) + +inline fun Field.mstInField(block: MstField.() -> MST): MstExpression = + MstExpression(this, MstField.block()) + +inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression = + algebra.mstInSpace(block) + +inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression = + algebra.mstInRing(block) + +inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression = + algebra.mstInField(block) \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index af39d9091..468ed01ba 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -6,7 +6,7 @@ import scientifik.kmath.asm.internal.buildExpectationStack import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST -import scientifik.kmath.ast.MSTExpression +import scientifik.kmath.ast.MstExpression import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.NumericAlgebra @@ -80,6 +80,6 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) /** - * Optimize performance of an [MSTExpression] using ASM codegen + * Optimize performance of an [MstExpression] using ASM codegen */ -inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) +inline fun MstExpression.compile(): Expression = mst.compileWith(T::class, algebra)