From 4bf430b2c07a351d116e383a0d7806802af026c2 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 23:17:54 +0700 Subject: [PATCH] Rename converter functions, add symbol delegate provider for MstAlgebra --- .../kscience/kmath/ast/KotlingradSupport.kt | 8 ++--- .../kmath/ast/kotlingrad/ScalarsAdapters.kt | 36 +++++++++---------- .../kmath/ast/kotlingrad/AdaptingTests.kt | 20 +++++------ .../kotlin/kscience/kmath/ast/extensions.kt | 22 ++++++++++++ 4 files changed, 54 insertions(+), 32 deletions(-) create mode 100644 kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt index e63b0c9c0..c2e8456da 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -3,8 +3,8 @@ package kscience.kmath.ast import edu.umontreal.kotlingrad.experimental.DoublePrecision import kscience.kmath.asm.compile import kscience.kmath.ast.kotlingrad.mst -import kscience.kmath.ast.kotlingrad.sfun -import kscience.kmath.ast.kotlingrad.svar +import kscience.kmath.ast.kotlingrad.sFun +import kscience.kmath.ast.kotlingrad.sVar import kscience.kmath.expressions.invoke import kscience.kmath.operations.RealField @@ -14,8 +14,8 @@ import kscience.kmath.operations.RealField */ fun main() { val proto = DoublePrecision.prototype - val x by MstAlgebra.symbol("x").svar(proto) - val quadratic = "x^2-4*x-44".parseMath().sfun(proto) + val x by MstAlgebra.symbol("x").sVar(proto) + val quadratic = "x^2-4*x-44".parseMath().sFun(proto) val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt index 16c96646a..3e7cd0439 100644 --- a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt +++ b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt @@ -55,7 +55,7 @@ public fun > SFun.mst(): MST = MstExtendedField { * @receiver the node. * @return a new constant. */ -public fun > MST.Numeric.sconst(): SConst = SConst(value) +public fun > MST.Numeric.sConst(): SConst = SConst(value) /** * Maps [MST.Symbolic] to [SVar] directly. @@ -64,7 +64,7 @@ public fun > MST.Numeric.sconst(): SConst = SConst(value) * @param proto the prototype instance. * @return a new variable. */ -public fun > MST.Symbolic.svar(proto: X): SVar = SVar(proto, value) +public fun > MST.Symbolic.sVar(proto: X): SVar = SVar(proto, value) /** * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. @@ -80,28 +80,28 @@ public fun > MST.Symbolic.svar(proto: X): SVar = SVar(proto, valu * @param proto the prototype instance. * @return a scalar function. */ -public fun > MST.sfun(proto: X): SFun = when (this) { - is MST.Numeric -> sconst() - is MST.Symbolic -> svar(proto) +public fun > MST.sFun(proto: X): SFun = when (this) { + is MST.Numeric -> sConst() + is MST.Symbolic -> sVar(proto) is MST.Unary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> value.sfun(proto) - SpaceOperations.MINUS_OPERATION -> Negative(value.sfun(proto)) - TrigonometricOperations.SIN_OPERATION -> Sine(value.sfun(proto)) - TrigonometricOperations.COS_OPERATION -> Cosine(value.sfun(proto)) - TrigonometricOperations.TAN_OPERATION -> Tangent(value.sfun(proto)) - PowerOperations.SQRT_OPERATION -> Power(value.sfun(proto), SConst(0.5)) - ExponentialOperations.EXP_OPERATION -> Power(value.sfun(proto), E()) - ExponentialOperations.LN_OPERATION -> Log(value.sfun(proto)) + SpaceOperations.PLUS_OPERATION -> value.sFun(proto) + SpaceOperations.MINUS_OPERATION -> Negative(value.sFun(proto)) + TrigonometricOperations.SIN_OPERATION -> Sine(value.sFun(proto)) + TrigonometricOperations.COS_OPERATION -> Cosine(value.sFun(proto)) + TrigonometricOperations.TAN_OPERATION -> Tangent(value.sFun(proto)) + PowerOperations.SQRT_OPERATION -> Power(value.sFun(proto), SConst(0.5)) + ExponentialOperations.EXP_OPERATION -> Power(value.sFun(proto), E()) + ExponentialOperations.LN_OPERATION -> Log(value.sFun(proto)) else -> error("Unary operation $operation not defined in $this") } is MST.Binary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> Sum(left.sfun(proto), right.sfun(proto)) - SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto))) - RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto)) - FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One()))) - PowerOperations.POW_OPERATION -> Power(left.sfun(proto), SConst((right as MST.Numeric).value)) + SpaceOperations.PLUS_OPERATION -> Sum(left.sFun(proto), right.sFun(proto)) + SpaceOperations.MINUS_OPERATION -> Sum(left.sFun(proto), Negative(right.sFun(proto))) + RingOperations.TIMES_OPERATION -> Prod(left.sFun(proto), right.sFun(proto)) + FieldOperations.DIV_OPERATION -> Prod(left.sFun(proto), Power(right.sFun(proto), Negative(One()))) + PowerOperations.POW_OPERATION -> Power(left.sFun(proto), SConst((right as MST.Numeric).value)) else -> error("Binary operation $operation not defined in $this") } } diff --git a/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt b/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt index 94d25e411..c3c4602ad 100644 --- a/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt +++ b/kmath-ast-kotlingrad/src/test/kotlin/kscience/kmath/ast/kotlingrad/AdaptingTests.kt @@ -18,24 +18,24 @@ internal class AdaptingTests { @Test fun symbol() { val c1 = MstAlgebra.symbol("x") - assertTrue(c1.svar(proto).name == "x") - val c2 = "kitten".parseMath().sfun(proto) + assertTrue(c1.sVar(proto).name == "x") + val c2 = "kitten".parseMath().sFun(proto) if (c2 is SVar) assertTrue(c2.name == "kitten") else fail() } @Test fun number() { val c1 = MstAlgebra.number(12354324) - assertTrue(c1.sconst().doubleValue == 12354324.0) - val c2 = "0.234".parseMath().sfun(proto) + assertTrue(c1.sConst().doubleValue == 12354324.0) + val c2 = "0.234".parseMath().sFun(proto) if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail() - val c3 = "1e-3".parseMath().sfun(proto) + val c3 = "1e-3".parseMath().sFun(proto) if (c3 is SConst) assertEquals(0.001, c3.value) else fail() } @Test fun simpleFunctionShape() { - val linear = "2*x+16".parseMath().sfun(proto) + val linear = "2*x+16".parseMath().sFun(proto) if (linear !is Sum) fail() if (linear.left !is Prod) fail() if (linear.right !is SConst) fail() @@ -43,8 +43,8 @@ internal class AdaptingTests { @Test fun simpleFunctionDerivative() { - val x = MstAlgebra.symbol("x").svar(proto) - val quadratic = "x^2-4*x-44".parseMath().sfun(proto) + val x = MstAlgebra.symbol("x").sVar(proto) + val quadratic = "x^2-4*x-44".parseMath().sFun(proto) val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile() val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) @@ -52,8 +52,8 @@ internal class AdaptingTests { @Test fun moreComplexDerivative() { - val x = MstAlgebra.symbol("x").svar(proto) - val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sfun(proto) + val x = MstAlgebra.symbol("x").sVar(proto) + val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sFun(proto) val actualDerivative = MstExpression(RealField, composition.d(x).mst()).compile() val expectedDerivative = MstExpression( diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt new file mode 100644 index 000000000..cba4bbb13 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/extensions.kt @@ -0,0 +1,22 @@ +package kscience.kmath.ast + +import kscience.kmath.operations.Algebra +import kotlin.properties.ReadOnlyProperty +import kotlin.reflect.KProperty + +/** + * Stores `provideDelegate` method returning property of [MST.Symbolic]. + */ +public object MstSymbolDelegateProvider { + /** + * Returns [ReadOnlyProperty] of [MST.Symbolic] with its value equal to the name of the property. + */ + public operator fun provideDelegate(thisRef: Any?, prop: KProperty<*>): ReadOnlyProperty = + ReadOnlyProperty { _, property -> MST.Symbolic(property.name) } +} + +/** + * Returns [MstSymbolDelegateProvider]. + */ +public val Algebra.symbol: MstSymbolDelegateProvider + get() = MstSymbolDelegateProvider