Remove Algebra.evaluate(MST) by inlining in into interpret

This commit is contained in:
Iaroslav Postovalov 2022-02-08 11:43:50 +07:00
parent 53ab8334dd
commit e094f6e8ee
3 changed files with 65 additions and 76 deletions

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.ast
import space.kscience.kmath.complex.Complex import space.kscience.kmath.complex.Complex
import space.kscience.kmath.complex.ComplexField import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.expressions.evaluate import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test import kotlin.test.Test
@ -17,14 +17,14 @@ internal class TestParser {
@Test @Test
fun evaluateParsedMst() { fun evaluateParsedMst() {
val mst = "2+2*(2+2)".parseMath() val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.evaluate(mst) val res = mst.interpret(ComplexField)
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
@Test @Test
fun evaluateMstSymbol() { fun evaluateMstSymbol() {
val mst = "i".parseMath() val mst = "i".parseMath()
val res = ComplexField.evaluate(mst) val res = mst.interpret(ComplexField)
assertEquals(ComplexField.i, res) assertEquals(ComplexField.i, res)
} }
@ -32,7 +32,7 @@ internal class TestParser {
@Test @Test
fun evaluateMstUnary() { fun evaluateMstUnary() {
val mst = "sin(0)".parseMath() val mst = "sin(0)".parseMath()
val res = DoubleField.evaluate(mst) val res = mst.interpret(DoubleField)
assertEquals(0.0, res) assertEquals(0.0, res)
} }
@ -53,7 +53,7 @@ internal class TestParser {
} }
val mst = "magic(a, b)".parseMath() val mst = "magic(a, b)".parseMath()
val res = magicalAlgebra.evaluate(mst) val res = mst.interpret(magicalAlgebra)
assertEquals("a ★ b", res) assertEquals("a ★ b", res)
} }
} }

View File

@ -5,35 +5,35 @@
package space.kscience.kmath.ast package space.kscience.kmath.ast
import space.kscience.kmath.expressions.evaluate import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class TestParserPrecedence { internal class TestParserPrecedence {
@Test @Test
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath())) fun test1(): Unit = assertEquals(6.0, "2*2+2".parseMath().interpret(f))
@Test @Test
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath())) fun test2(): Unit = assertEquals(6.0, "2+2*2".parseMath().interpret(f))
@Test @Test
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath())) fun test3(): Unit = assertEquals(10.0, "2^3+2".parseMath().interpret(f))
@Test @Test
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath())) fun test4(): Unit = assertEquals(10.0, "2+2^3".parseMath().interpret(f))
@Test @Test
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath())) fun test5(): Unit = assertEquals(16.0, "2^3*2".parseMath().interpret(f))
@Test @Test
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath())) fun test6(): Unit = assertEquals(16.0, "2*2^3".parseMath().interpret(f))
@Test @Test
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath())) fun test7(): Unit = assertEquals(18.0, "2+2^3*2".parseMath().interpret(f))
@Test @Test
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath())) fun test8(): Unit = assertEquals(18.0, "2*2^3+2".parseMath().interpret(f))
private companion object { private companion object {
private val f = DoubleField private val f = DoubleField

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.expressions
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra import space.kscience.kmath.operations.NumericAlgebra
import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.bindSymbolOrNull
/** /**
* A Mathematical Syntax Tree (MST) node for mathematical expressions. * A Mathematical Syntax Tree (MST) node for mathematical expressions.
@ -43,66 +43,50 @@ public sealed interface MST {
// TODO add a function with named arguments // TODO add a function with named arguments
/**
* Interprets the [MST] node with this [Algebra].
*
* @receiver the algebra that provides operations.
* @param node the node to evaluate.
* @return the value of expression.
* @author Alexander Nozik
*/
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is Symbol -> bindSymbol(node)
is MST.Unary -> when {
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
}
is MST.Binary -> when {
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
this is NumericAlgebra && node.left is MST.Numeric ->
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
this is NumericAlgebra && node.right is MST.Numeric ->
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
}
}
internal class InnerAlgebra<T>(val algebra: Algebra<T>, val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun bindSymbolOrNull(value: String): T? = algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)]
override fun unaryOperation(operation: String, arg: T): T =
algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
algebra.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
algebra.binaryOperationFunction(operation)
@Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else
error("Numeric nodes are not supported by $this")
}
/** /**
* Interprets the [MST] node with this [Algebra] and optional [arguments] * Interprets the [MST] node with this [Algebra] and optional [arguments]
*/ */
public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = when (this) {
InnerAlgebra(algebra, arguments).evaluate(this) is MST.Numeric -> (algebra as NumericAlgebra<T>?)?.number(value)
?: error("Numeric nodes are not supported by $algebra")
is Symbol -> algebra.bindSymbolOrNull(this) ?: arguments.getValue(this)
is MST.Unary -> when {
algebra is NumericAlgebra && this.value is MST.Numeric -> algebra.unaryOperation(
this.operation,
algebra.number(this.value.value),
)
else -> algebra.unaryOperationFunction(this.operation)(this.value.interpret(algebra, arguments))
}
is MST.Binary -> when {
algebra is NumericAlgebra && this.left is MST.Numeric && this.right is MST.Numeric -> algebra.binaryOperation(
this.operation,
algebra.number(this.left.value),
algebra.number(this.right.value),
)
algebra is NumericAlgebra && this.left is MST.Numeric -> algebra.leftSideNumberOperation(
this.operation,
this.left.value,
this.right.interpret(algebra, arguments),
)
algebra is NumericAlgebra && this.right is MST.Numeric -> algebra.rightSideNumberOperation(
this.operation,
left.interpret(algebra, arguments),
right.value,
)
else -> algebra.binaryOperation(
this.operation,
this.left.interpret(algebra, arguments),
this.right.interpret(algebra, arguments),
)
}
}
/** /**
* Interprets the [MST] node with this [Algebra] and optional [arguments] * Interprets the [MST] node with this [Algebra] and optional [arguments]
@ -111,12 +95,17 @@ public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T
* @param algebra the algebra that provides operations. * @param algebra the algebra that provides operations.
* @return the value of expression. * @return the value of expression.
*/ */
public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = interpret(
interpret(algebra, mapOf(*arguments)) algebra,
when (arguments.size) {
0 -> emptyMap()
1 -> mapOf(arguments[0])
else -> hashMapOf(*arguments)
},
)
/** /**
* Interpret this [MST] as expression. * Interpret this [MST] as expression.
*/ */
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments -> public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> =
interpret(algebra, arguments) Expression { arguments -> interpret(algebra, arguments) }
}