diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt index b838245e1..d0c3a789e 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt @@ -7,7 +7,7 @@ package space.kscience.kmath.ast import space.kscience.kmath.complex.Complex import space.kscience.kmath.complex.ComplexField -import space.kscience.kmath.expressions.evaluate +import space.kscience.kmath.expressions.interpret import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.DoubleField import kotlin.test.Test @@ -17,14 +17,14 @@ internal class TestParser { @Test fun evaluateParsedMst() { val mst = "2+2*(2+2)".parseMath() - val res = ComplexField.evaluate(mst) + val res = mst.interpret(ComplexField) assertEquals(Complex(10.0, 0.0), res) } @Test fun evaluateMstSymbol() { val mst = "i".parseMath() - val res = ComplexField.evaluate(mst) + val res = mst.interpret(ComplexField) assertEquals(ComplexField.i, res) } @@ -32,7 +32,7 @@ internal class TestParser { @Test fun evaluateMstUnary() { val mst = "sin(0)".parseMath() - val res = DoubleField.evaluate(mst) + val res = mst.interpret(DoubleField) assertEquals(0.0, res) } @@ -53,7 +53,7 @@ internal class TestParser { } val mst = "magic(a, b)".parseMath() - val res = magicalAlgebra.evaluate(mst) + val res = mst.interpret(magicalAlgebra) assertEquals("a ★ b", res) } } diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt index bb6bb3ce1..42cf5ce58 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt @@ -5,35 +5,35 @@ package space.kscience.kmath.ast -import space.kscience.kmath.expressions.evaluate +import space.kscience.kmath.expressions.interpret import space.kscience.kmath.operations.DoubleField import kotlin.test.Test import kotlin.test.assertEquals internal class TestParserPrecedence { @Test - fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath())) + fun test1(): Unit = assertEquals(6.0, "2*2+2".parseMath().interpret(f)) @Test - fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath())) + fun test2(): Unit = assertEquals(6.0, "2+2*2".parseMath().interpret(f)) @Test - fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath())) + fun test3(): Unit = assertEquals(10.0, "2^3+2".parseMath().interpret(f)) @Test - fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath())) + fun test4(): Unit = assertEquals(10.0, "2+2^3".parseMath().interpret(f)) @Test - fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath())) + fun test5(): Unit = assertEquals(16.0, "2^3*2".parseMath().interpret(f)) @Test - fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath())) + fun test6(): Unit = assertEquals(16.0, "2*2^3".parseMath().interpret(f)) @Test - fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath())) + fun test7(): Unit = assertEquals(18.0, "2+2^3*2".parseMath().interpret(f)) @Test - fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath())) + fun test8(): Unit = assertEquals(18.0, "2*2^3+2".parseMath().interpret(f)) private companion object { private val f = DoubleField diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt index 7533024a1..24e96e845 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt @@ -7,7 +7,7 @@ package space.kscience.kmath.expressions import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.NumericAlgebra -import space.kscience.kmath.operations.bindSymbol +import space.kscience.kmath.operations.bindSymbolOrNull /** * A Mathematical Syntax Tree (MST) node for mathematical expressions. @@ -43,66 +43,50 @@ public sealed interface MST { // TODO add a function with named arguments -/** - * Interprets the [MST] node with this [Algebra]. - * - * @receiver the algebra that provides operations. - * @param node the node to evaluate. - * @return the value of expression. - * @author Alexander Nozik - */ -public fun Algebra.evaluate(node: MST): T = when (node) { - is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) - ?: error("Numeric nodes are not supported by $this") - - is Symbol -> bindSymbol(node) - - is MST.Unary -> when { - this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value)) - else -> unaryOperationFunction(node.operation)(evaluate(node.value)) - } - - is MST.Binary -> when { - this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric -> - binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value)) - - this is NumericAlgebra && node.left is MST.Numeric -> - leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right)) - - this is NumericAlgebra && node.right is MST.Numeric -> - rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value) - - else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right)) - } -} - -internal class InnerAlgebra(val algebra: Algebra, val arguments: Map) : NumericAlgebra { - override fun bindSymbolOrNull(value: String): T? = algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)] - - override fun unaryOperation(operation: String, arg: T): T = - algebra.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: T, right: T): T = - algebra.binaryOperation(operation, left, right) - - override fun unaryOperationFunction(operation: String): (arg: T) -> T = - algebra.unaryOperationFunction(operation) - - override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = - algebra.binaryOperationFunction(operation) - - @Suppress("UNCHECKED_CAST") - override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) - (algebra as NumericAlgebra).number(value) - else - error("Numeric nodes are not supported by $this") -} /** * Interprets the [MST] node with this [Algebra] and optional [arguments] */ -public fun MST.interpret(algebra: Algebra, arguments: Map): T = - InnerAlgebra(algebra, arguments).evaluate(this) +public fun MST.interpret(algebra: Algebra, arguments: Map): T = when (this) { + is MST.Numeric -> (algebra as NumericAlgebra?)?.number(value) + ?: error("Numeric nodes are not supported by $algebra") + + is Symbol -> algebra.bindSymbolOrNull(this) ?: arguments.getValue(this) + + is MST.Unary -> when { + algebra is NumericAlgebra && this.value is MST.Numeric -> algebra.unaryOperation( + this.operation, + algebra.number(this.value.value), + ) + else -> algebra.unaryOperationFunction(this.operation)(this.value.interpret(algebra, arguments)) + } + + is MST.Binary -> when { + algebra is NumericAlgebra && this.left is MST.Numeric && this.right is MST.Numeric -> algebra.binaryOperation( + this.operation, + algebra.number(this.left.value), + algebra.number(this.right.value), + ) + + algebra is NumericAlgebra && this.left is MST.Numeric -> algebra.leftSideNumberOperation( + this.operation, + this.left.value, + this.right.interpret(algebra, arguments), + ) + + algebra is NumericAlgebra && this.right is MST.Numeric -> algebra.rightSideNumberOperation( + this.operation, + left.interpret(algebra, arguments), + right.value, + ) + + else -> algebra.binaryOperation( + this.operation, + this.left.interpret(algebra, arguments), + this.right.interpret(algebra, arguments), + ) + } +} /** * Interprets the [MST] node with this [Algebra] and optional [arguments] @@ -111,12 +95,17 @@ public fun MST.interpret(algebra: Algebra, arguments: Map): T * @param algebra the algebra that provides operations. * @return the value of expression. */ -public fun MST.interpret(algebra: Algebra, vararg arguments: Pair): T = - interpret(algebra, mapOf(*arguments)) +public fun MST.interpret(algebra: Algebra, vararg arguments: Pair): T = interpret( + algebra, + when (arguments.size) { + 0 -> emptyMap() + 1 -> mapOf(arguments[0]) + else -> hashMapOf(*arguments) + }, +) /** * Interpret this [MST] as expression. */ -public fun MST.toExpression(algebra: Algebra): Expression = Expression { arguments -> - interpret(algebra, arguments) -} +public fun MST.toExpression(algebra: Algebra): Expression = + Expression { arguments -> interpret(algebra, arguments) }