From 5de9d69237118889b74d4a5074b48db609db92ae Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 12 Oct 2020 21:06:15 +0700 Subject: [PATCH] Add more fine-grained converters from MST to SVar and SConst --- .../kmath/ast/kotlingrad/ScalarsAdapters.kt | 70 ++++++++++++------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt index 741a2534c..3a1191e6a 100644 --- a/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt +++ b/kmath-ast-kotlingrad/src/main/kotlin/kscience/kmath/ast/kotlingrad/ScalarsAdapters.kt @@ -8,8 +8,8 @@ import kscience.kmath.operations.* /** * Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then. * - * @receiver a scalar function. - * @return the [MST]. + * @receiver the scalar function. + * @return a node. */ public fun > SFun.mst(): MST = MstExtendedField { when (this@mst) { @@ -30,36 +30,52 @@ public fun > SFun.mst(): MST = MstExtendedField { } } +/** + * Maps [MST.Numeric] to [SConst] directly. + * + * @receiver the node. + * @return a new constant. + */ +public fun > MST.Numeric.sconst(): SConst = SConst(value) + +/** + * Maps [MST.Symbolic] to [SVar] directly. + * + * @receiver the node. + * @param proto the prototype instance. + * @return a new variable. + */ +public fun > MST.Symbolic.svar(proto: X): SVar = SVar(proto, value) + /** * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. * - * @receiver an [MST]. - * @return the scalar function. + * @receiver the node. + * @param proto the prototype instance. + * @return a scalar function. */ -public fun > MST.sfun(proto: X): SFun { - return when (this) { - is MST.Numeric -> SConst(value) - is MST.Symbolic -> SVar(proto, value) +public fun > MST.sfun(proto: X): SFun = when (this) { + is MST.Numeric -> sconst() + is MST.Symbolic -> svar(proto) - is MST.Unary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> value.sfun(proto) - SpaceOperations.MINUS_OPERATION -> Negative(value.sfun(proto)) - TrigonometricOperations.SIN_OPERATION -> Sine(value.sfun(proto)) - TrigonometricOperations.COS_OPERATION -> Cosine(value.sfun(proto)) - TrigonometricOperations.TAN_OPERATION -> Tangent(value.sfun(proto)) - PowerOperations.SQRT_OPERATION -> Power(value.sfun(proto), SConst(0.5)) - ExponentialOperations.EXP_OPERATION -> Power(value.sfun(proto), E()) - ExponentialOperations.LN_OPERATION -> Log(value.sfun(proto)) - else -> error("Unary operation $operation not defined in $this") - } + is MST.Unary -> when (operation) { + SpaceOperations.PLUS_OPERATION -> value.sfun(proto) + SpaceOperations.MINUS_OPERATION -> Negative(value.sfun(proto)) + TrigonometricOperations.SIN_OPERATION -> Sine(value.sfun(proto)) + TrigonometricOperations.COS_OPERATION -> Cosine(value.sfun(proto)) + TrigonometricOperations.TAN_OPERATION -> Tangent(value.sfun(proto)) + PowerOperations.SQRT_OPERATION -> Power(value.sfun(proto), SConst(0.5)) + ExponentialOperations.EXP_OPERATION -> Power(value.sfun(proto), E()) + ExponentialOperations.LN_OPERATION -> Log(value.sfun(proto)) + else -> error("Unary operation $operation not defined in $this") + } - is MST.Binary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> Sum(left.sfun(proto), right.sfun(proto)) - SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto))) - RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto)) - FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One()))) - PowerOperations.POW_OPERATION -> Power(left.sfun(proto), SConst((right as MST.Numeric).value)) - else -> error("Binary operation $operation not defined in $this") - } + is MST.Binary -> when (operation) { + SpaceOperations.PLUS_OPERATION -> Sum(left.sfun(proto), right.sfun(proto)) + SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto))) + RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto)) + FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One()))) + PowerOperations.POW_OPERATION -> Power(left.sfun(proto), SConst((right as MST.Numeric).value)) + else -> error("Binary operation $operation not defined in $this") } }