diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt index 45693645c..4a5900a48 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt @@ -1,4 +1,3 @@ -@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE") package scientifik.kmath.ast @@ -7,9 +6,11 @@ import scientifik.kmath.operations.* object MSTAlgebra : NumericAlgebra { override fun symbol(value: String): MST = MST.Symbolic(value) - override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST = + MST.Unary(operation, arg) - override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right) + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MST.Binary(operation, left, right) override fun number(value: Number): MST = MST.Numeric(value) } @@ -17,17 +18,28 @@ object MSTAlgebra : NumericAlgebra { object MSTSpace : Space, NumericAlgebra by MSTAlgebra { override val zero: MST = number(0.0) - override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + override fun add(a: MST, b: MST): MST = + binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + override fun multiply(a: MST, k: Number): MST = + binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MSTAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } object MSTRing : Ring, Space by MSTSpace { override val one: MST = number(1.0) override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MSTSpace.binaryOperation(operation, left, right) } object MSTField : Field, Ring by MSTRing { override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + override fun binaryOperation(operation: String, left: MST, right: MST): MST = + MSTRing.binaryOperation(operation, left, right) } \ No newline at end of file diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 48e368fc3..46739c49c 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -4,11 +4,10 @@ import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST -import scientifik.kmath.ast.MSTField -import scientifik.kmath.ast.MSTRing -import scientifik.kmath.ast.MSTSpace +import scientifik.kmath.ast.MSTExpression import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.* +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.NumericAlgebra import kotlin.reflect.KClass @@ -88,16 +87,21 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< inline fun Algebra.compile(mst: MST): Expression = mst.compileWith(T::class, this) -inline fun , E : Algebra> A.asm( - mstAlgebra: E, - block: E.() -> MST -): Expression = mstAlgebra.block().compileWith(T::class, this) +/** + * Optimize performance of an [MSTExpression] using ASM codegen + */ +inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) -inline fun > A.asmInSpace(block: MSTSpace.() -> MST): Expression = - MSTSpace.block().compileWith(T::class, this) - -inline fun > A.asmInRing(block: MSTRing.() -> MST): Expression = - MSTRing.block().compileWith(T::class, this) - -inline fun > A.asmInField(block: MSTField.() -> MST): Expression = - MSTField.block().compileWith(T::class, this) +//inline fun , E : Algebra> A.asm( +// mstAlgebra: E, +// block: E.() -> MST +//): Expression = mstAlgebra.block().compileWith(T::class, this) +// +//inline fun > A.asmInSpace(block: MSTSpace.() -> MST): Expression = +// MSTSpace.block().compileWith(T::class, this) +// +//inline fun > A.asmInRing(block: MSTRing.() -> MST): Expression = +// MSTRing.block().compileWith(T::class, this) +// +//inline fun > A.asmInField(block: MSTField.() -> MST): Expression = +// MSTField.block().compileWith(T::class, this)