forked from kscience/kmath
Rename converter functions, add symbol delegate provider for MstAlgebra
This commit is contained in:
parent
54069fd37e
commit
4bf430b2c0
@ -3,8 +3,8 @@ package kscience.kmath.ast
|
||||
import edu.umontreal.kotlingrad.experimental.DoublePrecision
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.ast.kotlingrad.mst
|
||||
import kscience.kmath.ast.kotlingrad.sfun
|
||||
import kscience.kmath.ast.kotlingrad.svar
|
||||
import kscience.kmath.ast.kotlingrad.sFun
|
||||
import kscience.kmath.ast.kotlingrad.sVar
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
@ -14,8 +14,8 @@ import kscience.kmath.operations.RealField
|
||||
*/
|
||||
fun main() {
|
||||
val proto = DoublePrecision.prototype
|
||||
val x by MstAlgebra.symbol("x").svar(proto)
|
||||
val quadratic = "x^2-4*x-44".parseMath().sfun(proto)
|
||||
val x by MstAlgebra.symbol("x").sVar(proto)
|
||||
val quadratic = "x^2-4*x-44".parseMath().sFun(proto)
|
||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile()
|
||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||
|
@ -55,7 +55,7 @@ public fun <X : SFun<X>> SFun<X>.mst(): MST = MstExtendedField {
|
||||
* @receiver the node.
|
||||
* @return a new constant.
|
||||
*/
|
||||
public fun <X : SFun<X>> MST.Numeric.sconst(): SConst<X> = SConst(value)
|
||||
public fun <X : SFun<X>> MST.Numeric.sConst(): SConst<X> = SConst(value)
|
||||
|
||||
/**
|
||||
* Maps [MST.Symbolic] to [SVar] directly.
|
||||
@ -64,7 +64,7 @@ public fun <X : SFun<X>> MST.Numeric.sconst(): SConst<X> = SConst(value)
|
||||
* @param proto the prototype instance.
|
||||
* @return a new variable.
|
||||
*/
|
||||
public fun <X : SFun<X>> MST.Symbolic.svar(proto: X): SVar<X> = SVar(proto, value)
|
||||
public fun <X : SFun<X>> MST.Symbolic.sVar(proto: X): SVar<X> = SVar(proto, value)
|
||||
|
||||
/**
|
||||
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
||||
@ -80,28 +80,28 @@ public fun <X : SFun<X>> MST.Symbolic.svar(proto: X): SVar<X> = SVar(proto, valu
|
||||
* @param proto the prototype instance.
|
||||
* @return a scalar function.
|
||||
*/
|
||||
public fun <X : SFun<X>> MST.sfun(proto: X): SFun<X> = when (this) {
|
||||
is MST.Numeric -> sconst()
|
||||
is MST.Symbolic -> svar(proto)
|
||||
public fun <X : SFun<X>> MST.sFun(proto: X): SFun<X> = when (this) {
|
||||
is MST.Numeric -> sConst()
|
||||
is MST.Symbolic -> sVar(proto)
|
||||
|
||||
is MST.Unary -> when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> value.sfun(proto)
|
||||
SpaceOperations.MINUS_OPERATION -> Negative(value.sfun(proto))
|
||||
TrigonometricOperations.SIN_OPERATION -> Sine(value.sfun(proto))
|
||||
TrigonometricOperations.COS_OPERATION -> Cosine(value.sfun(proto))
|
||||
TrigonometricOperations.TAN_OPERATION -> Tangent(value.sfun(proto))
|
||||
PowerOperations.SQRT_OPERATION -> Power(value.sfun(proto), SConst(0.5))
|
||||
ExponentialOperations.EXP_OPERATION -> Power(value.sfun(proto), E())
|
||||
ExponentialOperations.LN_OPERATION -> Log(value.sfun(proto))
|
||||
SpaceOperations.PLUS_OPERATION -> value.sFun(proto)
|
||||
SpaceOperations.MINUS_OPERATION -> Negative(value.sFun(proto))
|
||||
TrigonometricOperations.SIN_OPERATION -> Sine(value.sFun(proto))
|
||||
TrigonometricOperations.COS_OPERATION -> Cosine(value.sFun(proto))
|
||||
TrigonometricOperations.TAN_OPERATION -> Tangent(value.sFun(proto))
|
||||
PowerOperations.SQRT_OPERATION -> Power(value.sFun(proto), SConst(0.5))
|
||||
ExponentialOperations.EXP_OPERATION -> Power(value.sFun(proto), E())
|
||||
ExponentialOperations.LN_OPERATION -> Log(value.sFun(proto))
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
}
|
||||
|
||||
is MST.Binary -> when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> Sum(left.sfun(proto), right.sfun(proto))
|
||||
SpaceOperations.MINUS_OPERATION -> Sum(left.sfun(proto), Negative(right.sfun(proto)))
|
||||
RingOperations.TIMES_OPERATION -> Prod(left.sfun(proto), right.sfun(proto))
|
||||
FieldOperations.DIV_OPERATION -> Prod(left.sfun(proto), Power(right.sfun(proto), Negative(One())))
|
||||
PowerOperations.POW_OPERATION -> Power(left.sfun(proto), SConst((right as MST.Numeric).value))
|
||||
SpaceOperations.PLUS_OPERATION -> Sum(left.sFun(proto), right.sFun(proto))
|
||||
SpaceOperations.MINUS_OPERATION -> Sum(left.sFun(proto), Negative(right.sFun(proto)))
|
||||
RingOperations.TIMES_OPERATION -> Prod(left.sFun(proto), right.sFun(proto))
|
||||
FieldOperations.DIV_OPERATION -> Prod(left.sFun(proto), Power(right.sFun(proto), Negative(One())))
|
||||
PowerOperations.POW_OPERATION -> Power(left.sFun(proto), SConst((right as MST.Numeric).value))
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
}
|
||||
|
@ -18,24 +18,24 @@ internal class AdaptingTests {
|
||||
@Test
|
||||
fun symbol() {
|
||||
val c1 = MstAlgebra.symbol("x")
|
||||
assertTrue(c1.svar(proto).name == "x")
|
||||
val c2 = "kitten".parseMath().sfun(proto)
|
||||
assertTrue(c1.sVar(proto).name == "x")
|
||||
val c2 = "kitten".parseMath().sFun(proto)
|
||||
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun number() {
|
||||
val c1 = MstAlgebra.number(12354324)
|
||||
assertTrue(c1.sconst<DReal>().doubleValue == 12354324.0)
|
||||
val c2 = "0.234".parseMath().sfun(proto)
|
||||
assertTrue(c1.sConst<DReal>().doubleValue == 12354324.0)
|
||||
val c2 = "0.234".parseMath().sFun(proto)
|
||||
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
||||
val c3 = "1e-3".parseMath().sfun(proto)
|
||||
val c3 = "1e-3".parseMath().sFun(proto)
|
||||
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun simpleFunctionShape() {
|
||||
val linear = "2*x+16".parseMath().sfun(proto)
|
||||
val linear = "2*x+16".parseMath().sFun(proto)
|
||||
if (linear !is Sum) fail()
|
||||
if (linear.left !is Prod) fail()
|
||||
if (linear.right !is SConst) fail()
|
||||
@ -43,8 +43,8 @@ internal class AdaptingTests {
|
||||
|
||||
@Test
|
||||
fun simpleFunctionDerivative() {
|
||||
val x = MstAlgebra.symbol("x").svar(proto)
|
||||
val quadratic = "x^2-4*x-44".parseMath().sfun(proto)
|
||||
val x = MstAlgebra.symbol("x").sVar(proto)
|
||||
val quadratic = "x^2-4*x-44".parseMath().sFun(proto)
|
||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile()
|
||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||
@ -52,8 +52,8 @@ internal class AdaptingTests {
|
||||
|
||||
@Test
|
||||
fun moreComplexDerivative() {
|
||||
val x = MstAlgebra.symbol("x").svar(proto)
|
||||
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sfun(proto)
|
||||
val x = MstAlgebra.symbol("x").sVar(proto)
|
||||
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sFun(proto)
|
||||
val actualDerivative = MstExpression(RealField, composition.d(x).mst()).compile()
|
||||
|
||||
val expectedDerivative = MstExpression(
|
||||
|
@ -0,0 +1,22 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kotlin.properties.ReadOnlyProperty
|
||||
import kotlin.reflect.KProperty
|
||||
|
||||
/**
|
||||
* Stores `provideDelegate` method returning property of [MST.Symbolic].
|
||||
*/
|
||||
public object MstSymbolDelegateProvider {
|
||||
/**
|
||||
* Returns [ReadOnlyProperty] of [MST.Symbolic] with its value equal to the name of the property.
|
||||
*/
|
||||
public operator fun provideDelegate(thisRef: Any?, prop: KProperty<*>): ReadOnlyProperty<Any?, MST.Symbolic> =
|
||||
ReadOnlyProperty { _, property -> MST.Symbolic(property.name) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [MstSymbolDelegateProvider].
|
||||
*/
|
||||
public val Algebra<MST>.symbol: MstSymbolDelegateProvider
|
||||
get() = MstSymbolDelegateProvider
|
Loading…
Reference in New Issue
Block a user