Fix typo, introduce KG protocol delegating to algebra
This commit is contained in:
parent
2b7803290f
commit
7f8abbdd20
@ -2,10 +2,10 @@ package kscience.kmath.ast
|
|||||||
|
|
||||||
import edu.umontreal.kotlingrad.experimental.DoublePrecision
|
import edu.umontreal.kotlingrad.experimental.DoublePrecision
|
||||||
import kscience.kmath.asm.compile
|
import kscience.kmath.asm.compile
|
||||||
import kscience.kmath.kotlingrad.toMst
|
|
||||||
import kscience.kmath.kotlingrad.tSFun
|
|
||||||
import kscience.kmath.kotlingrad.toSVar
|
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
|
import kscience.kmath.kotlingrad.toMst
|
||||||
|
import kscience.kmath.kotlingrad.toSFun
|
||||||
|
import kscience.kmath.kotlingrad.toSVar
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -15,7 +15,7 @@ import kscience.kmath.operations.RealField
|
|||||||
fun main() {
|
fun main() {
|
||||||
val proto = DoublePrecision.prototype
|
val proto = DoublePrecision.prototype
|
||||||
val x by MstAlgebra.symbol("x").toSVar(proto)
|
val x by MstAlgebra.symbol("x").toSVar(proto)
|
||||||
val quadratic = "x^2-4*x-44".parseMath().tSFun(proto)
|
val quadratic = "x^2-4*x-44".parseMath().toSFun(proto)
|
||||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||||
|
@ -2,8 +2,13 @@ package kscience.kmath.kotlingrad
|
|||||||
|
|
||||||
import edu.umontreal.kotlingrad.experimental.*
|
import edu.umontreal.kotlingrad.experimental.*
|
||||||
import kscience.kmath.ast.MST
|
import kscience.kmath.ast.MST
|
||||||
|
import kscience.kmath.ast.MstAlgebra
|
||||||
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.ast.MstExtendedField
|
import kscience.kmath.ast.MstExtendedField
|
||||||
import kscience.kmath.ast.MstExtendedField.unaryMinus
|
import kscience.kmath.ast.MstExtendedField.unaryMinus
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import kscience.kmath.expressions.Expression
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
import kscience.kmath.operations.*
|
import kscience.kmath.operations.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -80,28 +85,52 @@ public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, va
|
|||||||
* @param proto the prototype instance.
|
* @param proto the prototype instance.
|
||||||
* @return a scalar function.
|
* @return a scalar function.
|
||||||
*/
|
*/
|
||||||
public fun <X : SFun<X>> MST.tSFun(proto: X): SFun<X> = when (this) {
|
public fun <X : SFun<X>> MST.toSFun(proto: X): SFun<X> = when (this) {
|
||||||
is MST.Numeric -> toSConst()
|
is MST.Numeric -> toSConst()
|
||||||
is MST.Symbolic -> toSVar(proto)
|
is MST.Symbolic -> toSVar(proto)
|
||||||
|
|
||||||
is MST.Unary -> when (operation) {
|
is MST.Unary -> when (operation) {
|
||||||
SpaceOperations.PLUS_OPERATION -> value.tSFun(proto)
|
SpaceOperations.PLUS_OPERATION -> value.toSFun(proto)
|
||||||
SpaceOperations.MINUS_OPERATION -> -value.tSFun(proto)
|
SpaceOperations.MINUS_OPERATION -> -value.toSFun(proto)
|
||||||
TrigonometricOperations.SIN_OPERATION -> sin(value.tSFun(proto))
|
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun(proto))
|
||||||
TrigonometricOperations.COS_OPERATION -> cos(value.tSFun(proto))
|
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun(proto))
|
||||||
TrigonometricOperations.TAN_OPERATION -> tan(value.tSFun(proto))
|
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun(proto))
|
||||||
PowerOperations.SQRT_OPERATION -> value.tSFun(proto).sqrt()
|
PowerOperations.SQRT_OPERATION -> value.toSFun(proto).sqrt()
|
||||||
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.tSFun(proto)
|
ExponentialOperations.EXP_OPERATION -> E<X>() pow value.toSFun(proto)
|
||||||
ExponentialOperations.LN_OPERATION -> value.tSFun(proto).ln()
|
ExponentialOperations.LN_OPERATION -> value.toSFun(proto).ln()
|
||||||
else -> error("Unary operation $operation not defined in $this")
|
else -> error("Unary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Binary -> when (operation) {
|
is MST.Binary -> when (operation) {
|
||||||
SpaceOperations.PLUS_OPERATION -> left.tSFun(proto) + right.tSFun(proto)
|
SpaceOperations.PLUS_OPERATION -> left.toSFun(proto) + right.toSFun(proto)
|
||||||
SpaceOperations.MINUS_OPERATION -> left.tSFun(proto) - right.tSFun(proto)
|
SpaceOperations.MINUS_OPERATION -> left.toSFun(proto) - right.toSFun(proto)
|
||||||
RingOperations.TIMES_OPERATION -> left.tSFun(proto) * right.tSFun(proto)
|
RingOperations.TIMES_OPERATION -> left.toSFun(proto) * right.toSFun(proto)
|
||||||
FieldOperations.DIV_OPERATION -> left.tSFun(proto) / right.tSFun(proto)
|
FieldOperations.DIV_OPERATION -> left.toSFun(proto) / right.toSFun(proto)
|
||||||
PowerOperations.POW_OPERATION -> left.tSFun(proto) pow (right as MST.Numeric).toSConst()
|
PowerOperations.POW_OPERATION -> left.toSFun(proto) pow (right as MST.Numeric).toSConst()
|
||||||
else -> error("Binary operation $operation not defined in $this")
|
else -> error("Binary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
||||||
|
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
||||||
|
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
||||||
|
override val proto: KMathNumber<T, A> by lazy { KMathNumber(algebra, algebra.number(Double.NaN)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class KMathProtocol<T, A>(algebra: A) :
|
||||||
|
Protocol<KMathNumber<T, A>>(KMathNumber(algebra, algebra.number(Double.NaN)))
|
||||||
|
where T : Number, A : NumericAlgebra<T>
|
||||||
|
|
||||||
|
public class DifferentiableMstExpression<T, A>(public val algebra: A, public val mst: MST) :
|
||||||
|
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
|
||||||
|
public val proto by lazy { KMathProtocol(algebra).prototype }
|
||||||
|
public val expr by lazy { MstExpression(algebra, mst) }
|
||||||
|
|
||||||
|
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||||
|
|
||||||
|
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T> {
|
||||||
|
val sfun = mst.toSFun(proto)
|
||||||
|
val orders2 = orders.mapKeys { (k, _) -> MstAlgebra.symbol(k.identity).toSVar(proto) }
|
||||||
|
TODO()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -19,7 +19,7 @@ internal class AdaptingTests {
|
|||||||
fun symbol() {
|
fun symbol() {
|
||||||
val c1 = MstAlgebra.symbol("x")
|
val c1 = MstAlgebra.symbol("x")
|
||||||
assertTrue(c1.toSVar(proto).name == "x")
|
assertTrue(c1.toSVar(proto).name == "x")
|
||||||
val c2 = "kitten".parseMath().tSFun(proto)
|
val c2 = "kitten".parseMath().toSFun(proto)
|
||||||
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,15 +27,15 @@ internal class AdaptingTests {
|
|||||||
fun number() {
|
fun number() {
|
||||||
val c1 = MstAlgebra.number(12354324)
|
val c1 = MstAlgebra.number(12354324)
|
||||||
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
|
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
|
||||||
val c2 = "0.234".parseMath().tSFun(proto)
|
val c2 = "0.234".parseMath().toSFun(proto)
|
||||||
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
||||||
val c3 = "1e-3".parseMath().tSFun(proto)
|
val c3 = "1e-3".parseMath().toSFun(proto)
|
||||||
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun simpleFunctionShape() {
|
fun simpleFunctionShape() {
|
||||||
val linear = "2*x+16".parseMath().tSFun(proto)
|
val linear = "2*x+16".parseMath().toSFun(proto)
|
||||||
if (linear !is Sum) fail()
|
if (linear !is Sum) fail()
|
||||||
if (linear.left !is Prod) fail()
|
if (linear.left !is Prod) fail()
|
||||||
if (linear.right !is SConst) fail()
|
if (linear.right !is SConst) fail()
|
||||||
@ -44,7 +44,7 @@ internal class AdaptingTests {
|
|||||||
@Test
|
@Test
|
||||||
fun simpleFunctionDerivative() {
|
fun simpleFunctionDerivative() {
|
||||||
val x = MstAlgebra.symbol("x").toSVar(proto)
|
val x = MstAlgebra.symbol("x").toSVar(proto)
|
||||||
val quadratic = "x^2-4*x-44".parseMath().tSFun(proto)
|
val quadratic = "x^2-4*x-44".parseMath().toSFun(proto)
|
||||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||||
@ -53,7 +53,7 @@ internal class AdaptingTests {
|
|||||||
@Test
|
@Test
|
||||||
fun moreComplexDerivative() {
|
fun moreComplexDerivative() {
|
||||||
val x = MstAlgebra.symbol("x").toSVar(proto)
|
val x = MstAlgebra.symbol("x").toSVar(proto)
|
||||||
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().tSFun(proto)
|
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun(proto)
|
||||||
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
|
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
|
||||||
|
|
||||||
val expectedDerivative = MstExpression(
|
val expectedDerivative = MstExpression(
|
||||||
|
Loading…
Reference in New Issue
Block a user