Fix typo, introduce KG protocol delegating to algebra
This commit is contained in:
parent
2b7803290f
commit
7f8abbdd20
@ -2,10 +2,10 @@ package kscience.kmath.ast
|
||||
|
||||
import edu.umontreal.kotlingrad.experimental.DoublePrecision
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.kotlingrad.toMst
|
||||
import kscience.kmath.kotlingrad.tSFun
|
||||
import kscience.kmath.kotlingrad.toSVar
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.kotlingrad.toMst
|
||||
import kscience.kmath.kotlingrad.toSFun
|
||||
import kscience.kmath.kotlingrad.toSVar
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
@ -15,7 +15,7 @@ import kscience.kmath.operations.RealField
|
||||
fun main() {
|
||||
val proto = DoublePrecision.prototype
|
||||
val x by MstAlgebra.symbol("x").toSVar(proto)
|
||||
val quadratic = "x^2-4*x-44".parseMath().tSFun(proto)
|
||||
val quadratic = "x^2-4*x-44".parseMath().toSFun(proto)
|
||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||
|
@ -2,8 +2,13 @@ package kscience.kmath.kotlingrad
|
||||
|
||||
import edu.umontreal.kotlingrad.experimental.*
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MstAlgebra
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.ast.MstExtendedField
|
||||
import kscience.kmath.ast.MstExtendedField.unaryMinus
|
||||
import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kscience.kmath.operations.*
|
||||
|
||||
/**
|
||||
@ -80,28 +85,52 @@ public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, va
|
||||
* @param proto the prototype instance.
|
||||
* @return a scalar function.
|
||||
*/
|
||||
public fun <X : SFun<X>> MST.tSFun(proto: X): SFun<X> = when (this) {
|
||||
public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
|
||||
is MST.Numeric -> toSConst()
|
||||
is MST.Symbolic -> toSVar(proto)
|
||||
|
||||
is MST.Unary -> when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> value.tSFun(proto)
|
||||
SpaceOperations.MINUS_OPERATION -> -value.tSFun(proto)
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(value.tSFun(proto))
|
||||
TrigonometricOperations.COS_OPERATION -> cos(value.tSFun(proto))
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(value.tSFun(proto))
|
||||
PowerOperations.SQRT_OPERATION -> value.tSFun(proto).sqrt()
|
||||
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.tSFun(proto)
|
||||
ExponentialOperations.LN_OPERATION -> value.tSFun(proto).ln()
|
||||
SpaceOperations.PLUS_OPERATION -> value.toSFun(proto)
|
||||
SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto)
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto))
|
||||
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto))
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto))
|
||||
PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt()
|
||||
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.toSFun(proto)
|
||||
ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln()
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
}
|
||||
|
||||
is MST.Binary -> when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left.tSFun(proto) + right.tSFun(proto)
|
||||
SpaceOperations.MINUS_OPERATION -> left.tSFun(proto) - right.tSFun(proto)
|
||||
RingOperations.TIMES_OPERATION -> left.tSFun(proto) * right.tSFun(proto)
|
||||
FieldOperations.DIV_OPERATION -> left.tSFun(proto) / right.tSFun(proto)
|
||||
PowerOperations.POW_OPERATION -> left.tSFun(proto) pow (right as MST.Numeric).toSConst()
|
||||
SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto)
|
||||
SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto)
|
||||
RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto)
|
||||
FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto)
|
||||
PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst()
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
}
|
||||
|
||||
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
||||
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
||||
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
||||
override val proto: KMathNumber<T, A> by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) }
|
||||
}
|
||||
|
||||
public class KMathProtocol<T, A>(algebra: A) :
|
||||
Protocol<KMathNumber<T, A>>(KMathNumber(algebra, algebra.number(Double.NaN)))
|
||||
where T : Number, A : NumericAlgebra<T>
|
||||
|
||||
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
|
||||
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
|
||||
public val proto by lazy { KMathProtocol(algebra).prototype }
|
||||
public val expr by lazy { MstExpression(algebra, mst) }
|
||||
|
||||
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||
|
||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
|
||||
val sfun = mst.toSFun(proto)
|
||||
val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) }
|
||||
TODO()
|
||||
}
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ internal class AdaptingTests {
|
||||
fun symbol() {
|
||||
val c1 = MstAlgebra.symbol("x")
|
||||
assertTrue(c1.toSVar(proto).name == "x")
|
||||
val c2 = "kitten".parseMath().tSFun(proto)
|
||||
val c2 = "kitten".parseMath().toSFun(proto)
|
||||
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
||||
}
|
||||
|
||||
@ -27,15 +27,15 @@ internal class AdaptingTests {
|
||||
fun number() {
|
||||
val c1 = MstAlgebra.number(12354324)
|
||||
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
|
||||
val c2 = "0.234".parseMath().tSFun(proto)
|
||||
val c2 = "0.234".parseMath().toSFun(proto)
|
||||
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
||||
val c3 = "1e-3".parseMath().tSFun(proto)
|
||||
val c3 = "1e-3".parseMath().toSFun(proto)
|
||||
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun simpleFunctionShape() {
|
||||
val linear = "2*x+16".parseMath().tSFun(proto)
|
||||
val linear = "2*x+16".parseMath().toSFun(proto)
|
||||
if (linear !is Sum) fail()
|
||||
if (linear.left !is Prod) fail()
|
||||
if (linear.right !is SConst) fail()
|
||||
@ -44,7 +44,7 @@ internal class AdaptingTests {
|
||||
@Test
|
||||
fun simpleFunctionDerivative() {
|
||||
val x = MstAlgebra.symbol("x").toSVar(proto)
|
||||
val quadratic = "x^2-4*x-44".parseMath().tSFun(proto)
|
||||
val quadratic = "x^2-4*x-44".parseMath().toSFun(proto)
|
||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||
@ -53,7 +53,7 @@ internal class AdaptingTests {
|
||||
@Test
|
||||
fun moreComplexDerivative() {
|
||||
val x = MstAlgebra.symbol("x").toSVar(proto)
|
||||
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().tSFun(proto)
|
||||
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun(proto)
|
||||
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
|
||||
|
||||
val expectedDerivative = MstExpression(
|
||||
|
Loading…
Reference in New Issue
Block a user