From bce2a8460e22e85b7a9d38f2c8d40da475e1983b Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 20 Jul 2020 20:40:03 +0700 Subject: [PATCH] Add reference ANTLR grammar, implement missing rules and operations in parser, also add support for symbols in ASM --- kmath-ast/reference/ArithmeticsEvaluator.g4 | 59 +++++++++++++++ .../kotlin/scientifik/kmath/ast/MST.kt | 38 +++++----- .../kotlin/scientifik/kmath/ast/parser.kt | 71 ++++++++++++------- .../kotlin/scientifik/kmath/asm/asm.kt | 15 +++- .../kmath/asm/internal/AsmBuilder.kt | 2 +- .../kotlin/scietifik/kmath/ast/ParserTest.kt | 35 +++++++++ 6 files changed, 173 insertions(+), 47 deletions(-) create mode 100644 kmath-ast/reference/ArithmeticsEvaluator.g4 diff --git a/kmath-ast/reference/ArithmeticsEvaluator.g4 b/kmath-ast/reference/ArithmeticsEvaluator.g4 new file mode 100644 index 000000000..684a9de44 --- /dev/null +++ b/kmath-ast/reference/ArithmeticsEvaluator.g4 @@ -0,0 +1,59 @@ +grammar ArithmeticsEvaluator; + +fragment DIGIT: '0'..'9'; +fragment LETTER: 'a'..'z'; +fragment UNDERSCORE: '_'; + +ID: (LETTER | UNDERSCORE) (LETTER | UNDERSCORE | DIGIT)*; +NUM: (DIGIT | '.')+ ([eE] MINUS? DIGIT+)?; +MUL: '*'; +DIV: '/'; +PLUS: '+'; +MINUS: '-'; +POW: '^'; +COMMA: ','; +LPAR: '('; +RPAR: ')'; +WS: [ \n\t\r]+ -> skip; + +number + : NUM + ; + +singular + : ID + ; + +unaryFunction + : ID LPAR subSumChain RPAR + ; + +binaryFunction + : ID LPAR subSumChain COMMA subSumChain RPAR + ; + +term + : number + | singular + | unaryFunction + | binaryFunction + | MINUS term + | LPAR subSumChain RPAR + ; + +powChain + : term (POW term)* + ; + + +divMulChain + : powChain ((PLUS | MINUS) powChain)* + ; + +subSumChain + : divMulChain ((DIV | MUL) divMulChain)* + ; + +rootParser + : subSumChain EOF + ; diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt index 142d27f93..e87df882f 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt @@ -41,27 +41,25 @@ sealed class MST { //TODO add a function with named arguments -fun Algebra.evaluate(node: MST): T { - return when (node) { - is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) - ?: error("Numeric nodes are not supported by $this") - is MST.Symbolic -> symbol(node.value) - is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) - is MST.Binary -> when { - this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) - node.left is MST.Numeric && node.right is MST.Numeric -> { - val number = RealField.binaryOperation( - node.operation, - node.left.value.toDouble(), - node.right.value.toDouble() - ) - number(number) - } - node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) - node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) - else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) +fun Algebra.evaluate(node: MST): T = when (node) { + is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) + ?: error("Numeric nodes are not supported by $this") + is MST.Symbolic -> symbol(node.value) + is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) + is MST.Binary -> when { + this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + node.left is MST.Numeric && node.right is MST.Numeric -> { + val number = RealField.binaryOperation( + node.operation, + node.left.value.toDouble(), + node.right.value.toDouble() + ) + number(number) } + node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) + node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) + else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) } } -fun MST.compile(algebra: Algebra): T = algebra.evaluate(this) \ No newline at end of file +fun MST.compile(algebra: Algebra): T = algebra.evaluate(this) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt index 30a92c5ae..ae0b6656d 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt @@ -5,6 +5,8 @@ import com.github.h0tk3y.betterParse.grammar.Grammar import com.github.h0tk3y.betterParse.grammar.parseToEnd import com.github.h0tk3y.betterParse.grammar.parser import com.github.h0tk3y.betterParse.grammar.tryParseToEnd +import com.github.h0tk3y.betterParse.lexer.Token +import com.github.h0tk3y.betterParse.lexer.TokenMatch import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.Parser import scientifik.kmath.operations.FieldOperations @@ -15,45 +17,66 @@ import scientifik.kmath.operations.SpaceOperations /** * TODO move to common */ -private object ArithmeticsEvaluator : Grammar() { - val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex()) - val lpar by token("\\(".toRegex()) - val rpar by token("\\)".toRegex()) - val mul by token("\\*".toRegex()) - val pow by token("\\^".toRegex()) - val div by token("/".toRegex()) - val minus by token("-".toRegex()) - val plus by token("\\+".toRegex()) - val ws by token("\\s+".toRegex(), ignore = true) +object ArithmeticsEvaluator : Grammar() { + private val num: Token by token("[\\d.]+(?:[eE]-?\\d+)?".toRegex()) + private val id: Token by token("[a-z_][\\da-z_]*".toRegex()) + private val lpar: Token by token("\\(".toRegex()) + private val rpar: Token by token("\\)".toRegex()) + private val comma: Token by token(",".toRegex()) + private val mul: Token by token("\\*".toRegex()) + private val pow: Token by token("\\^".toRegex()) + private val div: Token by token("/".toRegex()) + private val minus: Token by token("-".toRegex()) + private val plus: Token by token("\\+".toRegex()) + private val ws: Token by token("\\s+".toRegex(), ignore = true) - val number: Parser by num use { MST.Numeric(text.toDouble()) } + private val number: Parser by num use { MST.Numeric(text.toDouble()) } + private val singular: Parser by id use { MST.Symbolic(text) } - val term: Parser by number or - (skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or - (skip(lpar) and parser(this::rootParser) and skip(rpar)) + private val unaryFunction: Parser by (id and skip(lpar) and parser(::subSumChain) and skip(rpar)) + .map { (id, term) -> MST.Unary(id.text, term) } - val powChain by leftAssociative(term, pow) { a, _, b -> + private val binaryFunction: Parser by id + .and(skip(lpar)) + .and(parser(::subSumChain)) + .and(skip(comma)) + .and(parser(::subSumChain)) + .and(skip(rpar)) + .map { (id, left, right) -> MST.Binary(id.text, left, right) } + + private val term: Parser by number + .or(binaryFunction) + .or(unaryFunction) + .or(singular) + .or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) + .or(skip(lpar) and parser(::subSumChain) and skip(rpar)) + + private val powChain: Parser by leftAssociative(term = term, operator = pow) { a, _, b -> MST.Binary(PowerOperations.POW_OPERATION, a, b) } - val divMulChain: Parser by leftAssociative(powChain, div or mul use { type }) { a, op, b -> - if (op == div) { + private val divMulChain: Parser by leftAssociative( + term = powChain, + operator = div or mul use TokenMatch::type + ) { a, op, b -> + if (op == div) MST.Binary(FieldOperations.DIV_OPERATION, a, b) - } else { + else MST.Binary(RingOperations.TIMES_OPERATION, a, b) - } } - val subSumChain: Parser by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b -> - if (op == plus) { + private val subSumChain: Parser by leftAssociative( + term = divMulChain, + operator = plus or minus use TokenMatch::type + ) { a, op, b -> + if (op == plus) MST.Binary(SpaceOperations.PLUS_OPERATION, a, b) - } else { + else MST.Binary(SpaceOperations.MINUS_OPERATION, a, b) - } } override val rootParser: Parser by subSumChain } fun String.tryParseMath(): ParseResult = ArithmeticsEvaluator.tryParseToEnd(this) -fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) \ No newline at end of file +fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index f0022a43a..ee0ea15ff 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST import scientifik.kmath.ast.MstExpression import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.NumericAlgebra import kotlin.reflect.KClass /** @@ -17,7 +16,19 @@ import kotlin.reflect.KClass fun MST.compileWith(type: KClass, algebra: Algebra): Expression { fun AsmBuilder.visit(node: MST) { when (node) { - is MST.Symbolic -> loadVariable(node.value) + is MST.Symbolic -> { + val symbol = try { + algebra.symbol(node.value) + } catch (ignored: Throwable) { + null + } + + if (symbol != null) + loadTConstant(symbol) + else + loadVariable(node.value) + } + is MST.Numeric -> loadNumeric(node.value) is MST.Unary -> buildAlgebraOperationCall( diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 5531fd5dc..a947b3478 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -288,7 +288,7 @@ internal class AsmBuilder internal constructor( /** * Loads a [T] constant from [constants]. */ - private fun loadTConstant(value: T) { + internal fun loadTConstant(value: T) { if (classOfT in INLINABLE_NUMBERS) { val expectedType = expectationStack.pop() val mustBeBoxed = expectedType.sort == Type.OBJECT diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt index 5394a4b00..9179c3428 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt @@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.mstInField import scientifik.kmath.ast.parseMath import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField +import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals @@ -22,4 +24,37 @@ internal class ParserTest { val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() assertEquals(Complex(10.0, 0.0), res) } + + @Test + fun `evaluate MST with singular`() { + val mst = "i".parseMath() + val res = ComplexField.evaluate(mst) + assertEquals(ComplexField.i, res) + } + + + @Test + fun `evaluate MST with unary function`() { + val mst = "sin(0)".parseMath() + val res = RealField.evaluate(mst) + assertEquals(0.0, res) + } + + @Test + fun `evaluate MST with binary function`() { + val magicalAlgebra = object : Algebra { + override fun symbol(value: String): String = value + + override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError() + + override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) { + "magic" -> "$left ★ $right" + else -> throw NotImplementedError() + } + } + + val mst = "magic(a, b)".parseMath() + val res = magicalAlgebra.evaluate(mst) + assertEquals("a ★ b", res) + } }