diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt index 07194a7bb..f741fc8c4 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MSTAlgebra.kt @@ -34,43 +34,40 @@ object MSTSpace : Space, NumericAlgebra { } object MSTRing : Ring, NumericAlgebra { - override fun number(value: Number): MST = MST.Numeric(value) - override fun symbol(value: String): MST = MST.Symbolic(value) - override val zero: MST = MSTSpace.number(0.0) override val one: MST = number(1.0) - override fun add(a: MST, b: MST): MST = - MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + + override fun number(value: Number): MST = MST.Numeric(value) + override fun symbol(value: String): MST = MST.Symbolic(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) - override fun multiply(a: MST, b: MST): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = MSTAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } -object MSTField : Field{ - override fun symbol(value: String): MST = MST.Symbolic(value) - override fun number(value: Number): MST = MST.Numeric(value) - +object MSTField : Field { override val zero: MST = MSTSpace.number(0.0) override val one: MST = number(1.0) - override fun add(a: MST, b: MST): MST = - MSTAlgebra.binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + override fun symbol(value: String): MST = MST.Symbolic(value) + override fun number(value: Number): MST = MST.Numeric(value) + override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: MST, k: Number): MST = - MSTAlgebra.binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) + binaryOperation(RingOperations.TIMES_OPERATION, a, MSTSpace.number(k)) - override fun multiply(a: MST, b: MST): MST = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) - - override fun divide(a: MST, b: MST): MST = - binaryOperation(FieldOperations.DIV_OPERATION, a, b) + override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST = MSTAlgebra.binaryOperation(operation, left, right) + + override fun unaryOperation(operation: String, arg: MST): MST = MSTAlgebra.unaryOperation(operation, arg) } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt similarity index 68% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt rename to kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt index f3b07df56..b571e076f 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt @@ -7,39 +7,40 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -internal class TestSpecialization { +internal class TestAsmSpecialization { @Test fun testUnaryPlus() { val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(2.0, res) + assertEquals(2.0, expr("x" to 2.0)) } @Test fun testUnaryMinus() { val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(-2.0, res) + assertEquals(-2.0, expr("x" to 2.0)) } @Test fun testAdd() { val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(4.0, res) + assertEquals(4.0, expr("x" to 2.0)) + } + + @Test + fun testSine() { + val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 0.0)) } @Test fun testMinus() { val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(0.0, res) + assertEquals(0.0, expr("x" to 2.0)) } @Test fun testDivide() { val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() - val res = expr("x" to 2.0) - assertEquals(1.0, res) + assertEquals(1.0, expr("x" to 2.0)) } }