Remove Algebra.evaluate(MST) by inlining in into interpret

This commit is contained in:
Iaroslav Postovalov 2022-02-08 11:43:50 +07:00
parent 53ab8334dd
commit e094f6e8ee
3 changed files with 65 additions and 76 deletions

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.ast
import space.kscience.kmath.complex.Complex
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.expressions.evaluate
import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test
@ -17,14 +17,14 @@ internal class TestParser {
@Test
fun evaluateParsedMst() {
val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.evaluate(mst)
val res = mst.interpret(ComplexField)
assertEquals(Complex(10.0, 0.0), res)
}
@Test
fun evaluateMstSymbol() {
val mst = "i".parseMath()
val res = ComplexField.evaluate(mst)
val res = mst.interpret(ComplexField)
assertEquals(ComplexField.i, res)
}
@ -32,7 +32,7 @@ internal class TestParser {
@Test
fun evaluateMstUnary() {
val mst = "sin(0)".parseMath()
val res = DoubleField.evaluate(mst)
val res = mst.interpret(DoubleField)
assertEquals(0.0, res)
}
@ -53,7 +53,7 @@ internal class TestParser {
}
val mst = "magic(a, b)".parseMath()
val res = magicalAlgebra.evaluate(mst)
val res = mst.interpret(magicalAlgebra)
assertEquals("a ★ b", res)
}
}

View File

@ -5,35 +5,35 @@
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.evaluate
import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestParserPrecedence {
@Test
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
fun test1(): Unit = assertEquals(6.0, "2*2+2".parseMath().interpret(f))
@Test
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
fun test2(): Unit = assertEquals(6.0, "2+2*2".parseMath().interpret(f))
@Test
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
fun test3(): Unit = assertEquals(10.0, "2^3+2".parseMath().interpret(f))
@Test
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
fun test4(): Unit = assertEquals(10.0, "2+2^3".parseMath().interpret(f))
@Test
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
fun test5(): Unit = assertEquals(16.0, "2^3*2".parseMath().interpret(f))
@Test
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
fun test6(): Unit = assertEquals(16.0, "2*2^3".parseMath().interpret(f))
@Test
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
fun test7(): Unit = assertEquals(18.0, "2+2^3*2".parseMath().interpret(f))
@Test
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
fun test8(): Unit = assertEquals(18.0, "2*2^3+2".parseMath().interpret(f))
private companion object {
private val f = DoubleField

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.expressions
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.bindSymbolOrNull
/**
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
@ -43,66 +43,50 @@ public sealed interface MST {
// TODO add a function with named arguments
/**
* Interprets the [MST] node with this [Algebra].
*
* @receiver the algebra that provides operations.
* @param node the node to evaluate.
* @return the value of expression.
* @author Alexander Nozik
*/
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is Symbol -> bindSymbol(node)
is MST.Unary -> when {
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
}
is MST.Binary -> when {
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
this is NumericAlgebra && node.left is MST.Numeric ->
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
this is NumericAlgebra && node.right is MST.Numeric ->
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
}
}
internal class InnerAlgebra<T>(val algebra: Algebra<T>, val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun bindSymbolOrNull(value: String): T? = algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)]
override fun unaryOperation(operation: String, arg: T): T =
algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
algebra.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
algebra.binaryOperationFunction(operation)
@Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else
error("Numeric nodes are not supported by $this")
}
/**
* Interprets the [MST] node with this [Algebra] and optional [arguments]
*/
public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
InnerAlgebra(algebra, arguments).evaluate(this)
public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = when (this) {
is MST.Numeric -> (algebra as NumericAlgebra<T>?)?.number(value)
?: error("Numeric nodes are not supported by $algebra")
is Symbol -> algebra.bindSymbolOrNull(this) ?: arguments.getValue(this)
is MST.Unary -> when {
algebra is NumericAlgebra && this.value is MST.Numeric -> algebra.unaryOperation(
this.operation,
algebra.number(this.value.value),
)
else -> algebra.unaryOperationFunction(this.operation)(this.value.interpret(algebra, arguments))
}
is MST.Binary -> when {
algebra is NumericAlgebra && this.left is MST.Numeric && this.right is MST.Numeric -> algebra.binaryOperation(
this.operation,
algebra.number(this.left.value),
algebra.number(this.right.value),
)
algebra is NumericAlgebra && this.left is MST.Numeric -> algebra.leftSideNumberOperation(
this.operation,
this.left.value,
this.right.interpret(algebra, arguments),
)
algebra is NumericAlgebra && this.right is MST.Numeric -> algebra.rightSideNumberOperation(
this.operation,
left.interpret(algebra, arguments),
right.value,
)
else -> algebra.binaryOperation(
this.operation,
this.left.interpret(algebra, arguments),
this.right.interpret(algebra, arguments),
)
}
}
/**
* Interprets the [MST] node with this [Algebra] and optional [arguments]
@ -111,12 +95,17 @@ public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T
* @param algebra the algebra that provides operations.
* @return the value of expression.
*/
public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
interpret(algebra, mapOf(*arguments))
public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = interpret(
algebra,
when (arguments.size) {
0 -> emptyMap()
1 -> mapOf(arguments[0])
else -> hashMapOf(*arguments)
},
)
/**
* Interpret this [MST] as expression.
*/
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
interpret(algebra, arguments)
}
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> =
Expression { arguments -> interpret(algebra, arguments) }