forked from kscience/kmath
Rename KG module
This commit is contained in:
parent
fcfd79cb69
commit
381137724d
@ -20,7 +20,7 @@ sourceSets.register("benchmarks")
|
|||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation(project(":kmath-ast"))
|
implementation(project(":kmath-ast"))
|
||||||
implementation(project(":kmath-ast-kotlingrad"))
|
implementation(project(":kmath-kotlingrad"))
|
||||||
implementation(project(":kmath-core"))
|
implementation(project(":kmath-core"))
|
||||||
implementation(project(":kmath-coroutines"))
|
implementation(project(":kmath-coroutines"))
|
||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
|
@ -2,9 +2,9 @@ package kscience.kmath.ast
|
|||||||
|
|
||||||
import edu.umontreal.kotlingrad.experimental.DoublePrecision
|
import edu.umontreal.kotlingrad.experimental.DoublePrecision
|
||||||
import kscience.kmath.asm.compile
|
import kscience.kmath.asm.compile
|
||||||
import kscience.kmath.ast.kotlingrad.mst
|
import kscience.kmath.kotlingrad.toMst
|
||||||
import kscience.kmath.ast.kotlingrad.sFun
|
import kscience.kmath.kotlingrad.tSFun
|
||||||
import kscience.kmath.ast.kotlingrad.sVar
|
import kscience.kmath.kotlingrad.toSVar
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
|
|
||||||
@ -14,9 +14,9 @@ import kscience.kmath.operations.RealField
|
|||||||
*/
|
*/
|
||||||
fun main() {
|
fun main() {
|
||||||
val proto = DoublePrecision.prototype
|
val proto = DoublePrecision.prototype
|
||||||
val x by MstAlgebra.symbol("x").sVar(proto)
|
val x by MstAlgebra.symbol("x").toSVar(proto)
|
||||||
val quadratic = "x^2-4*x-44".parseMath().sFun(proto)
|
val quadratic = "x^2-4*x-44".parseMath().tSFun(proto)
|
||||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile()
|
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.ast.kotlingrad
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
import edu.umontreal.kotlingrad.experimental.*
|
import edu.umontreal.kotlingrad.experimental.*
|
||||||
import kscience.kmath.ast.MST
|
import kscience.kmath.ast.MST
|
||||||
@ -30,22 +30,22 @@ import kscience.kmath.operations.*
|
|||||||
* @receiver the scalar function.
|
* @receiver the scalar function.
|
||||||
* @return a node.
|
* @return a node.
|
||||||
*/
|
*/
|
||||||
public fun <X : SFun<X>> SFun<X>.mst(): MST = MstExtendedField {
|
public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
|
||||||
when (this@mst) {
|
when (this@toMst) {
|
||||||
is SVar -> symbol(name)
|
is SVar -> symbol(name)
|
||||||
is SConst -> number(doubleValue)
|
is SConst -> number(doubleValue)
|
||||||
is Sum -> left.mst() + right.mst()
|
is Sum -> left.toMst() + right.toMst()
|
||||||
is Prod -> left.mst() * right.mst()
|
is Prod -> left.toMst() * right.toMst()
|
||||||
is Power -> power(left.mst(), (right as SConst<*>).doubleValue)
|
is Power -> power(left.toMst(), (right as SConst<*>).doubleValue)
|
||||||
is Negative -> -input.mst()
|
is Negative -> -input.toMst()
|
||||||
is Log -> ln(left.mst()) / ln(right.mst())
|
is Log -> ln(left.toMst()) / ln(right.toMst())
|
||||||
is Sine -> sin(input.mst())
|
is Sine -> sin(input.toMst())
|
||||||
is Cosine -> cos(input.mst())
|
is Cosine -> cos(input.toMst())
|
||||||
is Tangent -> tan(input.mst())
|
is Tangent -> tan(input.toMst())
|
||||||
is DProd -> this@mst().mst()
|
is DProd -> this@toMst().toMst()
|
||||||
is SComposition -> this@mst().mst()
|
is SComposition -> this@toMst().toMst()
|
||||||
is VSumAll<X, *> -> this@mst().mst()
|
is VSumAll<X, *> -> this@toMst().toMst()
|
||||||
is Derivative -> this@mst().mst()
|
is Derivative -> this@toMst().toMst()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ public fun <X : SFun<X>> SFun<X>.mst(): MST = MstExtendedField {
|
|||||||
* @receiver the node.
|
* @receiver the node.
|
||||||
* @return a new constant.
|
* @return a new constant.
|
||||||
*/
|
*/
|
||||||
public fun <X : SFun<X>> MST.Numeric.sConst(): SConst<X> = SConst(value)
|
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps [MST.Symbolic] to [SVar] directly.
|
* Maps [MST.Symbolic] to [SVar] directly.
|
||||||
@ -64,7 +64,7 @@ public fun <X : SFun<X>> MST.Numeric.sConst(): SConst<X> = SConst(value)
|
|||||||
* @param proto the prototype instance.
|
* @param proto the prototype instance.
|
||||||
* @return a new variable.
|
* @return a new variable.
|
||||||
*/
|
*/
|
||||||
public fun <X : SFun<X>> MST.Symbolic.sVar(proto: X): SVar<X> = SVar(proto, value)
|
public fun <X : SFun<X>> MST.Symbolic.toSVar(proto: X): SVar<X> = SVar(proto, value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
||||||
@ -80,28 +80,28 @@ public fun <X : SFun<X>> MST.Symbolic.sVar(proto: X): SVar<X> = SVar(proto, valu
|
|||||||
* @param proto the prototype instance.
|
* @param proto the prototype instance.
|
||||||
* @return a scalar function.
|
* @return a scalar function.
|
||||||
*/
|
*/
|
||||||
public fun <X : SFun<X>> MST.sFun(proto: X): SFun<X> = when (this) {
|
public fun <X : SFun<X>> MST.tSFun(proto: X): SFun<X> = when (this) {
|
||||||
is MST.Numeric -> sConst()
|
is MST.Numeric -> toSConst()
|
||||||
is MST.Symbolic -> sVar(proto)
|
is MST.Symbolic -> toSVar(proto)
|
||||||
|
|
||||||
is MST.Unary -> when (operation) {
|
is MST.Unary -> when (operation) {
|
||||||
SpaceOperations.PLUS_OPERATION -> value.sFun(proto)
|
SpaceOperations.PLUS_OPERATION -> value.tSFun(proto)
|
||||||
SpaceOperations.MINUS_OPERATION -> Negative(value.sFun(proto))
|
SpaceOperations.MINUS_OPERATION -> Negative(value.tSFun(proto))
|
||||||
TrigonometricOperations.SIN_OPERATION -> Sine(value.sFun(proto))
|
TrigonometricOperations.SIN_OPERATION -> Sine(value.tSFun(proto))
|
||||||
TrigonometricOperations.COS_OPERATION -> Cosine(value.sFun(proto))
|
TrigonometricOperations.COS_OPERATION -> Cosine(value.tSFun(proto))
|
||||||
TrigonometricOperations.TAN_OPERATION -> Tangent(value.sFun(proto))
|
TrigonometricOperations.TAN_OPERATION -> Tangent(value.tSFun(proto))
|
||||||
PowerOperations.SQRT_OPERATION -> Power(value.sFun(proto), SConst(0.5))
|
PowerOperations.SQRT_OPERATION -> Power(value.tSFun(proto), SConst(0.5))
|
||||||
ExponentialOperations.EXP_OPERATION -> Power(value.sFun(proto), E())
|
ExponentialOperations.EXP_OPERATION -> Power(value.tSFun(proto), E())
|
||||||
ExponentialOperations.LN_OPERATION -> Log(value.sFun(proto))
|
ExponentialOperations.LN_OPERATION -> Log(value.tSFun(proto))
|
||||||
else -> error("Unary operation $operation not defined in $this")
|
else -> error("Unary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Binary -> when (operation) {
|
is MST.Binary -> when (operation) {
|
||||||
SpaceOperations.PLUS_OPERATION -> Sum(left.sFun(proto), right.sFun(proto))
|
SpaceOperations.PLUS_OPERATION -> Sum(left.tSFun(proto), right.tSFun(proto))
|
||||||
SpaceOperations.MINUS_OPERATION -> Sum(left.sFun(proto), Negative(right.sFun(proto)))
|
SpaceOperations.MINUS_OPERATION -> Sum(left.tSFun(proto), Negative(right.tSFun(proto)))
|
||||||
RingOperations.TIMES_OPERATION -> Prod(left.sFun(proto), right.sFun(proto))
|
RingOperations.TIMES_OPERATION -> Prod(left.tSFun(proto), right.tSFun(proto))
|
||||||
FieldOperations.DIV_OPERATION -> Prod(left.sFun(proto), Power(right.sFun(proto), Negative(One())))
|
FieldOperations.DIV_OPERATION -> Prod(left.tSFun(proto), Power(right.tSFun(proto), Negative(One())))
|
||||||
PowerOperations.POW_OPERATION -> Power(left.sFun(proto), SConst((right as MST.Numeric).value))
|
PowerOperations.POW_OPERATION -> Power(left.tSFun(proto), SConst((right as MST.Numeric).value))
|
||||||
else -> error("Binary operation $operation not defined in $this")
|
else -> error("Binary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.ast.kotlingrad
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
import edu.umontreal.kotlingrad.experimental.*
|
import edu.umontreal.kotlingrad.experimental.*
|
||||||
import kscience.kmath.asm.compile
|
import kscience.kmath.asm.compile
|
||||||
@ -18,24 +18,24 @@ internal class AdaptingTests {
|
|||||||
@Test
|
@Test
|
||||||
fun symbol() {
|
fun symbol() {
|
||||||
val c1 = MstAlgebra.symbol("x")
|
val c1 = MstAlgebra.symbol("x")
|
||||||
assertTrue(c1.sVar(proto).name == "x")
|
assertTrue(c1.toSVar(proto).name == "x")
|
||||||
val c2 = "kitten".parseMath().sFun(proto)
|
val c2 = "kitten".parseMath().tSFun(proto)
|
||||||
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun number() {
|
fun number() {
|
||||||
val c1 = MstAlgebra.number(12354324)
|
val c1 = MstAlgebra.number(12354324)
|
||||||
assertTrue(c1.sConst<DReal>().doubleValue == 12354324.0)
|
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
|
||||||
val c2 = "0.234".parseMath().sFun(proto)
|
val c2 = "0.234".parseMath().tSFun(proto)
|
||||||
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
||||||
val c3 = "1e-3".parseMath().sFun(proto)
|
val c3 = "1e-3".parseMath().tSFun(proto)
|
||||||
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun simpleFunctionShape() {
|
fun simpleFunctionShape() {
|
||||||
val linear = "2*x+16".parseMath().sFun(proto)
|
val linear = "2*x+16".parseMath().tSFun(proto)
|
||||||
if (linear !is Sum) fail()
|
if (linear !is Sum) fail()
|
||||||
if (linear.left !is Prod) fail()
|
if (linear.left !is Prod) fail()
|
||||||
if (linear.right !is SConst) fail()
|
if (linear.right !is SConst) fail()
|
||||||
@ -43,18 +43,18 @@ internal class AdaptingTests {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun simpleFunctionDerivative() {
|
fun simpleFunctionDerivative() {
|
||||||
val x = MstAlgebra.symbol("x").sVar(proto)
|
val x = MstAlgebra.symbol("x").toSVar(proto)
|
||||||
val quadratic = "x^2-4*x-44".parseMath().sFun(proto)
|
val quadratic = "x^2-4*x-44".parseMath().tSFun(proto)
|
||||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile()
|
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun moreComplexDerivative() {
|
fun moreComplexDerivative() {
|
||||||
val x = MstAlgebra.symbol("x").sVar(proto)
|
val x = MstAlgebra.symbol("x").toSVar(proto)
|
||||||
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().sFun(proto)
|
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().tSFun(proto)
|
||||||
val actualDerivative = MstExpression(RealField, composition.d(x).mst()).compile()
|
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
|
||||||
|
|
||||||
val expectedDerivative = MstExpression(
|
val expectedDerivative = MstExpression(
|
||||||
RealField,
|
RealField,
|
@ -40,5 +40,5 @@ include(
|
|||||||
":kmath-ast",
|
":kmath-ast",
|
||||||
":examples",
|
":examples",
|
||||||
":kmath-ejml",
|
":kmath-ejml",
|
||||||
":kmath-ast-kotlingrad"
|
":kmath-kotlingrad"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user