Refactor toSFun, update KG, delete KMath algebra protocol, update DifferentiableMstExpr.
This commit is contained in:
parent
7f8abbdd20
commit
6f0f6577de
@ -3,6 +3,6 @@ plugins {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api("com.github.breandan:kotlingrad:0.3.2")
|
||||
api("com.github.breandan:kotlingrad:0.3.7")
|
||||
api(project(":kmath-ast"))
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
||||
* @param proto the prototype instance.
|
||||
* @return a new variable.
|
||||
*/
|
||||
public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, value)
|
||||
public fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
|
||||
|
||||
/**
|
||||
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
||||
@ -85,28 +85,28 @@ public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, va
|
||||
* @param proto the prototype instance.
|
||||
* @return a scalar function.
|
||||
*/
|
||||
public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
|
||||
public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||
is MST.Numeric -> toSConst()
|
||||
is MST.Symbolic -> toSVar(proto)
|
||||
is MST.Symbolic -> toSVar()
|
||||
|
||||
is MST.Unary -> when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> value.toSFun(proto)
|
||||
SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto)
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto))
|
||||
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto))
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto))
|
||||
PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt()
|
||||
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.toSFun(proto)
|
||||
ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln()
|
||||
SpaceOperations.PLUS_OPERATION -> value.toSFun<X>()
|
||||
SpaceOperations.MINUS_OPERATION -> (-value).toSFun()
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
|
||||
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
|
||||
PowerOperations.SQRT_OPERATION -> value.toSFun<X>().sqrt()
|
||||
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
|
||||
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
}
|
||||
|
||||
is MST.Binary -> when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto)
|
||||
SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto)
|
||||
RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto)
|
||||
FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto)
|
||||
PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst()
|
||||
SpaceOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
|
||||
SpaceOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
|
||||
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
|
||||
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
|
||||
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
}
|
||||
@ -114,23 +114,23 @@ public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
|
||||
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
||||
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
||||
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
||||
override val proto: KMathNumber<T, A> by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) }
|
||||
}
|
||||
|
||||
public class KMathProtocol<T, A>(algebra: A) :
|
||||
Protocol<KMathNumber<T, A>>(KMathNumber(algebra, algebra.number(Double.NaN)))
|
||||
where T : Number, A : NumericAlgebra<T>
|
||||
|
||||
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
|
||||
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
|
||||
public val proto by lazy { KMathProtocol(algebra).prototype }
|
||||
public val expr by lazy { MstExpression(algebra, mst) }
|
||||
|
||||
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||
|
||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
|
||||
val sfun = mst.toSFun(proto)
|
||||
val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) }
|
||||
TODO()
|
||||
}
|
||||
|
||||
public fun derivativeOrNull(orders: List<Symbol>): Expression<T> {
|
||||
orders.map { MstAlgebra.symbol(it.identity).toSVar<KMathNumber<T, A>>() }
|
||||
.fold<SVar<KMathNumber<T, A>>, SFun<KMathNumber<T, A>>>(mst.toSFun()) { result, sVar -> result.d(sVar) }
|
||||
.toMst()
|
||||
|
||||
TODO()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user