Refactor MST

This commit is contained in:
Alexander Nozik 2021-04-01 20:15:49 +03:00
parent c2bab5d138
commit af4866e876
18 changed files with 241 additions and 365 deletions

View File

@ -5,6 +5,7 @@
- ScaleOperations interface
- Field extends ScaleOperations
- Basic integration API
- Basic MPP distributions and samplers
### Changed
- Exponential operations merged with hyperbolic functions
@ -14,6 +15,8 @@
- NDStructure and NDAlgebra to StructureND and AlgebraND respectively
- Real -> Double
- DataSets are moved from functions to core
- Redesign advanced Chain API
- Redesign MST. Remove MSTExpression.
### Deprecated
@ -21,6 +24,7 @@
- Nearest in Domain. To be implemented in geometry package.
- Number multiplication and division in main Algebra chain
- `contentEquals` from Buffer. It moved to the companion.
- MSTExpression
### Fixed

View File

@ -1,15 +1,17 @@
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
fun main() {
val expr = DoubleField.mstInField {
val x = bindSymbol("x")
val expr = MstField {
val x = bindSymbol(x)
x * 2.0 + number(2.0) / x - 16.0
}
repeat(10000000) {
expr.invoke("x" to 1.0)
expr.interpret(DoubleField, x to 1.0)
}
}

View File

@ -1,9 +1,9 @@
package space.kscience.kmath.ast
import space.kscience.kmath.asm.compile
import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.derivative
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.kotlingrad.differentiable
import space.kscience.kmath.kotlingrad.toDiffExpression
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
@ -14,11 +14,11 @@ import space.kscience.kmath.operations.DoubleField
fun main() {
val x by symbol
val actualDerivative = MstExpression(DoubleField, "x^2-4*x-44".parseMath())
.differentiable()
val actualDerivative = "x^2-4*x-44".parseMath()
.toDiffExpression(DoubleField)
.derivative(x)
.compile()
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile()
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
}

View File

@ -1,5 +1,8 @@
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.misc.StringSymbol
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
@ -76,11 +79,51 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
}
}
internal class InnerAlgebra<T : Any>(val algebra: Algebra<T>, val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun bindSymbol(value: String): T = try {
algebra.bindSymbol(value)
} catch (ignored: IllegalStateException) {
null
} ?: arguments.getValue(StringSymbol(value))
override fun unaryOperation(operation: String, arg: T): T =
algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
algebra.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
algebra.binaryOperationFunction(operation)
@Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else
error("Numeric nodes are not supported by $this")
}
/**
* Interprets the [MST] node with this [Algebra].
* Interprets the [MST] node with this [Algebra] and optional [arguments]
*/
public fun <T : Any> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
InnerAlgebra(algebra, arguments).evaluate(this)
/**
* Interprets the [MST] node with this [Algebra] and optional [arguments]
*
* @receiver the node to evaluate.
* @param algebra the algebra that provides operations.
* @return the value of expression.
*/
public fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this)
public fun <T : Any> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
interpret(algebra, mapOf(*arguments))
/**
* Interpret this [MST] as expression.
*/
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
interpret(algebra, arguments)
}

View File

@ -1,138 +0,0 @@
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.StringSymbol
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/**
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
* ASM-generated expressions.
*
* @property algebra the algebra that provides operations.
* @property mst the [MST] node.
* @author Alexander Nozik
*/
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun bindSymbol(value: String): T = try {
algebra.bindSymbol(value)
} catch (ignored: IllegalStateException) {
null
} ?: arguments.getValue(StringSymbol(value))
override fun unaryOperation(operation: String, arg: T): T =
algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
algebra.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
algebra.binaryOperationFunction(operation)
@Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else
error("Numeric nodes are not supported by $this")
}
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
}
/**
* Builds [MstExpression] over [Algebra].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E,
block: E.() -> MST,
): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
/**
* Builds [MstExpression] over [Group].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Group<T>> A.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstGroup.block())
}
/**
* Builds [MstExpression] over [Ring].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstRing.block())
}
/**
* Builds [MstExpression] over [Field].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstField.block())
}
/**
* Builds [MstExpression] over [ExtendedField].
*
* @author Iaroslav Postovalov
*/
public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstExtendedField.block())
}
/**
* Builds [MstExpression] over [FunctionalExpressionGroup].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Group<T>> FunctionalExpressionGroup<T, A>.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInGroup(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionRing].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInRing(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionField].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInField(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
*
* @author Iaroslav Postovalov
*/
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
block: MstExtendedField.() -> MST,
): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block)
}

View File

@ -2,10 +2,11 @@ package space.kscience.kmath.estree
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.MST.*
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.estree.internal.ESTreeBuilder
import space.kscience.kmath.estree.internal.estree.BaseExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
@ -64,19 +65,21 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
}
/**
* Create a compiled expression with given [MST] and given [algebra].
*/
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = compileWith(algebra)
/**
* Compiles an [MST] to ESTree generated expression using given algebra.
*
* @author Iaroslav Postovalov
* Compile given MST to expression and evaluate it against [arguments]
*/
public fun <T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
mst.compileWith(this)
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
compileToExpression(algebra).invoke(arguments)
/**
* Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression.
*
* @author Iaroslav Postovalov
* Compile given MST to expression and evaluate it against [arguments]
*/
public fun <T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
mst.compileWith(algebra)
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
compileToExpression(algebra).invoke(*arguments)

View File

@ -3,16 +3,19 @@ package space.kscience.kmath.estree
import space.kscience.kmath.ast.*
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeConsistencyWithInterpreter {
@Test
fun mstSpace() {
val res1 = MstGroup.mstInGroup {
val mst = MstGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
@ -23,27 +26,17 @@ internal class TestESTreeConsistencyWithInterpreter {
number(1)
) + bindSymbol("x") + zero
}("x" to MST.Numeric(2))
}
val res2 = MstGroup.mstInGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol("x") + zero
}.compile()("x" to MST.Numeric(2))
assertEquals(res1, res2)
assertEquals(
mst.interpret(MstGroup, Symbol.x to MST.Numeric(2)),
mst.compile(MstGroup, Symbol.x to MST.Numeric(2))
)
}
@Test
fun byteRing() {
val res1 = ByteRing.mstInRing {
val mst = MstRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale(
@ -54,62 +47,43 @@ internal class TestESTreeConsistencyWithInterpreter {
number(1)
) * number(2)
}("x" to 3.toByte())
}
val res2 = ByteRing.mstInRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale(
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}.compile()("x" to 3.toByte())
assertEquals(res1, res2)
assertEquals(
mst.interpret(ByteRing, Symbol.x to 3.toByte()),
mst.compile(ByteRing, Symbol.x to 3.toByte())
)
}
@Test
fun realField() {
val res1 = DoubleField.mstInField {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}("x" to 2.0)
}
val res2 = DoubleField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0)
assertEquals(res1, res2)
assertEquals(
mst.interpret(DoubleField, Symbol.x to 2.0),
mst.compile(DoubleField, Symbol.x to 2.0)
)
}
@Test
fun complexField() {
val res1 = ComplexField.mstInField {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}("x" to 2.0.toComplex())
}
val res2 = ComplexField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0.toComplex())
assertEquals(res1, res2)
assertEquals(
mst.interpret(ComplexField, Symbol.x to 2.0.toComplex()),
mst.compile(ComplexField, Symbol.x to 2.0.toComplex())
)
}
}

View File

@ -1,10 +1,9 @@
package space.kscience.kmath.estree
import space.kscience.kmath.ast.mstInExtendedField
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.ast.mstInGroup
import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
@ -12,29 +11,29 @@ import kotlin.test.assertEquals
internal class TestESTreeOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile()
val expression = MstExtendedField { -bindSymbol("x") }.compileToExpression(DoubleField)
val res = expression("x" to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile()
val expression = MstExtendedField { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
val res = expression("x" to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0)
val res = MstExtendedField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
assertEquals(4.0, res)
}
@Test
fun testMultipleCalls() {
val e =
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compile()
MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compileToExpression(DoubleField)
val r = Random(0)
var s = 0.0
repeat(1000000) { s += e("x" to r.nextDouble()) }

View File

@ -1,53 +1,63 @@
package space.kscience.kmath.estree
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeSpecialization {
@Test
fun testUnaryPlus() {
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(2.0, expr("x" to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr("x" to 2.0))
}
@Test
fun testAdd() {
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("+")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0))
}
@Test
fun testSine() {
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 0.0))
}
@Test
fun testMinus() {
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 2.0))
}
@Test
fun testDivide() {
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr("x" to 2.0))
}
@Test
fun testPower() {
val expr = DoubleField
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) }
.compile()
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol("x"), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0))
}

View File

@ -1,8 +1,9 @@
package space.kscience.kmath.estree
import space.kscience.kmath.ast.mstInRing
import space.kscience.kmath.ast.MstRing
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
internal class TestESTreeVariables {
@Test
fun testVariable() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
val expr = MstRing{ bindSymbol("x") }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr("x" to 1.toByte()))
}
@Test
fun testUndefinedVariableFails() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() }
}
}

View File

@ -4,8 +4,9 @@ import space.kscience.kmath.asm.internal.AsmBuilder
import space.kscience.kmath.asm.internal.buildName
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.MST.*
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
@ -70,18 +71,22 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
}
/**
* Compiles an [MST] to ASM using given algebra.
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
mst.compileWith(T::class.java, this)
/**
* Optimizes performance of an [MstExpression] using ASM codegen.
*
* @author Alexander Nozik
* Create a compiled expression with given [MST] and given [algebra].
*/
public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
mst.compileWith(T::class.java, algebra)
public inline fun <reified T: Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
compileWith(T::class.java, algebra)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
compileToExpression(algebra).invoke(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
compileToExpression(algebra).invoke(*arguments)

View File

@ -3,16 +3,19 @@ package space.kscience.kmath.asm
import space.kscience.kmath.ast.*
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmConsistencyWithInterpreter {
@Test
fun mstSpace() {
val res1 = MstGroup.mstInGroup {
val mst = MstGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
@ -23,27 +26,17 @@ internal class TestAsmConsistencyWithInterpreter {
number(1)
) + bindSymbol("x") + zero
}("x" to MST.Numeric(2))
}
val res2 = MstGroup.mstInGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol("x") + zero
}.compile()("x" to MST.Numeric(2))
assertEquals(res1, res2)
assertEquals(
mst.interpret(MstGroup, x to MST.Numeric(2)),
mst.compile(MstGroup, x to MST.Numeric(2))
)
}
@Test
fun byteRing() {
val res1 = ByteRing.mstInRing {
val mst = MstRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale(
@ -54,62 +47,43 @@ internal class TestAsmConsistencyWithInterpreter {
number(1)
) * number(2)
}("x" to 3.toByte())
}
val res2 = ByteRing.mstInRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale(
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}.compile()("x" to 3.toByte())
assertEquals(res1, res2)
assertEquals(
mst.interpret(ByteRing, x to 3.toByte()),
mst.compile(ByteRing, x to 3.toByte())
)
}
@Test
fun realField() {
val res1 = DoubleField.mstInField {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}("x" to 2.0)
}
val res2 = DoubleField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0)
assertEquals(res1, res2)
assertEquals(
mst.interpret(DoubleField, x to 2.0),
mst.compile(DoubleField, x to 2.0)
)
}
@Test
fun complexField() {
val res1 = ComplexField.mstInField {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}("x" to 2.0.toComplex())
}
val res2 = ComplexField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0.toComplex())
assertEquals(res1, res2)
assertEquals(
mst.interpret(ComplexField, x to 2.0.toComplex()),
mst.compile(ComplexField, x to 2.0.toComplex())
)
}
}

View File

@ -1,10 +1,11 @@
package space.kscience.kmath.asm
import space.kscience.kmath.ast.mstInExtendedField
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.ast.mstInGroup
import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.ast.MstField
import space.kscience.kmath.ast.MstGroup
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
@ -12,29 +13,29 @@ import kotlin.test.assertEquals
internal class TestAsmOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile()
val expression = MstGroup { -bindSymbol("x") }.compileToExpression(DoubleField)
val res = expression("x" to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile()
val expression = MstGroup { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
val res = expression("x" to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0)
val res = MstField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
assertEquals(4.0, res)
}
@Test
fun testMultipleCalls() {
val e =
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compile()
MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compileToExpression(DoubleField)
val r = Random(0)
var s = 0.0
repeat(1000000) { s += e("x" to r.nextDouble()) }

View File

@ -1,53 +1,63 @@
package space.kscience.kmath.asm
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmSpecialization {
@Test
fun testUnaryPlus() {
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(2.0, expr("x" to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr("x" to 2.0))
}
@Test
fun testAdd() {
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("+")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0))
}
@Test
fun testSine() {
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 0.0))
}
@Test
fun testMinus() {
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 2.0))
}
@Test
fun testDivide() {
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr("x" to 2.0))
}
@Test
fun testPower() {
val expr = DoubleField
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) }
.compile()
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol("x"), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0))
}

View File

@ -1,8 +1,9 @@
package space.kscience.kmath.asm
import space.kscience.kmath.ast.mstInRing
import space.kscience.kmath.ast.MstRing
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
internal class TestAsmVariables {
@Test
fun testVariable() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr("x" to 1.toByte()))
}
@Test
fun testUndefinedVariableFails() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() }
}
}

View File

@ -2,9 +2,9 @@ package space.kscience.kmath.ast
import space.kscience.kmath.complex.Complex
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
@ -18,7 +18,7 @@ internal class ParserTest {
@Test
fun `evaluate MSTExpression`() {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
val res = MstField.invoke { number(2) + number(2) * (number(2) + number(2)) }.interpret(ComplexField)
assertEquals(Complex(10.0, 0.0), res)
}

View File

@ -3,8 +3,9 @@ package space.kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.api.SFun
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.MstAlgebra
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.ast.interpret
import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.NumericAlgebra
@ -18,38 +19,26 @@ import space.kscience.kmath.operations.NumericAlgebra
* @param A the [NumericAlgebra] of [T].
* @property expr the underlying [MstExpression].
*/
public inline class DifferentiableMstExpression<T: Number, A>(
public val expr: MstExpression<T, A>,
) : DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T> {
public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
public val algebra: A,
public val mst: MST,
) : DifferentiableExpression<T, Expression<T>> {
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
/**
* The [MstExpression.algebra] of [expr].
*/
public val algebra: A
get() = expr.algebra
/**
* The [MstExpression.mst] of [expr].
*/
public val mst: MST
get() = expr.mst
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstAlgebra::bindSymbol)
.map { it.toSVar<KMathNumber<T, A>>() }
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
public override fun derivativeOrNull(symbols: List<Symbol>): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstAlgebra::bindSymbol)
.map { it.toSVar<KMathNumber<T, A>>() }
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
}
/**
* Wraps this [MstExpression] into [DifferentiableMstExpression].
* Wraps this [MST] into [DifferentiableMstExpression].
*/
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(this)
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(algebra, this)

View File

@ -1,9 +1,8 @@
package space.kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.api.*
import space.kscience.kmath.asm.compile
import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.ast.MstAlgebra
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.ast.parseMath
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
@ -43,8 +42,8 @@ internal class AdaptingTests {
fun simpleFunctionDerivative() {
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
val actualDerivative = MstExpression(DoubleField, quadratic.d(x).toMst()).compile()
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile()
val actualDerivative = quadratic.d(x).toMst().compileToExpression(DoubleField)
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
}
@ -52,12 +51,11 @@ internal class AdaptingTests {
fun moreComplexDerivative() {
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
val actualDerivative = MstExpression(DoubleField, composition.d(x).toMst()).compile()
val actualDerivative = composition.d(x).toMst().compileToExpression(DoubleField)
val expectedDerivative =
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath().compileToExpression(DoubleField)
val expectedDerivative = MstExpression(
DoubleField,
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
).compile()
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
}