Refactor MST
This commit is contained in:
parent
c2bab5d138
commit
af4866e876
@ -5,6 +5,7 @@
|
|||||||
- ScaleOperations interface
|
- ScaleOperations interface
|
||||||
- Field extends ScaleOperations
|
- Field extends ScaleOperations
|
||||||
- Basic integration API
|
- Basic integration API
|
||||||
|
- Basic MPP distributions and samplers
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Exponential operations merged with hyperbolic functions
|
- Exponential operations merged with hyperbolic functions
|
||||||
@ -14,6 +15,8 @@
|
|||||||
- NDStructure and NDAlgebra to StructureND and AlgebraND respectively
|
- NDStructure and NDAlgebra to StructureND and AlgebraND respectively
|
||||||
- Real -> Double
|
- Real -> Double
|
||||||
- DataSets are moved from functions to core
|
- DataSets are moved from functions to core
|
||||||
|
- Redesign advanced Chain API
|
||||||
|
- Redesign MST. Remove MSTExpression.
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
@ -21,6 +24,7 @@
|
|||||||
- Nearest in Domain. To be implemented in geometry package.
|
- Nearest in Domain. To be implemented in geometry package.
|
||||||
- Number multiplication and division in main Algebra chain
|
- Number multiplication and division in main Algebra chain
|
||||||
- `contentEquals` from Buffer. It moved to the companion.
|
- `contentEquals` from Buffer. It moved to the companion.
|
||||||
|
- MSTExpression
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
package space.kscience.kmath.ast
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.misc.Symbol.Companion.x
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.bindSymbol
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
val expr = DoubleField.mstInField {
|
val expr = MstField {
|
||||||
val x = bindSymbol("x")
|
val x = bindSymbol(x)
|
||||||
x * 2.0 + number(2.0) / x - 16.0
|
x * 2.0 + number(2.0) / x - 16.0
|
||||||
}
|
}
|
||||||
|
|
||||||
repeat(10000000) {
|
repeat(10000000) {
|
||||||
expr.invoke("x" to 1.0)
|
expr.interpret(DoubleField, x to 1.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,9 +1,9 @@
|
|||||||
package space.kscience.kmath.ast
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
import space.kscience.kmath.asm.compile
|
import space.kscience.kmath.asm.compileToExpression
|
||||||
import space.kscience.kmath.expressions.derivative
|
import space.kscience.kmath.expressions.derivative
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.kotlingrad.differentiable
|
import space.kscience.kmath.kotlingrad.toDiffExpression
|
||||||
import space.kscience.kmath.misc.symbol
|
import space.kscience.kmath.misc.symbol
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
|
||||||
@ -14,11 +14,11 @@ import space.kscience.kmath.operations.DoubleField
|
|||||||
fun main() {
|
fun main() {
|
||||||
val x by symbol
|
val x by symbol
|
||||||
|
|
||||||
val actualDerivative = MstExpression(DoubleField, "x^2-4*x-44".parseMath())
|
val actualDerivative = "x^2-4*x-44".parseMath()
|
||||||
.differentiable()
|
.toDiffExpression(DoubleField)
|
||||||
.derivative(x)
|
.derivative(x)
|
||||||
.compile()
|
|
||||||
|
|
||||||
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile()
|
|
||||||
|
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
|
||||||
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
package space.kscience.kmath.ast
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.kmath.expressions.Expression
|
||||||
|
import space.kscience.kmath.misc.StringSymbol
|
||||||
|
import space.kscience.kmath.misc.Symbol
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
import space.kscience.kmath.operations.NumericAlgebra
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
@ -76,11 +79,51 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal class InnerAlgebra<T : Any>(val algebra: Algebra<T>, val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||||
|
override fun bindSymbol(value: String): T = try {
|
||||||
|
algebra.bindSymbol(value)
|
||||||
|
} catch (ignored: IllegalStateException) {
|
||||||
|
null
|
||||||
|
} ?: arguments.getValue(StringSymbol(value))
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: T): T =
|
||||||
|
algebra.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
|
algebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
|
||||||
|
algebra.unaryOperationFunction(operation)
|
||||||
|
|
||||||
|
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
|
||||||
|
algebra.binaryOperationFunction(operation)
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||||
|
(algebra as NumericAlgebra<T>).number(value)
|
||||||
|
else
|
||||||
|
error("Numeric nodes are not supported by $this")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Interprets the [MST] node with this [Algebra].
|
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
||||||
|
*/
|
||||||
|
public fun <T : Any> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||||
|
InnerAlgebra(algebra, arguments).evaluate(this)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
||||||
*
|
*
|
||||||
* @receiver the node to evaluate.
|
* @receiver the node to evaluate.
|
||||||
* @param algebra the algebra that provides operations.
|
* @param algebra the algebra that provides operations.
|
||||||
* @return the value of expression.
|
* @return the value of expression.
|
||||||
*/
|
*/
|
||||||
public fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this)
|
public fun <T : Any> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
||||||
|
interpret(algebra, mapOf(*arguments))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interpret this [MST] as expression.
|
||||||
|
*/
|
||||||
|
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
|
||||||
|
interpret(algebra, arguments)
|
||||||
|
}
|
||||||
|
@ -1,138 +0,0 @@
|
|||||||
package space.kscience.kmath.ast
|
|
||||||
|
|
||||||
import space.kscience.kmath.expressions.*
|
|
||||||
import space.kscience.kmath.misc.StringSymbol
|
|
||||||
import space.kscience.kmath.misc.Symbol
|
|
||||||
import space.kscience.kmath.operations.*
|
|
||||||
import kotlin.contracts.InvocationKind
|
|
||||||
import kotlin.contracts.contract
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
|
|
||||||
* ASM-generated expressions.
|
|
||||||
*
|
|
||||||
* @property algebra the algebra that provides operations.
|
|
||||||
* @property mst the [MST] node.
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
|
||||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
|
||||||
override fun bindSymbol(value: String): T = try {
|
|
||||||
algebra.bindSymbol(value)
|
|
||||||
} catch (ignored: IllegalStateException) {
|
|
||||||
null
|
|
||||||
} ?: arguments.getValue(StringSymbol(value))
|
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: T): T =
|
|
||||||
algebra.unaryOperation(operation, arg)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
|
||||||
algebra.binaryOperation(operation, left, right)
|
|
||||||
|
|
||||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
|
|
||||||
algebra.unaryOperationFunction(operation)
|
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
|
|
||||||
algebra.binaryOperationFunction(operation)
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
|
||||||
(algebra as NumericAlgebra<T>).number(value)
|
|
||||||
else
|
|
||||||
error("Numeric nodes are not supported by $this")
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds [MstExpression] over [Algebra].
|
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
|
||||||
mstAlgebra: E,
|
|
||||||
block: E.() -> MST,
|
|
||||||
): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds [MstExpression] over [Group].
|
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any, A : Group<T>> A.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
|
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
return MstExpression(this, MstGroup.block())
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds [MstExpression] over [Ring].
|
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
return MstExpression(this, MstRing.block())
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds [MstExpression] over [Field].
|
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
return MstExpression(this, MstField.block())
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds [MstExpression] over [ExtendedField].
|
|
||||||
*
|
|
||||||
* @author Iaroslav Postovalov
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
|
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
return MstExpression(this, MstExtendedField.block())
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds [MstExpression] over [FunctionalExpressionGroup].
|
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any, A : Group<T>> FunctionalExpressionGroup<T, A>.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
|
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
return algebra.mstInGroup(block)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds [MstExpression] over [FunctionalExpressionRing].
|
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
return algebra.mstInRing(block)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds [MstExpression] over [FunctionalExpressionField].
|
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
return algebra.mstInField(block)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
|
|
||||||
*
|
|
||||||
* @author Iaroslav Postovalov
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
|
||||||
block: MstExtendedField.() -> MST,
|
|
||||||
): MstExpression<T, A> {
|
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
return algebra.mstInExtendedField(block)
|
|
||||||
}
|
|
@ -2,10 +2,11 @@ package space.kscience.kmath.estree
|
|||||||
|
|
||||||
import space.kscience.kmath.ast.MST
|
import space.kscience.kmath.ast.MST
|
||||||
import space.kscience.kmath.ast.MST.*
|
import space.kscience.kmath.ast.MST.*
|
||||||
import space.kscience.kmath.ast.MstExpression
|
|
||||||
import space.kscience.kmath.estree.internal.ESTreeBuilder
|
import space.kscience.kmath.estree.internal.ESTreeBuilder
|
||||||
import space.kscience.kmath.estree.internal.estree.BaseExpression
|
import space.kscience.kmath.estree.internal.estree.BaseExpression
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
|
import space.kscience.kmath.expressions.invoke
|
||||||
|
import space.kscience.kmath.misc.Symbol
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
import space.kscience.kmath.operations.NumericAlgebra
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
@ -64,19 +65,21 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
|||||||
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
|
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a compiled expression with given [MST] and given [algebra].
|
||||||
|
*/
|
||||||
|
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = compileWith(algebra)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compiles an [MST] to ESTree generated expression using given algebra.
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*
|
|
||||||
* @author Iaroslav Postovalov
|
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
|
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||||
mst.compileWith(this)
|
compileToExpression(algebra).invoke(arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression.
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*
|
|
||||||
* @author Iaroslav Postovalov
|
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
|
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
|
||||||
mst.compileWith(algebra)
|
compileToExpression(algebra).invoke(*arguments)
|
||||||
|
@ -3,16 +3,19 @@ package space.kscience.kmath.estree
|
|||||||
import space.kscience.kmath.ast.*
|
import space.kscience.kmath.ast.*
|
||||||
import space.kscience.kmath.complex.ComplexField
|
import space.kscience.kmath.complex.ComplexField
|
||||||
import space.kscience.kmath.complex.toComplex
|
import space.kscience.kmath.complex.toComplex
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.misc.Symbol
|
||||||
import space.kscience.kmath.operations.ByteRing
|
import space.kscience.kmath.operations.ByteRing
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class TestESTreeConsistencyWithInterpreter {
|
internal class TestESTreeConsistencyWithInterpreter {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun mstSpace() {
|
fun mstSpace() {
|
||||||
val res1 = MstGroup.mstInGroup {
|
|
||||||
|
val mst = MstGroup {
|
||||||
binaryOperationFunction("+")(
|
binaryOperationFunction("+")(
|
||||||
unaryOperationFunction("+")(
|
unaryOperationFunction("+")(
|
||||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
number(3.toByte()) - (number(2.toByte()) + (scale(
|
||||||
@ -23,27 +26,17 @@ internal class TestESTreeConsistencyWithInterpreter {
|
|||||||
|
|
||||||
number(1)
|
number(1)
|
||||||
) + bindSymbol("x") + zero
|
) + bindSymbol("x") + zero
|
||||||
}("x" to MST.Numeric(2))
|
}
|
||||||
|
|
||||||
val res2 = MstGroup.mstInGroup {
|
assertEquals(
|
||||||
binaryOperationFunction("+")(
|
mst.interpret(MstGroup, Symbol.x to MST.Numeric(2)),
|
||||||
unaryOperationFunction("+")(
|
mst.compile(MstGroup, Symbol.x to MST.Numeric(2))
|
||||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
)
|
||||||
add(number(1), number(1)),
|
|
||||||
2.0
|
|
||||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
|
||||||
),
|
|
||||||
|
|
||||||
number(1)
|
|
||||||
) + bindSymbol("x") + zero
|
|
||||||
}.compile()("x" to MST.Numeric(2))
|
|
||||||
|
|
||||||
assertEquals(res1, res2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun byteRing() {
|
fun byteRing() {
|
||||||
val res1 = ByteRing.mstInRing {
|
val mst = MstRing {
|
||||||
binaryOperationFunction("+")(
|
binaryOperationFunction("+")(
|
||||||
unaryOperationFunction("+")(
|
unaryOperationFunction("+")(
|
||||||
(bindSymbol("x") - (2.toByte() + (scale(
|
(bindSymbol("x") - (2.toByte() + (scale(
|
||||||
@ -54,62 +47,43 @@ internal class TestESTreeConsistencyWithInterpreter {
|
|||||||
|
|
||||||
number(1)
|
number(1)
|
||||||
) * number(2)
|
) * number(2)
|
||||||
}("x" to 3.toByte())
|
}
|
||||||
|
|
||||||
val res2 = ByteRing.mstInRing {
|
assertEquals(
|
||||||
binaryOperationFunction("+")(
|
mst.interpret(ByteRing, Symbol.x to 3.toByte()),
|
||||||
unaryOperationFunction("+")(
|
mst.compile(ByteRing, Symbol.x to 3.toByte())
|
||||||
(bindSymbol("x") - (2.toByte() + (scale(
|
)
|
||||||
add(number(1), number(1)),
|
|
||||||
2.0
|
|
||||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
|
||||||
),
|
|
||||||
number(1)
|
|
||||||
) * number(2)
|
|
||||||
}.compile()("x" to 3.toByte())
|
|
||||||
|
|
||||||
assertEquals(res1, res2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun realField() {
|
fun realField() {
|
||||||
val res1 = DoubleField.mstInField {
|
val mst = MstField {
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||||
+ number(1),
|
+ number(1),
|
||||||
number(1) / 2 + number(2.0) * one
|
number(1) / 2 + number(2.0) * one
|
||||||
) + zero
|
) + zero
|
||||||
}("x" to 2.0)
|
}
|
||||||
|
|
||||||
val res2 = DoubleField.mstInField {
|
assertEquals(
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
mst.interpret(DoubleField, Symbol.x to 2.0),
|
||||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
mst.compile(DoubleField, Symbol.x to 2.0)
|
||||||
+ number(1),
|
)
|
||||||
number(1) / 2 + number(2.0) * one
|
|
||||||
) + zero
|
|
||||||
}.compile()("x" to 2.0)
|
|
||||||
|
|
||||||
assertEquals(res1, res2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun complexField() {
|
fun complexField() {
|
||||||
val res1 = ComplexField.mstInField {
|
val mst = MstField {
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||||
+ number(1),
|
+ number(1),
|
||||||
number(1) / 2 + number(2.0) * one
|
number(1) / 2 + number(2.0) * one
|
||||||
) + zero
|
) + zero
|
||||||
}("x" to 2.0.toComplex())
|
}
|
||||||
|
|
||||||
val res2 = ComplexField.mstInField {
|
assertEquals(
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
mst.interpret(ComplexField, Symbol.x to 2.0.toComplex()),
|
||||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
mst.compile(ComplexField, Symbol.x to 2.0.toComplex())
|
||||||
+ number(1),
|
)
|
||||||
number(1) / 2 + number(2.0) * one
|
|
||||||
) + zero
|
|
||||||
}.compile()("x" to 2.0.toComplex())
|
|
||||||
|
|
||||||
assertEquals(res1, res2)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
package space.kscience.kmath.estree
|
package space.kscience.kmath.estree
|
||||||
|
|
||||||
import space.kscience.kmath.ast.mstInExtendedField
|
import space.kscience.kmath.ast.MstExtendedField
|
||||||
import space.kscience.kmath.ast.mstInField
|
|
||||||
import space.kscience.kmath.ast.mstInGroup
|
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
@ -12,29 +11,29 @@ import kotlin.test.assertEquals
|
|||||||
internal class TestESTreeOperationsSupport {
|
internal class TestESTreeOperationsSupport {
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryOperationInvocation() {
|
fun testUnaryOperationInvocation() {
|
||||||
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile()
|
val expression = MstExtendedField { -bindSymbol("x") }.compileToExpression(DoubleField)
|
||||||
val res = expression("x" to 2.0)
|
val res = expression("x" to 2.0)
|
||||||
assertEquals(-2.0, res)
|
assertEquals(-2.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBinaryOperationInvocation() {
|
fun testBinaryOperationInvocation() {
|
||||||
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile()
|
val expression = MstExtendedField { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
|
||||||
val res = expression("x" to 2.0)
|
val res = expression("x" to 2.0)
|
||||||
assertEquals(-1.0, res)
|
assertEquals(-1.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testConstProductInvocation() {
|
fun testConstProductInvocation() {
|
||||||
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0)
|
val res = MstExtendedField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
|
||||||
assertEquals(4.0, res)
|
assertEquals(4.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testMultipleCalls() {
|
fun testMultipleCalls() {
|
||||||
val e =
|
val e =
|
||||||
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
|
MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
|
||||||
.compile()
|
.compileToExpression(DoubleField)
|
||||||
val r = Random(0)
|
val r = Random(0)
|
||||||
var s = 0.0
|
var s = 0.0
|
||||||
repeat(1000000) { s += e("x" to r.nextDouble()) }
|
repeat(1000000) { s += e("x" to r.nextDouble()) }
|
||||||
|
@ -1,53 +1,63 @@
|
|||||||
package space.kscience.kmath.estree
|
package space.kscience.kmath.estree
|
||||||
|
|
||||||
import space.kscience.kmath.ast.mstInField
|
import space.kscience.kmath.ast.MstExtendedField
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class TestESTreeSpecialization {
|
internal class TestESTreeSpecialization {
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryPlus() {
|
fun testUnaryPlus() {
|
||||||
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile()
|
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||||
assertEquals(2.0, expr("x" to 2.0))
|
assertEquals(2.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryMinus() {
|
fun testUnaryMinus() {
|
||||||
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile()
|
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||||
assertEquals(-2.0, expr("x" to 2.0))
|
assertEquals(-2.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAdd() {
|
fun testAdd() {
|
||||||
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
val expr = MstExtendedField {
|
||||||
|
binaryOperationFunction("+")(bindSymbol("x"),
|
||||||
|
bindSymbol("x"))
|
||||||
|
}.compileToExpression(DoubleField)
|
||||||
assertEquals(4.0, expr("x" to 2.0))
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSine() {
|
fun testSine() {
|
||||||
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile()
|
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||||
assertEquals(0.0, expr("x" to 0.0))
|
assertEquals(0.0, expr("x" to 0.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testMinus() {
|
fun testMinus() {
|
||||||
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
val expr = MstExtendedField {
|
||||||
|
binaryOperationFunction("-")(bindSymbol("x"),
|
||||||
|
bindSymbol("x"))
|
||||||
|
}.compileToExpression(DoubleField)
|
||||||
assertEquals(0.0, expr("x" to 2.0))
|
assertEquals(0.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testDivide() {
|
fun testDivide() {
|
||||||
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
val expr = MstExtendedField {
|
||||||
|
binaryOperationFunction("/")(bindSymbol("x"),
|
||||||
|
bindSymbol("x"))
|
||||||
|
}.compileToExpression(DoubleField)
|
||||||
assertEquals(1.0, expr("x" to 2.0))
|
assertEquals(1.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPower() {
|
fun testPower() {
|
||||||
val expr = DoubleField
|
val expr = MstExtendedField {
|
||||||
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) }
|
binaryOperationFunction("pow")(bindSymbol("x"), number(2))
|
||||||
.compile()
|
}.compileToExpression(DoubleField)
|
||||||
|
|
||||||
assertEquals(4.0, expr("x" to 2.0))
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package space.kscience.kmath.estree
|
package space.kscience.kmath.estree
|
||||||
|
|
||||||
import space.kscience.kmath.ast.mstInRing
|
import space.kscience.kmath.ast.MstRing
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.operations.ByteRing
|
import space.kscience.kmath.operations.ByteRing
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertFailsWith
|
import kotlin.test.assertFailsWith
|
||||||
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
|
|||||||
internal class TestESTreeVariables {
|
internal class TestESTreeVariables {
|
||||||
@Test
|
@Test
|
||||||
fun testVariable() {
|
fun testVariable() {
|
||||||
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
|
val expr = MstRing{ bindSymbol("x") }.compileToExpression(ByteRing)
|
||||||
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testUndefinedVariableFails() {
|
fun testUndefinedVariableFails() {
|
||||||
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
|
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
|
||||||
assertFailsWith<NoSuchElementException> { expr() }
|
assertFailsWith<NoSuchElementException> { expr() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,8 +4,9 @@ import space.kscience.kmath.asm.internal.AsmBuilder
|
|||||||
import space.kscience.kmath.asm.internal.buildName
|
import space.kscience.kmath.asm.internal.buildName
|
||||||
import space.kscience.kmath.ast.MST
|
import space.kscience.kmath.ast.MST
|
||||||
import space.kscience.kmath.ast.MST.*
|
import space.kscience.kmath.ast.MST.*
|
||||||
import space.kscience.kmath.ast.MstExpression
|
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
|
import space.kscience.kmath.expressions.invoke
|
||||||
|
import space.kscience.kmath.misc.Symbol
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
import space.kscience.kmath.operations.NumericAlgebra
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
@ -70,18 +71,22 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
|||||||
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
|
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Compiles an [MST] to ASM using given algebra.
|
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
|
|
||||||
mst.compileWith(T::class.java, this)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimizes performance of an [MstExpression] using ASM codegen.
|
* Create a compiled expression with given [MST] and given [algebra].
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
|
public inline fun <reified T: Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
|
||||||
mst.compileWith(T::class.java, algebra)
|
compileWith(T::class.java, algebra)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
|
*/
|
||||||
|
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||||
|
compileToExpression(algebra).invoke(arguments)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
|
*/
|
||||||
|
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
|
||||||
|
compileToExpression(algebra).invoke(*arguments)
|
||||||
|
@ -3,16 +3,19 @@ package space.kscience.kmath.asm
|
|||||||
import space.kscience.kmath.ast.*
|
import space.kscience.kmath.ast.*
|
||||||
import space.kscience.kmath.complex.ComplexField
|
import space.kscience.kmath.complex.ComplexField
|
||||||
import space.kscience.kmath.complex.toComplex
|
import space.kscience.kmath.complex.toComplex
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.misc.Symbol.Companion.x
|
||||||
import space.kscience.kmath.operations.ByteRing
|
import space.kscience.kmath.operations.ByteRing
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class TestAsmConsistencyWithInterpreter {
|
internal class TestAsmConsistencyWithInterpreter {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun mstSpace() {
|
fun mstSpace() {
|
||||||
val res1 = MstGroup.mstInGroup {
|
|
||||||
|
val mst = MstGroup {
|
||||||
binaryOperationFunction("+")(
|
binaryOperationFunction("+")(
|
||||||
unaryOperationFunction("+")(
|
unaryOperationFunction("+")(
|
||||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
number(3.toByte()) - (number(2.toByte()) + (scale(
|
||||||
@ -23,27 +26,17 @@ internal class TestAsmConsistencyWithInterpreter {
|
|||||||
|
|
||||||
number(1)
|
number(1)
|
||||||
) + bindSymbol("x") + zero
|
) + bindSymbol("x") + zero
|
||||||
}("x" to MST.Numeric(2))
|
}
|
||||||
|
|
||||||
val res2 = MstGroup.mstInGroup {
|
assertEquals(
|
||||||
binaryOperationFunction("+")(
|
mst.interpret(MstGroup, x to MST.Numeric(2)),
|
||||||
unaryOperationFunction("+")(
|
mst.compile(MstGroup, x to MST.Numeric(2))
|
||||||
number(3.toByte()) - (number(2.toByte()) + (scale(
|
)
|
||||||
add(number(1), number(1)),
|
|
||||||
2.0
|
|
||||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
|
||||||
),
|
|
||||||
|
|
||||||
number(1)
|
|
||||||
) + bindSymbol("x") + zero
|
|
||||||
}.compile()("x" to MST.Numeric(2))
|
|
||||||
|
|
||||||
assertEquals(res1, res2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun byteRing() {
|
fun byteRing() {
|
||||||
val res1 = ByteRing.mstInRing {
|
val mst = MstRing {
|
||||||
binaryOperationFunction("+")(
|
binaryOperationFunction("+")(
|
||||||
unaryOperationFunction("+")(
|
unaryOperationFunction("+")(
|
||||||
(bindSymbol("x") - (2.toByte() + (scale(
|
(bindSymbol("x") - (2.toByte() + (scale(
|
||||||
@ -54,62 +47,43 @@ internal class TestAsmConsistencyWithInterpreter {
|
|||||||
|
|
||||||
number(1)
|
number(1)
|
||||||
) * number(2)
|
) * number(2)
|
||||||
}("x" to 3.toByte())
|
}
|
||||||
|
|
||||||
val res2 = ByteRing.mstInRing {
|
assertEquals(
|
||||||
binaryOperationFunction("+")(
|
mst.interpret(ByteRing, x to 3.toByte()),
|
||||||
unaryOperationFunction("+")(
|
mst.compile(ByteRing, x to 3.toByte())
|
||||||
(bindSymbol("x") - (2.toByte() + (scale(
|
)
|
||||||
add(number(1), number(1)),
|
|
||||||
2.0
|
|
||||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
|
||||||
),
|
|
||||||
number(1)
|
|
||||||
) * number(2)
|
|
||||||
}.compile()("x" to 3.toByte())
|
|
||||||
|
|
||||||
assertEquals(res1, res2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun realField() {
|
fun realField() {
|
||||||
val res1 = DoubleField.mstInField {
|
val mst = MstField {
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||||
+ number(1),
|
+ number(1),
|
||||||
number(1) / 2 + number(2.0) * one
|
number(1) / 2 + number(2.0) * one
|
||||||
) + zero
|
) + zero
|
||||||
}("x" to 2.0)
|
}
|
||||||
|
|
||||||
val res2 = DoubleField.mstInField {
|
assertEquals(
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
mst.interpret(DoubleField, x to 2.0),
|
||||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
mst.compile(DoubleField, x to 2.0)
|
||||||
+ number(1),
|
)
|
||||||
number(1) / 2 + number(2.0) * one
|
|
||||||
) + zero
|
|
||||||
}.compile()("x" to 2.0)
|
|
||||||
|
|
||||||
assertEquals(res1, res2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun complexField() {
|
fun complexField() {
|
||||||
val res1 = ComplexField.mstInField {
|
val mst = MstField {
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
||||||
+ number(1),
|
+ number(1),
|
||||||
number(1) / 2 + number(2.0) * one
|
number(1) / 2 + number(2.0) * one
|
||||||
) + zero
|
) + zero
|
||||||
}("x" to 2.0.toComplex())
|
}
|
||||||
|
|
||||||
val res2 = ComplexField.mstInField {
|
assertEquals(
|
||||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
mst.interpret(ComplexField, x to 2.0.toComplex()),
|
||||||
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
|
mst.compile(ComplexField, x to 2.0.toComplex())
|
||||||
+ number(1),
|
)
|
||||||
number(1) / 2 + number(2.0) * one
|
|
||||||
) + zero
|
|
||||||
}.compile()("x" to 2.0.toComplex())
|
|
||||||
|
|
||||||
assertEquals(res1, res2)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
package space.kscience.kmath.asm
|
package space.kscience.kmath.asm
|
||||||
|
|
||||||
import space.kscience.kmath.ast.mstInExtendedField
|
import space.kscience.kmath.ast.MstExtendedField
|
||||||
import space.kscience.kmath.ast.mstInField
|
import space.kscience.kmath.ast.MstField
|
||||||
import space.kscience.kmath.ast.mstInGroup
|
import space.kscience.kmath.ast.MstGroup
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
@ -12,29 +13,29 @@ import kotlin.test.assertEquals
|
|||||||
internal class TestAsmOperationsSupport {
|
internal class TestAsmOperationsSupport {
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryOperationInvocation() {
|
fun testUnaryOperationInvocation() {
|
||||||
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile()
|
val expression = MstGroup { -bindSymbol("x") }.compileToExpression(DoubleField)
|
||||||
val res = expression("x" to 2.0)
|
val res = expression("x" to 2.0)
|
||||||
assertEquals(-2.0, res)
|
assertEquals(-2.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBinaryOperationInvocation() {
|
fun testBinaryOperationInvocation() {
|
||||||
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile()
|
val expression = MstGroup { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
|
||||||
val res = expression("x" to 2.0)
|
val res = expression("x" to 2.0)
|
||||||
assertEquals(-1.0, res)
|
assertEquals(-1.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testConstProductInvocation() {
|
fun testConstProductInvocation() {
|
||||||
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0)
|
val res = MstField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
|
||||||
assertEquals(4.0, res)
|
assertEquals(4.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testMultipleCalls() {
|
fun testMultipleCalls() {
|
||||||
val e =
|
val e =
|
||||||
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
|
MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
|
||||||
.compile()
|
.compileToExpression(DoubleField)
|
||||||
val r = Random(0)
|
val r = Random(0)
|
||||||
var s = 0.0
|
var s = 0.0
|
||||||
repeat(1000000) { s += e("x" to r.nextDouble()) }
|
repeat(1000000) { s += e("x" to r.nextDouble()) }
|
||||||
|
@ -1,53 +1,63 @@
|
|||||||
package space.kscience.kmath.asm
|
package space.kscience.kmath.asm
|
||||||
|
|
||||||
import space.kscience.kmath.ast.mstInField
|
import space.kscience.kmath.ast.MstExtendedField
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class TestAsmSpecialization {
|
internal class TestAsmSpecialization {
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryPlus() {
|
fun testUnaryPlus() {
|
||||||
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile()
|
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||||
assertEquals(2.0, expr("x" to 2.0))
|
assertEquals(2.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testUnaryMinus() {
|
fun testUnaryMinus() {
|
||||||
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile()
|
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||||
assertEquals(-2.0, expr("x" to 2.0))
|
assertEquals(-2.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAdd() {
|
fun testAdd() {
|
||||||
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
val expr = MstExtendedField {
|
||||||
|
binaryOperationFunction("+")(bindSymbol("x"),
|
||||||
|
bindSymbol("x"))
|
||||||
|
}.compileToExpression(DoubleField)
|
||||||
assertEquals(4.0, expr("x" to 2.0))
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSine() {
|
fun testSine() {
|
||||||
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile()
|
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
|
||||||
assertEquals(0.0, expr("x" to 0.0))
|
assertEquals(0.0, expr("x" to 0.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testMinus() {
|
fun testMinus() {
|
||||||
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
val expr = MstExtendedField {
|
||||||
|
binaryOperationFunction("-")(bindSymbol("x"),
|
||||||
|
bindSymbol("x"))
|
||||||
|
}.compileToExpression(DoubleField)
|
||||||
assertEquals(0.0, expr("x" to 2.0))
|
assertEquals(0.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testDivide() {
|
fun testDivide() {
|
||||||
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile()
|
val expr = MstExtendedField {
|
||||||
|
binaryOperationFunction("/")(bindSymbol("x"),
|
||||||
|
bindSymbol("x"))
|
||||||
|
}.compileToExpression(DoubleField)
|
||||||
assertEquals(1.0, expr("x" to 2.0))
|
assertEquals(1.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPower() {
|
fun testPower() {
|
||||||
val expr = DoubleField
|
val expr = MstExtendedField {
|
||||||
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) }
|
binaryOperationFunction("pow")(bindSymbol("x"), number(2))
|
||||||
.compile()
|
}.compileToExpression(DoubleField)
|
||||||
|
|
||||||
assertEquals(4.0, expr("x" to 2.0))
|
assertEquals(4.0, expr("x" to 2.0))
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package space.kscience.kmath.asm
|
package space.kscience.kmath.asm
|
||||||
|
|
||||||
import space.kscience.kmath.ast.mstInRing
|
import space.kscience.kmath.ast.MstRing
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.operations.ByteRing
|
import space.kscience.kmath.operations.ByteRing
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertFailsWith
|
import kotlin.test.assertFailsWith
|
||||||
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
|
|||||||
internal class TestAsmVariables {
|
internal class TestAsmVariables {
|
||||||
@Test
|
@Test
|
||||||
fun testVariable() {
|
fun testVariable() {
|
||||||
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
|
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
|
||||||
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testUndefinedVariableFails() {
|
fun testUndefinedVariableFails() {
|
||||||
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
|
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
|
||||||
assertFailsWith<NoSuchElementException> { expr() }
|
assertFailsWith<NoSuchElementException> { expr() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,9 +2,9 @@ package space.kscience.kmath.ast
|
|||||||
|
|
||||||
import space.kscience.kmath.complex.Complex
|
import space.kscience.kmath.complex.Complex
|
||||||
import space.kscience.kmath.complex.ComplexField
|
import space.kscience.kmath.complex.ComplexField
|
||||||
import space.kscience.kmath.expressions.invoke
|
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ internal class ParserTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `evaluate MSTExpression`() {
|
fun `evaluate MSTExpression`() {
|
||||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
val res = MstField.invoke { number(2) + number(2) * (number(2) + number(2)) }.interpret(ComplexField)
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,8 +3,9 @@ package space.kscience.kmath.kotlingrad
|
|||||||
import edu.umontreal.kotlingrad.api.SFun
|
import edu.umontreal.kotlingrad.api.SFun
|
||||||
import space.kscience.kmath.ast.MST
|
import space.kscience.kmath.ast.MST
|
||||||
import space.kscience.kmath.ast.MstAlgebra
|
import space.kscience.kmath.ast.MstAlgebra
|
||||||
import space.kscience.kmath.ast.MstExpression
|
import space.kscience.kmath.ast.interpret
|
||||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
import space.kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import space.kscience.kmath.expressions.Expression
|
||||||
import space.kscience.kmath.misc.Symbol
|
import space.kscience.kmath.misc.Symbol
|
||||||
import space.kscience.kmath.operations.NumericAlgebra
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
@ -18,38 +19,26 @@ import space.kscience.kmath.operations.NumericAlgebra
|
|||||||
* @param A the [NumericAlgebra] of [T].
|
* @param A the [NumericAlgebra] of [T].
|
||||||
* @property expr the underlying [MstExpression].
|
* @property expr the underlying [MstExpression].
|
||||||
*/
|
*/
|
||||||
public inline class DifferentiableMstExpression<T: Number, A>(
|
public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
|
||||||
public val expr: MstExpression<T, A>,
|
public val algebra: A,
|
||||||
) : DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T> {
|
public val mst: MST,
|
||||||
|
) : DifferentiableExpression<T, Expression<T>> {
|
||||||
|
|
||||||
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
|
||||||
|
|
||||||
/**
|
public override fun derivativeOrNull(symbols: List<Symbol>): DifferentiableMstExpression<T, A> =
|
||||||
* The [MstExpression.algebra] of [expr].
|
DifferentiableMstExpression(
|
||||||
*/
|
algebra,
|
||||||
public val algebra: A
|
symbols.map(Symbol::identity)
|
||||||
get() = expr.algebra
|
.map(MstAlgebra::bindSymbol)
|
||||||
|
.map { it.toSVar<KMathNumber<T, A>>() }
|
||||||
/**
|
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
|
||||||
* The [MstExpression.mst] of [expr].
|
.toMst(),
|
||||||
*/
|
)
|
||||||
public val mst: MST
|
|
||||||
get() = expr.mst
|
|
||||||
|
|
||||||
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
|
||||||
|
|
||||||
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
|
|
||||||
algebra,
|
|
||||||
symbols.map(Symbol::identity)
|
|
||||||
.map(MstAlgebra::bindSymbol)
|
|
||||||
.map { it.toSVar<KMathNumber<T, A>>() }
|
|
||||||
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
|
|
||||||
.toMst(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
* Wraps this [MST] into [DifferentiableMstExpression].
|
||||||
*/
|
*/
|
||||||
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
|
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): DifferentiableMstExpression<T, A> =
|
||||||
DifferentiableMstExpression(this)
|
DifferentiableMstExpression(algebra, this)
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
package space.kscience.kmath.kotlingrad
|
package space.kscience.kmath.kotlingrad
|
||||||
|
|
||||||
import edu.umontreal.kotlingrad.api.*
|
import edu.umontreal.kotlingrad.api.*
|
||||||
import space.kscience.kmath.asm.compile
|
import space.kscience.kmath.asm.compileToExpression
|
||||||
import space.kscience.kmath.ast.MstAlgebra
|
import space.kscience.kmath.ast.MstAlgebra
|
||||||
import space.kscience.kmath.ast.MstExpression
|
|
||||||
import space.kscience.kmath.ast.parseMath
|
import space.kscience.kmath.ast.parseMath
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
@ -43,8 +42,8 @@ internal class AdaptingTests {
|
|||||||
fun simpleFunctionDerivative() {
|
fun simpleFunctionDerivative() {
|
||||||
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
|
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
|
||||||
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
|
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
|
||||||
val actualDerivative = MstExpression(DoubleField, quadratic.d(x).toMst()).compile()
|
val actualDerivative = quadratic.d(x).toMst().compileToExpression(DoubleField)
|
||||||
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile()
|
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
|
||||||
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,12 +51,11 @@ internal class AdaptingTests {
|
|||||||
fun moreComplexDerivative() {
|
fun moreComplexDerivative() {
|
||||||
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
|
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
|
||||||
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
|
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
|
||||||
val actualDerivative = MstExpression(DoubleField, composition.d(x).toMst()).compile()
|
val actualDerivative = composition.d(x).toMst().compileToExpression(DoubleField)
|
||||||
|
|
||||||
|
val expectedDerivative =
|
||||||
|
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath().compileToExpression(DoubleField)
|
||||||
|
|
||||||
val expectedDerivative = MstExpression(
|
|
||||||
DoubleField,
|
|
||||||
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
|
|
||||||
).compile()
|
|
||||||
|
|
||||||
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
|
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user