diff --git a/kmath-kotlingrad/build.gradle.kts b/kmath-kotlingrad/build.gradle.kts index 0fe6e6b93..f2245c3d5 100644 --- a/kmath-kotlingrad/build.gradle.kts +++ b/kmath-kotlingrad/build.gradle.kts @@ -3,6 +3,6 @@ plugins { } dependencies { - api("com.github.breandan:kotlingrad:0.3.2") + api("com.github.breandan:kotlingrad:0.3.7") api(project(":kmath-ast")) } diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt index a9a8a14b2..24e8377bd 100644 --- a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt @@ -69,7 +69,7 @@ public fun > MST.Numeric.toSConst(): SConst = SConst(value) * @param proto the prototype instance. * @return a new variable. */ -public fun > MST.Symbolic.toSVar(proto: X): SVar = SVar(proto, value) +public fun > MST.Symbolic.toSVar(): SVar = SVar(value) /** * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. @@ -85,28 +85,28 @@ public fun > MST.Symbolic.toSVar(proto: X): SVar = SVar(proto, va * @param proto the prototype instance. * @return a scalar function. */ -public fun > MST.toSFun(proto: X): SFun = when (this) { +public fun > MST.toSFun(): SFun = when (this) { is MST.Numeric -> toSConst() - is MST.Symbolic -> toSVar(proto) + is MST.Symbolic -> toSVar() is MST.Unary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> value.toSFun(proto) - SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto) - TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto)) - TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto)) - TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto)) - PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt() - ExponentialOperations.EXP_OPERATION -> E() pow value.toSFun(proto) - ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln() + SpaceOperations.PLUS_OPERATION -> value.toSFun() + SpaceOperations.MINUS_OPERATION -> (-value).toSFun() + TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun()) + TrigonometricOperations.COS_OPERATION -> cos(value.toSFun()) + TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun()) + PowerOperations.SQRT_OPERATION -> value.toSFun().sqrt() + ExponentialOperations.EXP_OPERATION -> exp(value.toSFun()) + ExponentialOperations.LN_OPERATION -> value.toSFun().ln() else -> error("Unary operation $operation not defined in $this") } is MST.Binary -> when (operation) { - SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto) - SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto) - RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto) - FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto) - PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst() + SpaceOperations.PLUS_OPERATION -> left.toSFun() + right.toSFun() + SpaceOperations.MINUS_OPERATION -> left.toSFun() - right.toSFun() + RingOperations.TIMES_OPERATION -> left.toSFun() * right.toSFun() + FieldOperations.DIV_OPERATION -> left.toSFun() / right.toSFun() + PowerOperations.POW_OPERATION -> left.toSFun() pow (right as MST.Numeric).toSConst() else -> error("Binary operation $operation not defined in $this") } } @@ -114,23 +114,23 @@ public fun > MST.toSFun(proto: X): SFun = when (this) { public class KMathNumber(public val algebra: A, value: T) : RealNumber, T>(value) where T : Number, A : NumericAlgebra { public override fun wrap(number: Number): SConst> = SConst(algebra.number(number)) - override val proto: KMathNumber by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) } } -public class KMathProtocol(algebra: A) : - Protocol>(KMathNumber(algebra, algebra.number(Double.NaN))) - where T : Number, A : NumericAlgebra - public class DifferentiableMstExpression(public val algebra: A, public val mst: MST) : DifferentiableExpression where A : NumericAlgebra, T : Number { - public val proto by lazy { KMathProtocol(algebra).prototype } public val expr by lazy { MstExpression(algebra, mst) } - public override fun invoke(arguments: Map): T = expr(arguments) + public override fun invoke(arguments: Map): T = expr(arguments) public override fun derivativeOrNull(orders: Map): Expression { - val sfun = mst.toSFun(proto) - val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) } + TODO() + } + + public fun derivativeOrNull(orders: List): Expression { + orders.map { MstAlgebra.symbol(it.identity).toSVar>() } + .fold>, SFun>>(mst.toSFun()) { result, sVar -> result.d(sVar) } + .toMst() + TODO() } }