Fix typo, introduce KG protocol delegating to algebra

This commit is contained in:
Iaroslav Postovalov 2020-10-29 02:22:34 +07:00
parent 2b7803290f
commit 7f8abbdd20
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
3 changed files with 53 additions and 24 deletions

View File

@ -2,10 +2,10 @@ package kscience.kmath.ast
import edu.umontreal.kotlingrad.experimental.DoublePrecision
import kscience.kmath.asm.compile
import kscience.kmath.kotlingrad.toMst
import kscience.kmath.kotlingrad.tSFun
import kscience.kmath.kotlingrad.toSVar
import kscience.kmath.expressions.invoke
import kscience.kmath.kotlingrad.toMst
import kscience.kmath.kotlingrad.toSFun
import kscience.kmath.kotlingrad.toSVar
import kscience.kmath.operations.RealField
/**
@ -15,7 +15,7 @@ import kscience.kmath.operations.RealField
fun main() {
val proto = DoublePrecision.prototype
val x by MstAlgebra.symbol("x").toSVar(proto)
val quadratic = "x^2-4*x-44".parseMath().tSFun(proto)
val quadratic = "x^2-4*x-44".parseMath().toSFun(proto)
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))

View File

@ -2,8 +2,13 @@ package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.*
import kscience.kmath.ast.MST
import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExpression
import kscience.kmath.ast.MstExtendedField
import kscience.kmath.ast.MstExtendedField.unaryMinus
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.Symbol
import kscience.kmath.operations.*
/**
@ -80,28 +85,52 @@ public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, va
* @param proto the prototype instance.
* @return a scalar function.
*/
public fun <X : SFun<X>> MST.tSFun(proto: X): SFun<X> = when (this) {
public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
is MST.Numeric -> toSConst()
is MST.Symbolic -> toSVar(proto)
is MST.Unary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> value.tSFun(proto)
SpaceOperations.MINUS_OPERATION -> -value.tSFun(proto)
TrigonometricOperations.SIN_OPERATION -> sin(value.tSFun(proto))
TrigonometricOperations.COS_OPERATION -> cos(value.tSFun(proto))
TrigonometricOperations.TAN_OPERATION -> tan(value.tSFun(proto))
PowerOperations.SQRT_OPERATION -> value.tSFun(proto).sqrt()
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.tSFun(proto)
ExponentialOperations.LN_OPERATION -> value.tSFun(proto).ln()
SpaceOperations.PLUS_OPERATION -> value.toSFun(proto)
SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto)
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto))
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto))
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto))
PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt()
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.toSFun(proto)
ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln()
else -> error("Unary operation $operation not defined in $this")
}
is MST.Binary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> left.tSFun(proto) + right.tSFun(proto)
SpaceOperations.MINUS_OPERATION -> left.tSFun(proto) - right.tSFun(proto)
RingOperations.TIMES_OPERATION -> left.tSFun(proto) * right.tSFun(proto)
FieldOperations.DIV_OPERATION -> left.tSFun(proto) / right.tSFun(proto)
PowerOperations.POW_OPERATION -> left.tSFun(proto) pow (right as MST.Numeric).toSConst()
SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto)
SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto)
RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto)
FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto)
PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst()
else -> error("Binary operation $operation not defined in $this")
}
}
public class KMathNumber<T, A>(public val algebra: A, value: T) :
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
override val proto: KMathNumber<T, A> by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) }
}
public class KMathProtocol<T, A>(algebra: A) :
Protocol<KMathNumber<T, A>>(KMathNumber(algebra, algebra.number(Double.NaN)))
where T : Number, A : NumericAlgebra<T>
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
public val proto by lazy { KMathProtocol(algebra).prototype }
public val expr by lazy { MstExpression(algebra, mst) }
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
val sfun = mst.toSFun(proto)
val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) }
TODO()
}
}

View File

@ -19,7 +19,7 @@ internal class AdaptingTests {
fun symbol() {
val c1 = MstAlgebra.symbol("x")
assertTrue(c1.toSVar(proto).name == "x")
val c2 = "kitten".parseMath().tSFun(proto)
val c2 = "kitten".parseMath().toSFun(proto)
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
}
@ -27,15 +27,15 @@ internal class AdaptingTests {
fun number() {
val c1 = MstAlgebra.number(12354324)
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
val c2 = "0.234".parseMath().tSFun(proto)
val c2 = "0.234".parseMath().toSFun(proto)
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
val c3 = "1e-3".parseMath().tSFun(proto)
val c3 = "1e-3".parseMath().toSFun(proto)
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
}
@Test
fun simpleFunctionShape() {
val linear = "2*x+16".parseMath().tSFun(proto)
val linear = "2*x+16".parseMath().toSFun(proto)
if (linear !is Sum) fail()
if (linear.left !is Prod) fail()
if (linear.right !is SConst) fail()
@ -44,7 +44,7 @@ internal class AdaptingTests {
@Test
fun simpleFunctionDerivative() {
val x = MstAlgebra.symbol("x").toSVar(proto)
val quadratic = "x^2-4*x-44".parseMath().tSFun(proto)
val quadratic = "x^2-4*x-44".parseMath().toSFun(proto)
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
@ -53,7 +53,7 @@ internal class AdaptingTests {
@Test
fun moreComplexDerivative() {
val x = MstAlgebra.symbol("x").toSVar(proto)
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().tSFun(proto)
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun(proto)
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
val expectedDerivative = MstExpression(