Refactor toSFun, update KG, delete KMath algebra protocol, update DifferentiableMstExpr.

This commit is contained in:
Iaroslav Postovalov 2020-10-29 13:34:12 +07:00
parent 7f8abbdd20
commit 6f0f6577de
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
2 changed files with 26 additions and 26 deletions

View File

@ -3,6 +3,6 @@ plugins {
}
dependencies {
api("com.github.breandan:kotlingrad:0.3.2")
api("com.github.breandan:kotlingrad:0.3.7")
api(project(":kmath-ast"))
}

View File

@ -69,7 +69,7 @@ public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
* @param proto the prototype instance.
* @return a new variable.
*/
public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, value)
public fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
/**
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
@ -85,28 +85,28 @@ public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, va
* @param proto the prototype instance.
* @return a scalar function.
*/
public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
is MST.Numeric -> toSConst()
is MST.Symbolic -> toSVar(proto)
is MST.Symbolic -> toSVar()
is MST.Unary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> value.toSFun(proto)
SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto)
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto))
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto))
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto))
PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt()
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.toSFun(proto)
ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln()
SpaceOperations.PLUS_OPERATION -> value.toSFun<X>()
SpaceOperations.MINUS_OPERATION -> (-value).toSFun()
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
PowerOperations.SQRT_OPERATION -> value.toSFun<X>().sqrt()
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
else -> error("Unary operation $operation not defined in $this")
}
is MST.Binary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto)
SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto)
RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto)
FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto)
PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst()
SpaceOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
SpaceOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
else -> error("Binary operation $operation not defined in $this")
}
}
@ -114,23 +114,23 @@ public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
public class KMathNumber<T, A>(public val algebra: A, value: T) :
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
override val proto: KMathNumber<T, A> by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) }
}
public class KMathProtocol<T, A>(algebra: A) :
Protocol<KMathNumber<T, A>>(KMathNumber(algebra, algebra.number(Double.NaN)))
where T : Number, A : NumericAlgebra<T>
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
public val proto by lazy { KMathProtocol(algebra).prototype }
public val expr by lazy { MstExpression(algebra, mst) }
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
val sfun = mst.toSFun(proto)
val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) }
TODO()
}
public fun derivativeOrNull(orders: List<Symbol>): Expression<T> {
orders.map { MstAlgebra.symbol(it.identity).toSVar<KMathNumber<T, A>>() }
.fold<SVar<KMathNumber<T, A>>, SFun<KMathNumber<T, A>>>(mst.toSFun()) { result, sVar -> result.d(sVar) }
.toMst()
TODO()
}
}