Refactor MST

This commit is contained in:
Alexander Nozik 2021-04-01 20:15:49 +03:00
parent c2bab5d138
commit af4866e876
18 changed files with 241 additions and 365 deletions

View File

@ -5,6 +5,7 @@
- ScaleOperations interface - ScaleOperations interface
- Field extends ScaleOperations - Field extends ScaleOperations
- Basic integration API - Basic integration API
- Basic MPP distributions and samplers
### Changed ### Changed
- Exponential operations merged with hyperbolic functions - Exponential operations merged with hyperbolic functions
@ -14,6 +15,8 @@
- NDStructure and NDAlgebra to StructureND and AlgebraND respectively - NDStructure and NDAlgebra to StructureND and AlgebraND respectively
- Real -> Double - Real -> Double
- DataSets are moved from functions to core - DataSets are moved from functions to core
- Redesign advanced Chain API
- Redesign MST. Remove MSTExpression.
### Deprecated ### Deprecated
@ -21,6 +24,7 @@
- Nearest in Domain. To be implemented in geometry package. - Nearest in Domain. To be implemented in geometry package.
- Number multiplication and division in main Algebra chain - Number multiplication and division in main Algebra chain
- `contentEquals` from Buffer. It moved to the companion. - `contentEquals` from Buffer. It moved to the companion.
- MSTExpression
### Fixed ### Fixed

View File

@ -1,15 +1,17 @@
package space.kscience.kmath.ast package space.kscience.kmath.ast
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
fun main() { fun main() {
val expr = DoubleField.mstInField { val expr = MstField {
val x = bindSymbol("x") val x = bindSymbol(x)
x * 2.0 + number(2.0) / x - 16.0 x * 2.0 + number(2.0) / x - 16.0
} }
repeat(10000000) { repeat(10000000) {
expr.invoke("x" to 1.0) expr.interpret(DoubleField, x to 1.0)
} }
} }

View File

@ -1,9 +1,9 @@
package space.kscience.kmath.ast package space.kscience.kmath.ast
import space.kscience.kmath.asm.compile import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.derivative import space.kscience.kmath.expressions.derivative
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.kotlingrad.differentiable import space.kscience.kmath.kotlingrad.toDiffExpression
import space.kscience.kmath.misc.symbol import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
@ -14,11 +14,11 @@ import space.kscience.kmath.operations.DoubleField
fun main() { fun main() {
val x by symbol val x by symbol
val actualDerivative = MstExpression(DoubleField, "x^2-4*x-44".parseMath()) val actualDerivative = "x^2-4*x-44".parseMath()
.differentiable() .toDiffExpression(DoubleField)
.derivative(x) .derivative(x)
.compile()
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile()
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
} }

View File

@ -1,5 +1,8 @@
package space.kscience.kmath.ast package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.misc.StringSymbol
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra import space.kscience.kmath.operations.NumericAlgebra
@ -76,11 +79,51 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
} }
} }
internal class InnerAlgebra<T : Any>(val algebra: Algebra<T>, val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun bindSymbol(value: String): T = try {
algebra.bindSymbol(value)
} catch (ignored: IllegalStateException) {
null
} ?: arguments.getValue(StringSymbol(value))
override fun unaryOperation(operation: String, arg: T): T =
algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
algebra.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
algebra.binaryOperationFunction(operation)
@Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else
error("Numeric nodes are not supported by $this")
}
/** /**
* Interprets the [MST] node with this [Algebra]. * Interprets the [MST] node with this [Algebra] and optional [arguments]
*/
public fun <T : Any> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
InnerAlgebra(algebra, arguments).evaluate(this)
/**
* Interprets the [MST] node with this [Algebra] and optional [arguments]
* *
* @receiver the node to evaluate. * @receiver the node to evaluate.
* @param algebra the algebra that provides operations. * @param algebra the algebra that provides operations.
* @return the value of expression. * @return the value of expression.
*/ */
public fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this) public fun <T : Any> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
interpret(algebra, mapOf(*arguments))
/**
* Interpret this [MST] as expression.
*/
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
interpret(algebra, arguments)
}

View File

@ -1,138 +0,0 @@
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.StringSymbol
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/**
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
* ASM-generated expressions.
*
* @property algebra the algebra that provides operations.
* @property mst the [MST] node.
* @author Alexander Nozik
*/
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun bindSymbol(value: String): T = try {
algebra.bindSymbol(value)
} catch (ignored: IllegalStateException) {
null
} ?: arguments.getValue(StringSymbol(value))
override fun unaryOperation(operation: String, arg: T): T =
algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
algebra.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
algebra.binaryOperationFunction(operation)
@Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else
error("Numeric nodes are not supported by $this")
}
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
}
/**
* Builds [MstExpression] over [Algebra].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E,
block: E.() -> MST,
): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
/**
* Builds [MstExpression] over [Group].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Group<T>> A.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstGroup.block())
}
/**
* Builds [MstExpression] over [Ring].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstRing.block())
}
/**
* Builds [MstExpression] over [Field].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstField.block())
}
/**
* Builds [MstExpression] over [ExtendedField].
*
* @author Iaroslav Postovalov
*/
public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstExtendedField.block())
}
/**
* Builds [MstExpression] over [FunctionalExpressionGroup].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Group<T>> FunctionalExpressionGroup<T, A>.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInGroup(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionRing].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInRing(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionField].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInField(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
*
* @author Iaroslav Postovalov
*/
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
block: MstExtendedField.() -> MST,
): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block)
}

View File

@ -2,10 +2,11 @@ package space.kscience.kmath.estree
import space.kscience.kmath.ast.MST import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.MST.* import space.kscience.kmath.ast.MST.*
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.estree.internal.ESTreeBuilder import space.kscience.kmath.estree.internal.ESTreeBuilder
import space.kscience.kmath.estree.internal.estree.BaseExpression import space.kscience.kmath.estree.internal.estree.BaseExpression
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra import space.kscience.kmath.operations.NumericAlgebra
@ -64,19 +65,21 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
return ESTreeBuilder<T> { visit(this@compileWith) }.instance return ESTreeBuilder<T> { visit(this@compileWith) }.instance
} }
/**
* Create a compiled expression with given [MST] and given [algebra].
*/
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = compileWith(algebra)
/** /**
* Compiles an [MST] to ESTree generated expression using given algebra. * Compile given MST to expression and evaluate it against [arguments]
*
* @author Iaroslav Postovalov
*/ */
public fun <T : Any> Algebra<T>.expression(mst: MST): Expression<T> = public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
mst.compileWith(this) compileToExpression(algebra).invoke(arguments)
/** /**
* Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression. * Compile given MST to expression and evaluate it against [arguments]
*
* @author Iaroslav Postovalov
*/ */
public fun <T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> = public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
mst.compileWith(algebra) compileToExpression(algebra).invoke(*arguments)

View File

@ -3,16 +3,19 @@ package space.kscience.kmath.estree
import space.kscience.kmath.ast.* import space.kscience.kmath.ast.*
import space.kscience.kmath.complex.ComplexField import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.ByteRing import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class TestESTreeConsistencyWithInterpreter { internal class TestESTreeConsistencyWithInterpreter {
@Test @Test
fun mstSpace() { fun mstSpace() {
val res1 = MstGroup.mstInGroup {
val mst = MstGroup {
binaryOperationFunction("+")( binaryOperationFunction("+")(
unaryOperationFunction("+")( unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale( number(3.toByte()) - (number(2.toByte()) + (scale(
@ -23,27 +26,17 @@ internal class TestESTreeConsistencyWithInterpreter {
number(1) number(1)
) + bindSymbol("x") + zero ) + bindSymbol("x") + zero
}("x" to MST.Numeric(2)) }
val res2 = MstGroup.mstInGroup { assertEquals(
binaryOperationFunction("+")( mst.interpret(MstGroup, Symbol.x to MST.Numeric(2)),
unaryOperationFunction("+")( mst.compile(MstGroup, Symbol.x to MST.Numeric(2))
number(3.toByte()) - (number(2.toByte()) + (scale( )
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol("x") + zero
}.compile()("x" to MST.Numeric(2))
assertEquals(res1, res2)
} }
@Test @Test
fun byteRing() { fun byteRing() {
val res1 = ByteRing.mstInRing { val mst = MstRing {
binaryOperationFunction("+")( binaryOperationFunction("+")(
unaryOperationFunction("+")( unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale( (bindSymbol("x") - (2.toByte() + (scale(
@ -54,62 +47,43 @@ internal class TestESTreeConsistencyWithInterpreter {
number(1) number(1)
) * number(2) ) * number(2)
}("x" to 3.toByte()) }
val res2 = ByteRing.mstInRing { assertEquals(
binaryOperationFunction("+")( mst.interpret(ByteRing, Symbol.x to 3.toByte()),
unaryOperationFunction("+")( mst.compile(ByteRing, Symbol.x to 3.toByte())
(bindSymbol("x") - (2.toByte() + (scale( )
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}.compile()("x" to 3.toByte())
assertEquals(res1, res2)
} }
@Test @Test
fun realField() { fun realField() {
val res1 = DoubleField.mstInField { val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
number(1) / 2 + number(2.0) * one number(1) / 2 + number(2.0) * one
) + zero ) + zero
}("x" to 2.0) }
val res2 = DoubleField.mstInField { assertEquals(
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( mst.interpret(DoubleField, Symbol.x to 2.0),
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 mst.compile(DoubleField, Symbol.x to 2.0)
+ number(1), )
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0)
assertEquals(res1, res2)
} }
@Test @Test
fun complexField() { fun complexField() {
val res1 = ComplexField.mstInField { val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
number(1) / 2 + number(2.0) * one number(1) / 2 + number(2.0) * one
) + zero ) + zero
}("x" to 2.0.toComplex()) }
val res2 = ComplexField.mstInField { assertEquals(
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( mst.interpret(ComplexField, Symbol.x to 2.0.toComplex()),
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 mst.compile(ComplexField, Symbol.x to 2.0.toComplex())
+ number(1), )
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0.toComplex())
assertEquals(res1, res2)
} }
} }

View File

@ -1,10 +1,9 @@
package space.kscience.kmath.estree package space.kscience.kmath.estree
import space.kscience.kmath.ast.mstInExtendedField import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.ast.mstInGroup
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.random.Random import kotlin.random.Random
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -12,29 +11,29 @@ import kotlin.test.assertEquals
internal class TestESTreeOperationsSupport { internal class TestESTreeOperationsSupport {
@Test @Test
fun testUnaryOperationInvocation() { fun testUnaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile() val expression = MstExtendedField { -bindSymbol("x") }.compileToExpression(DoubleField)
val res = expression("x" to 2.0) val res = expression("x" to 2.0)
assertEquals(-2.0, res) assertEquals(-2.0, res)
} }
@Test @Test
fun testBinaryOperationInvocation() { fun testBinaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile() val expression = MstExtendedField { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
val res = expression("x" to 2.0) val res = expression("x" to 2.0)
assertEquals(-1.0, res) assertEquals(-1.0, res)
} }
@Test @Test
fun testConstProductInvocation() { fun testConstProductInvocation() {
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0) val res = MstExtendedField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
assertEquals(4.0, res) assertEquals(4.0, res)
} }
@Test @Test
fun testMultipleCalls() { fun testMultipleCalls() {
val e = val e =
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) } MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compile() .compileToExpression(DoubleField)
val r = Random(0) val r = Random(0)
var s = 0.0 var s = 0.0
repeat(1000000) { s += e("x" to r.nextDouble()) } repeat(1000000) { s += e("x" to r.nextDouble()) }

View File

@ -1,53 +1,63 @@
package space.kscience.kmath.estree package space.kscience.kmath.estree
import space.kscience.kmath.ast.mstInField import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class TestESTreeSpecialization { internal class TestESTreeSpecialization {
@Test @Test
fun testUnaryPlus() { fun testUnaryPlus() {
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile() val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(2.0, expr("x" to 2.0)) assertEquals(2.0, expr("x" to 2.0))
} }
@Test @Test
fun testUnaryMinus() { fun testUnaryMinus() {
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile() val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr("x" to 2.0)) assertEquals(-2.0, expr("x" to 2.0))
} }
@Test @Test
fun testAdd() { fun testAdd() {
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile() val expr = MstExtendedField {
binaryOperationFunction("+")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0)) assertEquals(4.0, expr("x" to 2.0))
} }
@Test @Test
fun testSine() { fun testSine() {
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile() val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 0.0)) assertEquals(0.0, expr("x" to 0.0))
} }
@Test @Test
fun testMinus() { fun testMinus() {
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile() val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 2.0)) assertEquals(0.0, expr("x" to 2.0))
} }
@Test @Test
fun testDivide() { fun testDivide() {
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile() val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr("x" to 2.0)) assertEquals(1.0, expr("x" to 2.0))
} }
@Test @Test
fun testPower() { fun testPower() {
val expr = DoubleField val expr = MstExtendedField {
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) } binaryOperationFunction("pow")(bindSymbol("x"), number(2))
.compile() }.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0)) assertEquals(4.0, expr("x" to 2.0))
} }

View File

@ -1,8 +1,9 @@
package space.kscience.kmath.estree package space.kscience.kmath.estree
import space.kscience.kmath.ast.mstInRing import space.kscience.kmath.ast.MstRing
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.ByteRing import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
internal class TestESTreeVariables { internal class TestESTreeVariables {
@Test @Test
fun testVariable() { fun testVariable() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile() val expr = MstRing{ bindSymbol("x") }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr("x" to 1.toByte())) assertEquals(1.toByte(), expr("x" to 1.toByte()))
} }
@Test @Test
fun testUndefinedVariableFails() { fun testUndefinedVariableFails() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile() val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() } assertFailsWith<NoSuchElementException> { expr() }
} }
} }

View File

@ -4,8 +4,9 @@ import space.kscience.kmath.asm.internal.AsmBuilder
import space.kscience.kmath.asm.internal.buildName import space.kscience.kmath.asm.internal.buildName
import space.kscience.kmath.ast.MST import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.MST.* import space.kscience.kmath.ast.MST.*
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra import space.kscience.kmath.operations.NumericAlgebra
@ -70,18 +71,22 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
} }
/**
* Compiles an [MST] to ASM using given algebra.
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
mst.compileWith(T::class.java, this)
/** /**
* Optimizes performance of an [MstExpression] using ASM codegen. * Create a compiled expression with given [MST] and given [algebra].
*
* @author Alexander Nozik
*/ */
public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> = public inline fun <reified T: Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
mst.compileWith(T::class.java, algebra) compileWith(T::class.java, algebra)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
compileToExpression(algebra).invoke(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
compileToExpression(algebra).invoke(*arguments)

View File

@ -3,16 +3,19 @@ package space.kscience.kmath.asm
import space.kscience.kmath.ast.* import space.kscience.kmath.ast.*
import space.kscience.kmath.complex.ComplexField import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.ByteRing import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class TestAsmConsistencyWithInterpreter { internal class TestAsmConsistencyWithInterpreter {
@Test @Test
fun mstSpace() { fun mstSpace() {
val res1 = MstGroup.mstInGroup {
val mst = MstGroup {
binaryOperationFunction("+")( binaryOperationFunction("+")(
unaryOperationFunction("+")( unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale( number(3.toByte()) - (number(2.toByte()) + (scale(
@ -23,27 +26,17 @@ internal class TestAsmConsistencyWithInterpreter {
number(1) number(1)
) + bindSymbol("x") + zero ) + bindSymbol("x") + zero
}("x" to MST.Numeric(2)) }
val res2 = MstGroup.mstInGroup { assertEquals(
binaryOperationFunction("+")( mst.interpret(MstGroup, x to MST.Numeric(2)),
unaryOperationFunction("+")( mst.compile(MstGroup, x to MST.Numeric(2))
number(3.toByte()) - (number(2.toByte()) + (scale( )
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol("x") + zero
}.compile()("x" to MST.Numeric(2))
assertEquals(res1, res2)
} }
@Test @Test
fun byteRing() { fun byteRing() {
val res1 = ByteRing.mstInRing { val mst = MstRing {
binaryOperationFunction("+")( binaryOperationFunction("+")(
unaryOperationFunction("+")( unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale( (bindSymbol("x") - (2.toByte() + (scale(
@ -54,62 +47,43 @@ internal class TestAsmConsistencyWithInterpreter {
number(1) number(1)
) * number(2) ) * number(2)
}("x" to 3.toByte()) }
val res2 = ByteRing.mstInRing { assertEquals(
binaryOperationFunction("+")( mst.interpret(ByteRing, x to 3.toByte()),
unaryOperationFunction("+")( mst.compile(ByteRing, x to 3.toByte())
(bindSymbol("x") - (2.toByte() + (scale( )
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}.compile()("x" to 3.toByte())
assertEquals(res1, res2)
} }
@Test @Test
fun realField() { fun realField() {
val res1 = DoubleField.mstInField { val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
number(1) / 2 + number(2.0) * one number(1) / 2 + number(2.0) * one
) + zero ) + zero
}("x" to 2.0) }
val res2 = DoubleField.mstInField { assertEquals(
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( mst.interpret(DoubleField, x to 2.0),
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 mst.compile(DoubleField, x to 2.0)
+ number(1), )
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0)
assertEquals(res1, res2)
} }
@Test @Test
fun complexField() { fun complexField() {
val res1 = ComplexField.mstInField { val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
number(1) / 2 + number(2.0) * one number(1) / 2 + number(2.0) * one
) + zero ) + zero
}("x" to 2.0.toComplex()) }
val res2 = ComplexField.mstInField { assertEquals(
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( mst.interpret(ComplexField, x to 2.0.toComplex()),
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 mst.compile(ComplexField, x to 2.0.toComplex())
+ number(1), )
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0.toComplex())
assertEquals(res1, res2)
} }
} }

View File

@ -1,10 +1,11 @@
package space.kscience.kmath.asm package space.kscience.kmath.asm
import space.kscience.kmath.ast.mstInExtendedField import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.ast.mstInField import space.kscience.kmath.ast.MstField
import space.kscience.kmath.ast.mstInGroup import space.kscience.kmath.ast.MstGroup
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.random.Random import kotlin.random.Random
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -12,29 +13,29 @@ import kotlin.test.assertEquals
internal class TestAsmOperationsSupport { internal class TestAsmOperationsSupport {
@Test @Test
fun testUnaryOperationInvocation() { fun testUnaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile() val expression = MstGroup { -bindSymbol("x") }.compileToExpression(DoubleField)
val res = expression("x" to 2.0) val res = expression("x" to 2.0)
assertEquals(-2.0, res) assertEquals(-2.0, res)
} }
@Test @Test
fun testBinaryOperationInvocation() { fun testBinaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile() val expression = MstGroup { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
val res = expression("x" to 2.0) val res = expression("x" to 2.0)
assertEquals(-1.0, res) assertEquals(-1.0, res)
} }
@Test @Test
fun testConstProductInvocation() { fun testConstProductInvocation() {
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0) val res = MstField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
assertEquals(4.0, res) assertEquals(4.0, res)
} }
@Test @Test
fun testMultipleCalls() { fun testMultipleCalls() {
val e = val e =
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) } MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compile() .compileToExpression(DoubleField)
val r = Random(0) val r = Random(0)
var s = 0.0 var s = 0.0
repeat(1000000) { s += e("x" to r.nextDouble()) } repeat(1000000) { s += e("x" to r.nextDouble()) }

View File

@ -1,53 +1,63 @@
package space.kscience.kmath.asm package space.kscience.kmath.asm
import space.kscience.kmath.ast.mstInField import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class TestAsmSpecialization { internal class TestAsmSpecialization {
@Test @Test
fun testUnaryPlus() { fun testUnaryPlus() {
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile() val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(2.0, expr("x" to 2.0)) assertEquals(2.0, expr("x" to 2.0))
} }
@Test @Test
fun testUnaryMinus() { fun testUnaryMinus() {
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile() val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr("x" to 2.0)) assertEquals(-2.0, expr("x" to 2.0))
} }
@Test @Test
fun testAdd() { fun testAdd() {
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile() val expr = MstExtendedField {
binaryOperationFunction("+")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0)) assertEquals(4.0, expr("x" to 2.0))
} }
@Test @Test
fun testSine() { fun testSine() {
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile() val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 0.0)) assertEquals(0.0, expr("x" to 0.0))
} }
@Test @Test
fun testMinus() { fun testMinus() {
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile() val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 2.0)) assertEquals(0.0, expr("x" to 2.0))
} }
@Test @Test
fun testDivide() { fun testDivide() {
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile() val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr("x" to 2.0)) assertEquals(1.0, expr("x" to 2.0))
} }
@Test @Test
fun testPower() { fun testPower() {
val expr = DoubleField val expr = MstExtendedField {
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) } binaryOperationFunction("pow")(bindSymbol("x"), number(2))
.compile() }.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0)) assertEquals(4.0, expr("x" to 2.0))
} }

View File

@ -1,8 +1,9 @@
package space.kscience.kmath.asm package space.kscience.kmath.asm
import space.kscience.kmath.ast.mstInRing import space.kscience.kmath.ast.MstRing
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.ByteRing import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
internal class TestAsmVariables { internal class TestAsmVariables {
@Test @Test
fun testVariable() { fun testVariable() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile() val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr("x" to 1.toByte())) assertEquals(1.toByte(), expr("x" to 1.toByte()))
} }
@Test @Test
fun testUndefinedVariableFails() { fun testUndefinedVariableFails() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile() val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() } assertFailsWith<NoSuchElementException> { expr() }
} }
} }

View File

@ -2,9 +2,9 @@ package space.kscience.kmath.ast
import space.kscience.kmath.complex.Complex import space.kscience.kmath.complex.Complex
import space.kscience.kmath.complex.ComplexField import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -18,7 +18,7 @@ internal class ParserTest {
@Test @Test
fun `evaluate MSTExpression`() { fun `evaluate MSTExpression`() {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() val res = MstField.invoke { number(2) + number(2) * (number(2) + number(2)) }.interpret(ComplexField)
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }

View File

@ -3,8 +3,9 @@ package space.kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.api.SFun import edu.umontreal.kotlingrad.api.SFun
import space.kscience.kmath.ast.MST import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.MstAlgebra import space.kscience.kmath.ast.MstAlgebra
import space.kscience.kmath.ast.MstExpression import space.kscience.kmath.ast.interpret
import space.kscience.kmath.expressions.DifferentiableExpression import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.misc.Symbol import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.NumericAlgebra import space.kscience.kmath.operations.NumericAlgebra
@ -18,38 +19,26 @@ import space.kscience.kmath.operations.NumericAlgebra
* @param A the [NumericAlgebra] of [T]. * @param A the [NumericAlgebra] of [T].
* @property expr the underlying [MstExpression]. * @property expr the underlying [MstExpression].
*/ */
public inline class DifferentiableMstExpression<T: Number, A>( public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
public val expr: MstExpression<T, A>, public val algebra: A,
) : DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T> { public val mst: MST,
) : DifferentiableExpression<T, Expression<T>> {
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst)) public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
/** public override fun derivativeOrNull(symbols: List<Symbol>): DifferentiableMstExpression<T, A> =
* The [MstExpression.algebra] of [expr]. DifferentiableMstExpression(
*/ algebra,
public val algebra: A symbols.map(Symbol::identity)
get() = expr.algebra .map(MstAlgebra::bindSymbol)
.map { it.toSVar<KMathNumber<T, A>>() }
/** .fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
* The [MstExpression.mst] of [expr]. .toMst(),
*/ )
public val mst: MST
get() = expr.mst
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstAlgebra::bindSymbol)
.map { it.toSVar<KMathNumber<T, A>>() }
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
} }
/** /**
* Wraps this [MstExpression] into [DifferentiableMstExpression]. * Wraps this [MST] into [DifferentiableMstExpression].
*/ */
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> = public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(this) DifferentiableMstExpression(algebra, this)

View File

@ -1,9 +1,8 @@
package space.kscience.kmath.kotlingrad package space.kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.api.* import edu.umontreal.kotlingrad.api.*
import space.kscience.kmath.asm.compile import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.ast.MstAlgebra import space.kscience.kmath.ast.MstAlgebra
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.ast.parseMath import space.kscience.kmath.ast.parseMath
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
@ -43,8 +42,8 @@ internal class AdaptingTests {
fun simpleFunctionDerivative() { fun simpleFunctionDerivative() {
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>() val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, DoubleField>>() val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
val actualDerivative = MstExpression(DoubleField, quadratic.d(x).toMst()).compile() val actualDerivative = quadratic.d(x).toMst().compileToExpression(DoubleField)
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile() val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
} }
@ -52,12 +51,11 @@ internal class AdaptingTests {
fun moreComplexDerivative() { fun moreComplexDerivative() {
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>() val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, DoubleField>>() val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
val actualDerivative = MstExpression(DoubleField, composition.d(x).toMst()).compile() val actualDerivative = composition.d(x).toMst().compileToExpression(DoubleField)
val expectedDerivative =
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath().compileToExpression(DoubleField)
val expectedDerivative = MstExpression(
DoubleField,
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
).compile()
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1)) assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
} }