forked from kscience/kmath
Refactor MST
This commit is contained in:
parent
c2bab5d138
commit
af4866e876
@ -5,6 +5,7 @@
|
||||
- ScaleOperations interface
|
||||
- Field extends ScaleOperations
|
||||
- Basic integration API
|
||||
- Basic MPP distributions and samplers
|
||||
|
||||
### Changed
|
||||
- Exponential operations merged with hyperbolic functions
|
||||
@ -14,6 +15,8 @@
|
||||
- NDStructure and NDAlgebra to StructureND and AlgebraND respectively
|
||||
- Real -> Double
|
||||
- DataSets are moved from functions to core
|
||||
- Redesign advanced Chain API
|
||||
- Redesign MST. Remove MSTExpression.
|
||||
|
||||
### Deprecated
|
||||
|
||||
@ -21,6 +24,7 @@
|
||||
- Nearest in Domain. To be implemented in geometry package.
|
||||
- Number multiplication and division in main Algebra chain
|
||||
- `contentEquals` from Buffer. It moved to the companion.
|
||||
- MSTExpression
|
||||
|
||||
### Fixed
|
||||
|
||||
|
@ -1,15 +1,17 @@
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.Symbol.Companion.x
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.invoke
|
||||
|
||||
fun main() {
|
||||
val expr = DoubleField.mstInField {
|
||||
val x = bindSymbol("x")
|
||||
val expr = MstField {
|
||||
val x = bindSymbol(x)
|
||||
x * 2.0 + number(2.0) / x - 16.0
|
||||
}
|
||||
|
||||
repeat(10000000) {
|
||||
expr.invoke("x" to 1.0)
|
||||
expr.interpret(DoubleField, x to 1.0)
|
||||
}
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.asm.compile
|
||||
import space.kscience.kmath.asm.compileToExpression
|
||||
import space.kscience.kmath.expressions.derivative
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.kotlingrad.differentiable
|
||||
import space.kscience.kmath.kotlingrad.toDiffExpression
|
||||
import space.kscience.kmath.misc.symbol
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
|
||||
@ -14,11 +14,11 @@ import space.kscience.kmath.operations.DoubleField
|
||||
fun main() {
|
||||
val x by symbol
|
||||
|
||||
val actualDerivative = MstExpression(DoubleField, "x^2-4*x-44".parseMath())
|
||||
.differentiable()
|
||||
val actualDerivative = "x^2-4*x-44".parseMath()
|
||||
.toDiffExpression(DoubleField)
|
||||
.derivative(x)
|
||||
.compile()
|
||||
|
||||
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile()
|
||||
|
||||
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
|
||||
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.misc.StringSymbol
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
import space.kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
@ -76,11 +79,51 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||
}
|
||||
}
|
||||
|
||||
internal class InnerAlgebra<T : Any>(val algebra: Algebra<T>, val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||
override fun bindSymbol(value: String): T = try {
|
||||
algebra.bindSymbol(value)
|
||||
} catch (ignored: IllegalStateException) {
|
||||
null
|
||||
} ?: arguments.getValue(StringSymbol(value))
|
||||
|
||||
override fun unaryOperation(operation: String, arg: T): T =
|
||||
algebra.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
algebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
|
||||
algebra.unaryOperationFunction(operation)
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
|
||||
algebra.binaryOperationFunction(operation)
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||
(algebra as NumericAlgebra<T>).number(value)
|
||||
else
|
||||
error("Numeric nodes are not supported by $this")
|
||||
}
|
||||
|
||||
/**
|
||||
* Interprets the [MST] node with this [Algebra].
|
||||
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
||||
*/
|
||||
public fun <T : Any> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||
InnerAlgebra(algebra, arguments).evaluate(this)
|
||||
|
||||
/**
|
||||
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
||||
*
|
||||
* @receiver the node to evaluate.
|
||||
* @param algebra the algebra that provides operations.
|
||||
* @return the value of expression.
|
||||
*/
|
||||
public fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this)
|
||||
public fun <T : Any> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
||||
interpret(algebra, mapOf(*arguments))
|
||||
|
||||
/**
|
||||
* Interpret this [MST] as expression.
|
||||
*/
|
||||
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
|
||||
interpret(algebra, arguments)
|
||||
}
|
||||
|
@ -1,138 +0,0 @@
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.misc.StringSymbol
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.operations.*
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
/**
|
||||
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
|
||||
* ASM-generated expressions.
|
||||
*
|
||||
* @property algebra the algebra that provides operations.
|
||||
* @property mst the [MST] node.
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||
override fun bindSymbol(value: String): T = try {
|
||||
algebra.bindSymbol(value)
|
||||
} catch (ignored: IllegalStateException) {
|
||||
null
|
||||
} ?: arguments.getValue(StringSymbol(value))
|
||||
|
||||
override fun unaryOperation(operation: String, arg: T): T =
|
||||
algebra.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
algebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
|
||||
algebra.unaryOperationFunction(operation)
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
|
||||
algebra.binaryOperationFunction(operation)
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||
(algebra as NumericAlgebra<T>).number(value)
|
||||
else
|
||||
error("Numeric nodes are not supported by $this")
|
||||
}
|
||||
|
||||
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Algebra].
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||
mstAlgebra: E,
|
||||
block: E.() -> MST,
|
||||
): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Group].
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Group<T>> A.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstGroup.block())
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Ring].
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstRing.block())
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Field].
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstField.block())
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [ExtendedField].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstExtendedField.block())
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [FunctionalExpressionGroup].
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Group<T>> FunctionalExpressionGroup<T, A>.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInGroup(block)
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [FunctionalExpressionRing].
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInRing(block)
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [FunctionalExpressionField].
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInField(block)
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||
block: MstExtendedField.() -> MST,
|
||||
): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInExtendedField(block)
|
||||
}
|
@ -2,10 +2,11 @@ package space.kscience.kmath.estree
|
||||
|
||||
import space.kscience.kmath.ast.MST
|
||||
import space.kscience.kmath.ast.MST.*
|
||||
import space.kscience.kmath.ast.MstExpression
|
||||
import space.kscience.kmath.estree.internal.ESTreeBuilder
|
||||
import space.kscience.kmath.estree.internal.estree.BaseExpression
|
||||
import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
import space.kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
@ -64,19 +65,21 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
||||
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a compiled expression with given [MST] and given [algebra].
|
||||
*/
|
||||
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = compileWith(algebra)
|
||||
|
||||
|
||||
/**
|
||||
* Compiles an [MST] to ESTree generated expression using given algebra.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
* Compile given MST to expression and evaluate it against [arguments]
|
||||
*/
|
||||
public fun <T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
|
||||
mst.compileWith(this)
|
||||
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||
compileToExpression(algebra).invoke(arguments)
|
||||
|
||||
|
||||
/**
|
||||
* Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
* Compile given MST to expression and evaluate it against [arguments]
|
||||
*/
|
||||
public fun <T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
|
||||
mst.compileWith(algebra)
|
||||
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
|
||||
compileToExpression(algebra).invoke(*arguments)
|
||||
|
@ -3,16 +3,19 @@ package space.kscience.kmath.estree
|
||||
import space.kscience.kmath.ast.*
|
||||
import space.kscience.kmath.complex.ComplexField
|
||||
import space.kscience.kmath.complex.toComplex
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.operations.ByteRing
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestESTreeConsistencyWithInterpreter {
|
||||
|
||||
@Test
|
||||
fun mstSpace() {
|
||||
val res1 = MstGroup.mstInGroup {
|
||||
|
||||
val mst = MstGroup {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
||||
@ -23,27 +26,17 @@ internal class TestESTreeConsistencyWithInterpreter {
|
||||
|
||||
number(1)
|
||||
) + bindSymbol("x") + zero
|
||||
}("x" to MST.Numeric(2))
|
||||
}
|
||||
|
||||
val res2 = MstGroup.mstInGroup {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
||||
add(number(1), number(1)),
|
||||
2.0
|
||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + bindSymbol("x") + zero
|
||||
}.compile()("x" to MST.Numeric(2))
|
||||
|
||||
assertEquals(res1, res2)
|
||||
assertEquals(
|
||||
mst.interpret(MstGroup, Symbol.x to MST.Numeric(2)),
|
||||
mst.compile(MstGroup, Symbol.x to MST.Numeric(2))
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun byteRing() {
|
||||
val res1 = ByteRing.mstInRing {
|
||||
val mst = MstRing {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(bindSymbol("x") - (2.toByte() + (scale(
|
||||
@ -54,62 +47,43 @@ internal class TestESTreeConsistencyWithInterpreter {
|
||||
|
||||
number(1)
|
||||
) * number(2)
|
||||
}("x" to 3.toByte())
|
||||
}
|
||||
|
||||
val res2 = ByteRing.mstInRing {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(bindSymbol("x") - (2.toByte() + (scale(
|
||||
add(number(1), number(1)),
|
||||
2.0
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
number(1)
|
||||
) * number(2)
|
||||
}.compile()("x" to 3.toByte())
|
||||
|
||||
assertEquals(res1, res2)
|
||||
assertEquals(
|
||||
mst.interpret(ByteRing, Symbol.x to 3.toByte()),
|
||||
mst.compile(ByteRing, Symbol.x to 3.toByte())
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun realField() {
|
||||
val res1 = DoubleField.mstInField {
|
||||
val mst = MstField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}("x" to 2.0)
|
||||
}
|
||||
|
||||
val res2 = DoubleField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}.compile()("x" to 2.0)
|
||||
|
||||
assertEquals(res1, res2)
|
||||
assertEquals(
|
||||
mst.interpret(DoubleField, Symbol.x to 2.0),
|
||||
mst.compile(DoubleField, Symbol.x to 2.0)
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun complexField() {
|
||||
val res1 = ComplexField.mstInField {
|
||||
val mst = MstField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}("x" to 2.0.toComplex())
|
||||
}
|
||||
|
||||
val res2 = ComplexField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}.compile()("x" to 2.0.toComplex())
|
||||
|
||||
assertEquals(res1, res2)
|
||||
assertEquals(
|
||||
mst.interpret(ComplexField, Symbol.x to 2.0.toComplex()),
|
||||
mst.compile(ComplexField, Symbol.x to 2.0.toComplex())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,9 @@
|
||||
package space.kscience.kmath.estree
|
||||
|
||||
import space.kscience.kmath.ast.mstInExtendedField
|
||||
import space.kscience.kmath.ast.mstInField
|
||||
import space.kscience.kmath.ast.mstInGroup
|
||||
import space.kscience.kmath.ast.MstExtendedField
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.random.Random
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
@ -12,29 +11,29 @@ import kotlin.test.assertEquals
|
||||
internal class TestESTreeOperationsSupport {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile()
|
||||
val expression = MstExtendedField { -bindSymbol("x") }.compileToExpression(DoubleField)
|
||||
val res = expression("x" to 2.0)
|
||||
assertEquals(-2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBinaryOperationInvocation() {
|
||||
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile()
|
||||
val expression = MstExtendedField { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
|
||||
val res = expression("x" to 2.0)
|
||||
assertEquals(-1.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testConstProductInvocation() {
|
||||
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0)
|
||||
val res = MstExtendedField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMultipleCalls() {
|
||||
val e =
|
||||
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
|
||||
.compile()
|
||||
MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
|
||||
.compileToExpression(DoubleField)
|
||||
val r = Random(0)
|
||||
var s = 0.0
|
||||
repeat(1000000) { s += e("x" to r.nextDouble()) }
|
||||
|
@ -1,53 +1,63 @@
|
||||
package space.kscience.kmath.estree
|
||||
|
||||
import space.kscience.kmath.ast.mstInField
|
||||
import space.kscience.kmath.ast.MstExtendedField
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestESTreeSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||
assertEquals(2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||
assertEquals(-2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("+")(bindSymbol("x"),
|
||||
bindSymbol("x"))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() {
|
||||
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr("x" to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("-")(bindSymbol("x"),
|
||||
bindSymbol("x"))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("/")(bindSymbol("x"),
|
||||
bindSymbol("x"))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(1.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
val expr = DoubleField
|
||||
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) }
|
||||
.compile()
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("pow")(bindSymbol("x"), number(2))
|
||||
}.compileToExpression(DoubleField)
|
||||
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
|
@ -1,8 +1,9 @@
|
||||
package space.kscience.kmath.estree
|
||||
|
||||
import space.kscience.kmath.ast.mstInRing
|
||||
import space.kscience.kmath.ast.MstRing
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.operations.ByteRing
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
|
||||
internal class TestESTreeVariables {
|
||||
@Test
|
||||
fun testVariable() {
|
||||
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
|
||||
val expr = MstRing{ bindSymbol("x") }.compileToExpression(ByteRing)
|
||||
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUndefinedVariableFails() {
|
||||
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
|
||||
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
|
||||
assertFailsWith<NoSuchElementException> { expr() }
|
||||
}
|
||||
}
|
||||
|
@ -4,8 +4,9 @@ import space.kscience.kmath.asm.internal.AsmBuilder
|
||||
import space.kscience.kmath.asm.internal.buildName
|
||||
import space.kscience.kmath.ast.MST
|
||||
import space.kscience.kmath.ast.MST.*
|
||||
import space.kscience.kmath.ast.MstExpression
|
||||
import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
import space.kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
@ -70,18 +71,22 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
||||
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
|
||||
}
|
||||
|
||||
/**
|
||||
* Compiles an [MST] to ASM using given algebra.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
|
||||
mst.compileWith(T::class.java, this)
|
||||
|
||||
/**
|
||||
* Optimizes performance of an [MstExpression] using ASM codegen.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
* Create a compiled expression with given [MST] and given [algebra].
|
||||
*/
|
||||
public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
|
||||
mst.compileWith(T::class.java, algebra)
|
||||
public inline fun <reified T: Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
|
||||
compileWith(T::class.java, algebra)
|
||||
|
||||
|
||||
/**
|
||||
* Compile given MST to expression and evaluate it against [arguments]
|
||||
*/
|
||||
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||
compileToExpression(algebra).invoke(arguments)
|
||||
|
||||
/**
|
||||
* Compile given MST to expression and evaluate it against [arguments]
|
||||
*/
|
||||
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
|
||||
compileToExpression(algebra).invoke(*arguments)
|
||||
|
@ -3,16 +3,19 @@ package space.kscience.kmath.asm
|
||||
import space.kscience.kmath.ast.*
|
||||
import space.kscience.kmath.complex.ComplexField
|
||||
import space.kscience.kmath.complex.toComplex
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.misc.Symbol.Companion.x
|
||||
import space.kscience.kmath.operations.ByteRing
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmConsistencyWithInterpreter {
|
||||
|
||||
@Test
|
||||
fun mstSpace() {
|
||||
val res1 = MstGroup.mstInGroup {
|
||||
|
||||
val mst = MstGroup {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
||||
@ -23,27 +26,17 @@ internal class TestAsmConsistencyWithInterpreter {
|
||||
|
||||
number(1)
|
||||
) + bindSymbol("x") + zero
|
||||
}("x" to MST.Numeric(2))
|
||||
}
|
||||
|
||||
val res2 = MstGroup.mstInGroup {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
||||
add(number(1), number(1)),
|
||||
2.0
|
||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + bindSymbol("x") + zero
|
||||
}.compile()("x" to MST.Numeric(2))
|
||||
|
||||
assertEquals(res1, res2)
|
||||
assertEquals(
|
||||
mst.interpret(MstGroup, x to MST.Numeric(2)),
|
||||
mst.compile(MstGroup, x to MST.Numeric(2))
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun byteRing() {
|
||||
val res1 = ByteRing.mstInRing {
|
||||
val mst = MstRing {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(bindSymbol("x") - (2.toByte() + (scale(
|
||||
@ -54,62 +47,43 @@ internal class TestAsmConsistencyWithInterpreter {
|
||||
|
||||
number(1)
|
||||
) * number(2)
|
||||
}("x" to 3.toByte())
|
||||
}
|
||||
|
||||
val res2 = ByteRing.mstInRing {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(bindSymbol("x") - (2.toByte() + (scale(
|
||||
add(number(1), number(1)),
|
||||
2.0
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
number(1)
|
||||
) * number(2)
|
||||
}.compile()("x" to 3.toByte())
|
||||
|
||||
assertEquals(res1, res2)
|
||||
assertEquals(
|
||||
mst.interpret(ByteRing, x to 3.toByte()),
|
||||
mst.compile(ByteRing, x to 3.toByte())
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun realField() {
|
||||
val res1 = DoubleField.mstInField {
|
||||
val mst = MstField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}("x" to 2.0)
|
||||
}
|
||||
|
||||
val res2 = DoubleField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}.compile()("x" to 2.0)
|
||||
|
||||
assertEquals(res1, res2)
|
||||
assertEquals(
|
||||
mst.interpret(DoubleField, x to 2.0),
|
||||
mst.compile(DoubleField, x to 2.0)
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun complexField() {
|
||||
val res1 = ComplexField.mstInField {
|
||||
val mst = MstField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}("x" to 2.0.toComplex())
|
||||
}
|
||||
|
||||
val res2 = ComplexField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}.compile()("x" to 2.0.toComplex())
|
||||
|
||||
assertEquals(res1, res2)
|
||||
assertEquals(
|
||||
mst.interpret(ComplexField, x to 2.0.toComplex()),
|
||||
mst.compile(ComplexField, x to 2.0.toComplex())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,11 @@
|
||||
package space.kscience.kmath.asm
|
||||
|
||||
import space.kscience.kmath.ast.mstInExtendedField
|
||||
import space.kscience.kmath.ast.mstInField
|
||||
import space.kscience.kmath.ast.mstInGroup
|
||||
import space.kscience.kmath.ast.MstExtendedField
|
||||
import space.kscience.kmath.ast.MstField
|
||||
import space.kscience.kmath.ast.MstGroup
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.random.Random
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
@ -12,29 +13,29 @@ import kotlin.test.assertEquals
|
||||
internal class TestAsmOperationsSupport {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile()
|
||||
val expression = MstGroup { -bindSymbol("x") }.compileToExpression(DoubleField)
|
||||
val res = expression("x" to 2.0)
|
||||
assertEquals(-2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBinaryOperationInvocation() {
|
||||
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile()
|
||||
val expression = MstGroup { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
|
||||
val res = expression("x" to 2.0)
|
||||
assertEquals(-1.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testConstProductInvocation() {
|
||||
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0)
|
||||
val res = MstField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMultipleCalls() {
|
||||
val e =
|
||||
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
|
||||
.compile()
|
||||
MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
|
||||
.compileToExpression(DoubleField)
|
||||
val r = Random(0)
|
||||
var s = 0.0
|
||||
repeat(1000000) { s += e("x" to r.nextDouble()) }
|
||||
|
@ -1,53 +1,63 @@
|
||||
package space.kscience.kmath.asm
|
||||
|
||||
import space.kscience.kmath.ast.mstInField
|
||||
import space.kscience.kmath.ast.MstExtendedField
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||
assertEquals(2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||
assertEquals(-2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("+")(bindSymbol("x"),
|
||||
bindSymbol("x"))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() {
|
||||
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr("x" to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("-")(bindSymbol("x"),
|
||||
bindSymbol("x"))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(0.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("/")(bindSymbol("x"),
|
||||
bindSymbol("x"))
|
||||
}.compileToExpression(DoubleField)
|
||||
assertEquals(1.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
val expr = DoubleField
|
||||
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) }
|
||||
.compile()
|
||||
val expr = MstExtendedField {
|
||||
binaryOperationFunction("pow")(bindSymbol("x"), number(2))
|
||||
}.compileToExpression(DoubleField)
|
||||
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
|
@ -1,8 +1,9 @@
|
||||
package space.kscience.kmath.asm
|
||||
|
||||
import space.kscience.kmath.ast.mstInRing
|
||||
import space.kscience.kmath.ast.MstRing
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.operations.ByteRing
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
|
||||
internal class TestAsmVariables {
|
||||
@Test
|
||||
fun testVariable() {
|
||||
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
|
||||
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
|
||||
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUndefinedVariableFails() {
|
||||
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
|
||||
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
|
||||
assertFailsWith<NoSuchElementException> { expr() }
|
||||
}
|
||||
}
|
||||
|
@ -2,9 +2,9 @@ package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.complex.Complex
|
||||
import space.kscience.kmath.complex.ComplexField
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@ -18,7 +18,7 @@ internal class ParserTest {
|
||||
|
||||
@Test
|
||||
fun `evaluate MSTExpression`() {
|
||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
||||
val res = MstField.invoke { number(2) + number(2) * (number(2) + number(2)) }.interpret(ComplexField)
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
|
||||
|
@ -3,8 +3,9 @@ package space.kscience.kmath.kotlingrad
|
||||
import edu.umontreal.kotlingrad.api.SFun
|
||||
import space.kscience.kmath.ast.MST
|
||||
import space.kscience.kmath.ast.MstAlgebra
|
||||
import space.kscience.kmath.ast.MstExpression
|
||||
import space.kscience.kmath.ast.interpret
|
||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
||||
import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.misc.Symbol
|
||||
import space.kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
@ -18,38 +19,26 @@ import space.kscience.kmath.operations.NumericAlgebra
|
||||
* @param A the [NumericAlgebra] of [T].
|
||||
* @property expr the underlying [MstExpression].
|
||||
*/
|
||||
public inline class DifferentiableMstExpression<T: Number, A>(
|
||||
public val expr: MstExpression<T, A>,
|
||||
) : DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T> {
|
||||
public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
|
||||
public val algebra: A,
|
||||
public val mst: MST,
|
||||
) : DifferentiableExpression<T, Expression<T>> {
|
||||
|
||||
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
||||
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
|
||||
|
||||
/**
|
||||
* The [MstExpression.algebra] of [expr].
|
||||
*/
|
||||
public val algebra: A
|
||||
get() = expr.algebra
|
||||
|
||||
/**
|
||||
* The [MstExpression.mst] of [expr].
|
||||
*/
|
||||
public val mst: MST
|
||||
get() = expr.mst
|
||||
|
||||
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
|
||||
algebra,
|
||||
symbols.map(Symbol::identity)
|
||||
.map(MstAlgebra::bindSymbol)
|
||||
.map { it.toSVar<KMathNumber<T, A>>() }
|
||||
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
|
||||
.toMst(),
|
||||
)
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): DifferentiableMstExpression<T, A> =
|
||||
DifferentiableMstExpression(
|
||||
algebra,
|
||||
symbols.map(Symbol::identity)
|
||||
.map(MstAlgebra::bindSymbol)
|
||||
.map { it.toSVar<KMathNumber<T, A>>() }
|
||||
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
|
||||
.toMst(),
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
||||
* Wraps this [MST] into [DifferentiableMstExpression].
|
||||
*/
|
||||
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
|
||||
DifferentiableMstExpression(this)
|
||||
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): DifferentiableMstExpression<T, A> =
|
||||
DifferentiableMstExpression(algebra, this)
|
||||
|
@ -1,9 +1,8 @@
|
||||
package space.kscience.kmath.kotlingrad
|
||||
|
||||
import edu.umontreal.kotlingrad.api.*
|
||||
import space.kscience.kmath.asm.compile
|
||||
import space.kscience.kmath.asm.compileToExpression
|
||||
import space.kscience.kmath.ast.MstAlgebra
|
||||
import space.kscience.kmath.ast.MstExpression
|
||||
import space.kscience.kmath.ast.parseMath
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
@ -43,8 +42,8 @@ internal class AdaptingTests {
|
||||
fun simpleFunctionDerivative() {
|
||||
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
|
||||
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
|
||||
val actualDerivative = MstExpression(DoubleField, quadratic.d(x).toMst()).compile()
|
||||
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile()
|
||||
val actualDerivative = quadratic.d(x).toMst().compileToExpression(DoubleField)
|
||||
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
|
||||
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||
}
|
||||
|
||||
@ -52,12 +51,11 @@ internal class AdaptingTests {
|
||||
fun moreComplexDerivative() {
|
||||
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
|
||||
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
|
||||
val actualDerivative = MstExpression(DoubleField, composition.d(x).toMst()).compile()
|
||||
val actualDerivative = composition.d(x).toMst().compileToExpression(DoubleField)
|
||||
|
||||
val expectedDerivative =
|
||||
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath().compileToExpression(DoubleField)
|
||||
|
||||
val expectedDerivative = MstExpression(
|
||||
DoubleField,
|
||||
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
|
||||
).compile()
|
||||
|
||||
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user