forked from kscience/kmath
Add test, update MstAlgebra a bit to return concrete types
This commit is contained in:
parent
31c71e0fad
commit
57bdee4936
@ -36,7 +36,7 @@ public fun <X : SFun<X>> SFun<X>.mst(): MST = MstExtendedField {
|
|||||||
is SConst -> number(doubleValue)
|
is SConst -> number(doubleValue)
|
||||||
is Sum -> left.mst() + right.mst()
|
is Sum -> left.mst() + right.mst()
|
||||||
is Prod -> left.mst() * right.mst()
|
is Prod -> left.mst() * right.mst()
|
||||||
is Power -> power(left.mst(), (right() as SConst<*>).doubleValue)
|
is Power -> power(left.mst(), (right as SConst<*>).doubleValue)
|
||||||
is Negative -> -input.mst()
|
is Negative -> -input.mst()
|
||||||
is Log -> ln(left.mst()) / ln(right.mst())
|
is Log -> ln(left.mst()) / ln(right.mst())
|
||||||
is Sine -> sin(input.mst())
|
is Sine -> sin(input.mst())
|
||||||
|
@ -0,0 +1,66 @@
|
|||||||
|
package kscience.kmath.ast.kotlingrad
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.*
|
||||||
|
import kscience.kmath.asm.compile
|
||||||
|
import kscience.kmath.ast.MstAlgebra
|
||||||
|
import kscience.kmath.ast.MstExpression
|
||||||
|
import kscience.kmath.ast.parseMath
|
||||||
|
import kscience.kmath.expressions.invoke
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
|
internal class AdaptingTests {
|
||||||
|
private val proto: DReal = DoublePrecision.prototype
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun symbol() {
|
||||||
|
val c1 = MstAlgebra.symbol("x")
|
||||||
|
assertTrue(c1.svar(proto).name == "x")
|
||||||
|
val c2 = "kitten".parseMath().sfun(proto)
|
||||||
|
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun number() {
|
||||||
|
val c1 = MstAlgebra.number(12354324)
|
||||||
|
assertTrue(c1.sconst<DReal>().doubleValue == 12354324.0)
|
||||||
|
val c2 = "0.234".parseMath().sfun(proto)
|
||||||
|
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
||||||
|
val c3 = "1e-3".parseMath().sfun(proto)
|
||||||
|
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun simpleFunctionShape() {
|
||||||
|
val linear = "2*x+16".parseMath().sfun(proto)
|
||||||
|
if (linear !is Sum) fail()
|
||||||
|
if (linear.left !is Prod) fail()
|
||||||
|
if (linear.right !is SConst) fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun simpleFunctionDerivative() {
|
||||||
|
val x = MstAlgebra.symbol("x").svar(proto)
|
||||||
|
val quadratic = "x^2-4*x-44".parseMath().sfun(proto)
|
||||||
|
val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile()
|
||||||
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
|
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun moreComplexDerivative() {
|
||||||
|
val x = MstAlgebra.symbol("x").svar(proto)
|
||||||
|
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sfun(proto)
|
||||||
|
val actualDerivative = MstExpression(RealField, composition.d(x).mst()).compile()
|
||||||
|
|
||||||
|
val expectedDerivative = MstExpression(
|
||||||
|
RealField,
|
||||||
|
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
|
||||||
|
).compile()
|
||||||
|
|
||||||
|
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
|
||||||
|
}
|
||||||
|
}
|
@ -6,14 +6,14 @@ import kscience.kmath.operations.*
|
|||||||
* [Algebra] over [MST] nodes.
|
* [Algebra] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstAlgebra : NumericAlgebra<MST> {
|
public object MstAlgebra : NumericAlgebra<MST> {
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST =
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
|
||||||
MST.Unary(operation, arg)
|
MST.Unary(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MST.Binary(operation, left, right)
|
MST.Binary(operation, left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,97 +21,100 @@ public object MstAlgebra : NumericAlgebra<MST> {
|
|||||||
* [Space] over [MST] nodes.
|
* [Space] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST = number(0.0)
|
override val zero: MST.Numeric by lazy { number(0.0) }
|
||||||
|
|
||||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MstAlgebra.binaryOperation(operation, left, right)
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Ring] over [MST] nodes.
|
* [Ring] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST
|
override val zero: MST.Numeric
|
||||||
get() = MstSpace.zero
|
get() = MstSpace.zero
|
||||||
override val one: MST = number(1.0)
|
|
||||||
|
|
||||||
override fun number(value: Number): MST = MstSpace.number(value)
|
override val one: MST.Numeric by lazy { number(1.0) }
|
||||||
override fun symbol(value: String): MST = MstSpace.symbol(value)
|
|
||||||
override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
|
|
||||||
|
|
||||||
override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
|
override fun number(value: Number): MST.Numeric = MstSpace.number(value)
|
||||||
|
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
||||||
|
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
||||||
|
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
||||||
|
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
|
||||||
MstSpace.binaryOperation(operation, left, right)
|
MstSpace.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Field] over [MST] nodes.
|
* [Field] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstField : Field<MST> {
|
public object MstField : Field<MST> {
|
||||||
public override val zero: MST
|
public override val zero: MST.Numeric
|
||||||
get() = MstRing.zero
|
get() = MstRing.zero
|
||||||
|
|
||||||
public override val one: MST
|
public override val one: MST.Numeric
|
||||||
get() = MstRing.one
|
get() = MstRing.one
|
||||||
|
|
||||||
public override fun symbol(value: String): MST = MstRing.symbol(value)
|
public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value)
|
||||||
public override fun number(value: Number): MST = MstRing.number(value)
|
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
||||||
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
|
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||||
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
|
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
|
||||||
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
|
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||||
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MstRing.binaryOperation(operation, left, right)
|
MstRing.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [ExtendedField] over [MST] nodes.
|
* [ExtendedField] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstExtendedField : ExtendedField<MST> {
|
public object MstExtendedField : ExtendedField<MST> {
|
||||||
override val zero: MST
|
override val zero: MST.Numeric
|
||||||
get() = MstField.zero
|
get() = MstField.zero
|
||||||
|
|
||||||
override val one: MST
|
override val one: MST.Numeric
|
||||||
get() = MstField.one
|
get() = MstField.one
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MstField.symbol(value)
|
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
|
||||||
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
override fun number(value: Number): MST.Numeric = MstField.number(value)
|
||||||
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||||
override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||||
override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
||||||
override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||||
override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||||
override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||||
override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
||||||
override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
||||||
override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
||||||
override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
||||||
override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
||||||
override fun add(a: MST, b: MST): MST = MstField.add(a, b)
|
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
||||||
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
|
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||||
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
|
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
|
||||||
override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
|
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||||
override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||||
override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
|
||||||
override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun power(arg: MST, pow: Number): MST.Binary =
|
||||||
|
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||||
|
|
||||||
|
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||||
|
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MstField.binaryOperation(operation, left, right)
|
MstField.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user