forked from kscience/kmath
Refactor toSFun, update KG, delete KMath algebra protocol, update DifferentiableMstExpr.
This commit is contained in:
parent
7f8abbdd20
commit
6f0f6577de
@ -3,6 +3,6 @@ plugins {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api("com.github.breandan:kotlingrad:0.3.2")
|
api("com.github.breandan:kotlingrad:0.3.7")
|
||||||
api(project(":kmath-ast"))
|
api(project(":kmath-ast"))
|
||||||
}
|
}
|
||||||
|
@ -69,7 +69,7 @@ public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
|||||||
* @param proto the prototype instance.
|
* @param proto the prototype instance.
|
||||||
* @return a new variable.
|
* @return a new variable.
|
||||||
*/
|
*/
|
||||||
public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, value)
|
public fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
||||||
@ -85,28 +85,28 @@ public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, va
|
|||||||
* @param proto the prototype instance.
|
* @param proto the prototype instance.
|
||||||
* @return a scalar function.
|
* @return a scalar function.
|
||||||
*/
|
*/
|
||||||
public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
|
public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||||
is MST.Numeric -> toSConst()
|
is MST.Numeric -> toSConst()
|
||||||
is MST.Symbolic -> toSVar(proto)
|
is MST.Symbolic -> toSVar()
|
||||||
|
|
||||||
is MST.Unary -> when (operation) {
|
is MST.Unary -> when (operation) {
|
||||||
SpaceOperations.PLUS_OPERATION -> value.toSFun(proto)
|
SpaceOperations.PLUS_OPERATION -> value.toSFun<X>()
|
||||||
SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto)
|
SpaceOperations.MINUS_OPERATION -> (-value).toSFun()
|
||||||
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto))
|
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
|
||||||
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto))
|
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
|
||||||
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto))
|
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
|
||||||
PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt()
|
PowerOperations.SQRT_OPERATION -> value.toSFun<X>().sqrt()
|
||||||
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.toSFun(proto)
|
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
|
||||||
ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln()
|
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
|
||||||
else -> error("Unary operation $operation not defined in $this")
|
else -> error("Unary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Binary -> when (operation) {
|
is MST.Binary -> when (operation) {
|
||||||
SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto)
|
SpaceOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
|
||||||
SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto)
|
SpaceOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
|
||||||
RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto)
|
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
|
||||||
FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto)
|
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
|
||||||
PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst()
|
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
|
||||||
else -> error("Binary operation $operation not defined in $this")
|
else -> error("Binary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -114,23 +114,23 @@ public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
|
|||||||
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
||||||
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
||||||
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
||||||
override val proto: KMathNumber<T, A> by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class KMathProtocol<T, A>(algebra: A) :
|
|
||||||
Protocol<KMathNumber<T, A>>(KMathNumber(algebra, algebra.number(Double.NaN)))
|
|
||||||
where T : Number, A : NumericAlgebra<T>
|
|
||||||
|
|
||||||
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
|
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
|
||||||
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
|
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
|
||||||
public val proto by lazy { KMathProtocol(algebra).prototype }
|
|
||||||
public val expr by lazy { MstExpression(algebra, mst) }
|
public val expr by lazy { MstExpression(algebra, mst) }
|
||||||
|
|
||||||
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||||
|
|
||||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
|
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
|
||||||
val sfun = mst.toSFun(proto)
|
TODO()
|
||||||
val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) }
|
}
|
||||||
|
|
||||||
|
public fun derivativeOrNull(orders: List<Symbol>): Expression<T> {
|
||||||
|
orders.map { MstAlgebra.symbol(it.identity).toSVar<KMathNumber<T, A>>() }
|
||||||
|
.fold<SVar<KMathNumber<T, A>>, SFun<KMathNumber<T, A>>>(mst.toSFun()) { result, sVar -> result.d(sVar) }
|
||||||
|
.toMst()
|
||||||
|
|
||||||
TODO()
|
TODO()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user