From 57bdee49368208c67eea5259b39223f37fcae386 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 22:34:05 +0700 Subject: [PATCH] Add test, update MstAlgebra a bit to return concrete types --- .../kmath/ast/kotlingrad/ScalarsAdapters.kt | 2 +- .../kmath/ast/kotlingrad/AdaptingTests.kt | 66 ++++++++++ .../kotlin/kscience/kmath/ast/MstAlgebra.kt | 113 +++++++++--------- 3 files changed, 125 insertions(+), 56 deletions(-) create mode 100644 kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt index cfd1f2702..16c96646a 100644 --- a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt +++ b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt @@ -36,7 +36,7 @@ public fun > SFun.mst(): MST = MstExtendedField { is SConst -> number(doubleValue) is Sum -> left.mst() + right.mst() is Prod -> left.mst() * right.mst() - is Power -> power(left.mst(), (right() as SConst<*>).doubleValue) + is Power -> power(left.mst(), (right as SConst<*>).doubleValue) is Negative -> -input.mst() is Log -> ln(left.mst()) / ln(right.mst()) is Sine -> sin(input.mst()) diff --git a/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt b/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt new file mode 100644 index 000000000..94d25e411 --- /dev/null +++ b/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt @@ -0,0 +1,66 @@ +package kscience.kmath.ast.kotlingrad + +import edu.umontreal.kotlingrad.experimental.* +import kscience.kmath.asm.compile +import kscience.kmath.ast.MstAlgebra +import kscience.kmath.ast.MstExpression +import kscience.kmath.ast.parseMath +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.fail + +internal class AdaptingTests { + private val proto: DReal = DoublePrecision.prototype + + @Test + fun symbol() { + val c1 = MstAlgebra.symbol("x") + assertTrue(c1.svar(proto).name == "x") + val c2 = "kitten".parseMath().sfun(proto) + if (c2 is SVar) assertTrue(c2.name == "kitten") else fail() + } + + @Test + fun number() { + val c1 = MstAlgebra.number(12354324) + assertTrue(c1.sconst().doubleValue == 12354324.0) + val c2 = "0.234".parseMath().sfun(proto) + if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail() + val c3 = "1e-3".parseMath().sfun(proto) + if (c3 is SConst) assertEquals(0.001, c3.value) else fail() + } + + @Test + fun simpleFunctionShape() { + val linear = "2*x+16".parseMath().sfun(proto) + if (linear !is Sum) fail() + if (linear.left !is Prod) fail() + if (linear.right !is SConst) fail() + } + + @Test + fun simpleFunctionDerivative() { + val x = MstAlgebra.symbol("x").svar(proto) + val quadratic = "x^2-4*x-44".parseMath().sfun(proto) + val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile() + val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() + assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) + } + + @Test + fun moreComplexDerivative() { + val x = MstAlgebra.symbol("x").svar(proto) + val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sfun(proto) + val actualDerivative = MstExpression(RealField, composition.d(x).mst()).compile() + + val expectedDerivative = MstExpression( + RealField, + "-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath() + ).compile() + + assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1)) + } +} diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index 64a820b20..6ee6ab9af 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -6,14 +6,14 @@ import kscience.kmath.operations.* * [Algebra] over [MST] nodes. */ public object MstAlgebra : NumericAlgebra { - override fun number(value: Number): MST = MST.Numeric(value) + override fun number(value: Number): MST.Numeric = MST.Numeric(value) - override fun symbol(value: String): MST = MST.Symbolic(value) + override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) - override fun unaryOperation(operation: String, arg: MST): MST = + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MST.Unary(operation, arg) - override fun binaryOperation(operation: String, left: MST, right: MST): MST = + override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MST.Binary(operation, left, right) } @@ -21,97 +21,100 @@ public object MstAlgebra : NumericAlgebra { * [Space] over [MST] nodes. */ public object MstSpace : Space, NumericAlgebra { - override val zero: MST = number(0.0) + override val zero: MST.Numeric by lazy { number(0.0) } - override fun number(value: Number): MST = MstAlgebra.number(value) - override fun symbol(value: String): MST = MstAlgebra.symbol(value) - override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) + override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) + override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) - override fun binaryOperation(operation: String, left: MST, right: MST): MST = + override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MstAlgebra.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg) } /** * [Ring] over [MST] nodes. */ public object MstRing : Ring, NumericAlgebra { - override val zero: MST + override val zero: MST.Numeric get() = MstSpace.zero - override val one: MST = number(1.0) - override fun number(value: Number): MST = MstSpace.number(value) - override fun symbol(value: String): MST = MstSpace.symbol(value) - override fun add(a: MST, b: MST): MST = MstSpace.add(a, b) + override val one: MST.Numeric by lazy { number(1.0) } - override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) + override fun number(value: Number): MST.Numeric = MstSpace.number(value) + override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) + override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) + override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) + override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b) - override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = + override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MstSpace.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg) } /** * [Field] over [MST] nodes. */ public object MstField : Field { - public override val zero: MST + public override val zero: MST.Numeric get() = MstRing.zero - public override val one: MST + public override val one: MST.Numeric get() = MstRing.one - public override fun symbol(value: String): MST = MstRing.symbol(value) - public override fun number(value: Number): MST = MstRing.number(value) - public override fun add(a: MST, b: MST): MST = MstRing.add(a, b) - public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) - public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value) + public override fun number(value: Number): MST.Numeric = MstRing.number(value) + public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b) - public override fun binaryOperation(operation: String, left: MST, right: MST): MST = + public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MstRing.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg) } /** * [ExtendedField] over [MST] nodes. */ public object MstExtendedField : ExtendedField { - override val zero: MST + override val zero: MST.Numeric get() = MstField.zero - override val one: MST + override val one: MST.Numeric get() = MstField.one - override fun symbol(value: String): MST = MstField.symbol(value) - override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) - override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) - override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) - override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) - override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) - override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) - override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) - override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) - override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) - override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) - override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) - override fun add(a: MST, b: MST): MST = MstField.add(a, b) - override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) - override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) - override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) - override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) - override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) - override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) + override fun number(value: Number): MST.Numeric = MstField.number(value) + override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) + override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) + override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) + override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) + override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) + override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) + override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) + override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) + override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) + override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) + override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) + override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) + override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) + override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) + override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) + override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) - override fun binaryOperation(operation: String, left: MST, right: MST): MST = + override fun power(arg: MST, pow: Number): MST.Binary = + binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + + override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) + override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + + override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = MstField.binaryOperation(operation, left, right) - override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg) + override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg) }