Remove Algebra.evaluate(MST) by inlining in into interpret
This commit is contained in:
parent
53ab8334dd
commit
e094f6e8ee
@ -7,7 +7,7 @@ package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.complex.Complex
|
||||
import space.kscience.kmath.complex.ComplexField
|
||||
import space.kscience.kmath.expressions.evaluate
|
||||
import space.kscience.kmath.expressions.interpret
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import kotlin.test.Test
|
||||
@ -17,14 +17,14 @@ internal class TestParser {
|
||||
@Test
|
||||
fun evaluateParsedMst() {
|
||||
val mst = "2+2*(2+2)".parseMath()
|
||||
val res = ComplexField.evaluate(mst)
|
||||
val res = mst.interpret(ComplexField)
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun evaluateMstSymbol() {
|
||||
val mst = "i".parseMath()
|
||||
val res = ComplexField.evaluate(mst)
|
||||
val res = mst.interpret(ComplexField)
|
||||
assertEquals(ComplexField.i, res)
|
||||
}
|
||||
|
||||
@ -32,7 +32,7 @@ internal class TestParser {
|
||||
@Test
|
||||
fun evaluateMstUnary() {
|
||||
val mst = "sin(0)".parseMath()
|
||||
val res = DoubleField.evaluate(mst)
|
||||
val res = mst.interpret(DoubleField)
|
||||
assertEquals(0.0, res)
|
||||
}
|
||||
|
||||
@ -53,7 +53,7 @@ internal class TestParser {
|
||||
}
|
||||
|
||||
val mst = "magic(a, b)".parseMath()
|
||||
val res = magicalAlgebra.evaluate(mst)
|
||||
val res = mst.interpret(magicalAlgebra)
|
||||
assertEquals("a ★ b", res)
|
||||
}
|
||||
}
|
||||
|
@ -5,35 +5,35 @@
|
||||
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.evaluate
|
||||
import space.kscience.kmath.expressions.interpret
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestParserPrecedence {
|
||||
@Test
|
||||
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
|
||||
fun test1(): Unit = assertEquals(6.0, "2*2+2".parseMath().interpret(f))
|
||||
|
||||
@Test
|
||||
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
|
||||
fun test2(): Unit = assertEquals(6.0, "2+2*2".parseMath().interpret(f))
|
||||
|
||||
@Test
|
||||
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
|
||||
fun test3(): Unit = assertEquals(10.0, "2^3+2".parseMath().interpret(f))
|
||||
|
||||
@Test
|
||||
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
|
||||
fun test4(): Unit = assertEquals(10.0, "2+2^3".parseMath().interpret(f))
|
||||
|
||||
@Test
|
||||
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
|
||||
fun test5(): Unit = assertEquals(16.0, "2^3*2".parseMath().interpret(f))
|
||||
|
||||
@Test
|
||||
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
|
||||
fun test6(): Unit = assertEquals(16.0, "2*2^3".parseMath().interpret(f))
|
||||
|
||||
@Test
|
||||
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
|
||||
fun test7(): Unit = assertEquals(18.0, "2+2^3*2".parseMath().interpret(f))
|
||||
|
||||
@Test
|
||||
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
|
||||
fun test8(): Unit = assertEquals(18.0, "2*2^3+2".parseMath().interpret(f))
|
||||
|
||||
private companion object {
|
||||
private val f = DoubleField
|
||||
|
@ -7,7 +7,7 @@ package space.kscience.kmath.expressions
|
||||
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
import space.kscience.kmath.operations.NumericAlgebra
|
||||
import space.kscience.kmath.operations.bindSymbol
|
||||
import space.kscience.kmath.operations.bindSymbolOrNull
|
||||
|
||||
/**
|
||||
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
|
||||
@ -43,66 +43,50 @@ public sealed interface MST {
|
||||
|
||||
// TODO add a function with named arguments
|
||||
|
||||
/**
|
||||
* Interprets the [MST] node with this [Algebra].
|
||||
*
|
||||
* @receiver the algebra that provides operations.
|
||||
* @param node the node to evaluate.
|
||||
* @return the value of expression.
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
|
||||
is Symbol -> bindSymbol(node)
|
||||
|
||||
is MST.Unary -> when {
|
||||
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
|
||||
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
||||
}
|
||||
|
||||
is MST.Binary -> when {
|
||||
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
|
||||
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
|
||||
|
||||
this is NumericAlgebra && node.left is MST.Numeric ->
|
||||
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
||||
|
||||
this is NumericAlgebra && node.right is MST.Numeric ->
|
||||
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
||||
|
||||
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||
}
|
||||
}
|
||||
|
||||
internal class InnerAlgebra<T>(val algebra: Algebra<T>, val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||
override fun bindSymbolOrNull(value: String): T? = algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)]
|
||||
|
||||
override fun unaryOperation(operation: String, arg: T): T =
|
||||
algebra.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
algebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
|
||||
algebra.unaryOperationFunction(operation)
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
|
||||
algebra.binaryOperationFunction(operation)
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||
(algebra as NumericAlgebra<T>).number(value)
|
||||
else
|
||||
error("Numeric nodes are not supported by $this")
|
||||
}
|
||||
|
||||
/**
|
||||
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
||||
*/
|
||||
public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||
InnerAlgebra(algebra, arguments).evaluate(this)
|
||||
public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = when (this) {
|
||||
is MST.Numeric -> (algebra as NumericAlgebra<T>?)?.number(value)
|
||||
?: error("Numeric nodes are not supported by $algebra")
|
||||
|
||||
is Symbol -> algebra.bindSymbolOrNull(this) ?: arguments.getValue(this)
|
||||
|
||||
is MST.Unary -> when {
|
||||
algebra is NumericAlgebra && this.value is MST.Numeric -> algebra.unaryOperation(
|
||||
this.operation,
|
||||
algebra.number(this.value.value),
|
||||
)
|
||||
else -> algebra.unaryOperationFunction(this.operation)(this.value.interpret(algebra, arguments))
|
||||
}
|
||||
|
||||
is MST.Binary -> when {
|
||||
algebra is NumericAlgebra && this.left is MST.Numeric && this.right is MST.Numeric -> algebra.binaryOperation(
|
||||
this.operation,
|
||||
algebra.number(this.left.value),
|
||||
algebra.number(this.right.value),
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra && this.left is MST.Numeric -> algebra.leftSideNumberOperation(
|
||||
this.operation,
|
||||
this.left.value,
|
||||
this.right.interpret(algebra, arguments),
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra && this.right is MST.Numeric -> algebra.rightSideNumberOperation(
|
||||
this.operation,
|
||||
left.interpret(algebra, arguments),
|
||||
right.value,
|
||||
)
|
||||
|
||||
else -> algebra.binaryOperation(
|
||||
this.operation,
|
||||
this.left.interpret(algebra, arguments),
|
||||
this.right.interpret(algebra, arguments),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
||||
@ -111,12 +95,17 @@ public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T
|
||||
* @param algebra the algebra that provides operations.
|
||||
* @return the value of expression.
|
||||
*/
|
||||
public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
||||
interpret(algebra, mapOf(*arguments))
|
||||
public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = interpret(
|
||||
algebra,
|
||||
when (arguments.size) {
|
||||
0 -> emptyMap()
|
||||
1 -> mapOf(arguments[0])
|
||||
else -> hashMapOf(*arguments)
|
||||
},
|
||||
)
|
||||
|
||||
/**
|
||||
* Interpret this [MST] as expression.
|
||||
*/
|
||||
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
|
||||
interpret(algebra, arguments)
|
||||
}
|
||||
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> =
|
||||
Expression { arguments -> interpret(algebra, arguments) }
|
||||
|
Loading…
Reference in New Issue
Block a user