Refactor toSFun, update KG, delete KMath algebra protocol, update DifferentiableMstExpr.

This commit is contained in:
Iaroslav Postovalov 2020-10-29 13:34:12 +07:00
parent 7f8abbdd20
commit 6f0f6577de
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
2 changed files with 26 additions and 26 deletions

View File

@ -3,6 +3,6 @@ plugins {
} }
dependencies { dependencies {
api("com.github.breandan:kotlingrad:0.3.2") api("com.github.breandan:kotlingrad:0.3.7")
api(project(":kmath-ast")) api(project(":kmath-ast"))
} }

View File

@ -69,7 +69,7 @@ public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
* @param proto the prototype instance. * @param proto the prototype instance.
* @return a new variable. * @return a new variable.
*/ */
public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, value) public fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
/** /**
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
@ -85,28 +85,28 @@ public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, va
* @param proto the prototype instance. * @param proto the prototype instance.
* @return a scalar function. * @return a scalar function.
*/ */
public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) { public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
is MST.Numeric -> toSConst() is MST.Numeric -> toSConst()
is MST.Symbolic -> toSVar(proto) is MST.Symbolic -> toSVar()
is MST.Unary -> when (operation) { is MST.Unary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> value.toSFun(proto) SpaceOperations.PLUS_OPERATION -> value.toSFun<X>()
SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto) SpaceOperations.MINUS_OPERATION -> (-value).toSFun()
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto)) TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto)) TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto)) TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt() PowerOperations.SQRT_OPERATION -> value.toSFun<X>().sqrt()
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.toSFun(proto) ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln() ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
else -> error("Unary operation $operation not defined in $this") else -> error("Unary operation $operation not defined in $this")
} }
is MST.Binary -> when (operation) { is MST.Binary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto) SpaceOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto) SpaceOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto) RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto) FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst() PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
else -> error("Binary operation $operation not defined in $this") else -> error("Binary operation $operation not defined in $this")
} }
} }
@ -114,23 +114,23 @@ public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
public class KMathNumber<T, A>(public val algebra: A, value: T) : public class KMathNumber<T, A>(public val algebra: A, value: T) :
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> { RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number)) public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
override val proto: KMathNumber<T, A> by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) }
} }
public class KMathProtocol<T, A>(algebra: A) :
Protocol<KMathNumber<T, A>>(KMathNumber(algebra, algebra.number(Double.NaN)))
where T : Number, A : NumericAlgebra<T>
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) : public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number { DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
public val proto by lazy { KMathProtocol(algebra).prototype }
public val expr by lazy { MstExpression(algebra, mst) } public val expr by lazy { MstExpression(algebra, mst) }
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments) public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> { public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
val sfun = mst.toSFun(proto) TODO()
val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) } }
public fun derivativeOrNull(orders: List<Symbol>): Expression<T> {
orders.map { MstAlgebra.symbol(it.identity).toSVar<KMathNumber<T, A>>() }
.fold<SVar<KMathNumber<T, A>>, SFun<KMathNumber<T, A>>>(mst.toSFun()) { result, sVar -> result.d(sVar) }
.toMst()
TODO() TODO()
} }
} }