Add reference ANTLR grammar, implement missing rules and operations in parser, also add support for symbols in ASM

This commit is contained in:
Iaroslav Postovalov 2020-07-20 20:40:03 +07:00
parent 3d85c22497
commit bce2a8460e
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B
6 changed files with 173 additions and 47 deletions

View File

@ -0,0 +1,59 @@
grammar ArithmeticsEvaluator;
fragment DIGIT: '0'..'9';
fragment LETTER: 'a'..'z';
fragment UNDERSCORE: '_';
ID: (LETTER | UNDERSCORE) (LETTER | UNDERSCORE | DIGIT)*;
NUM: (DIGIT | '.')+ ([eE] MINUS? DIGIT+)?;
MUL: '*';
DIV: '/';
PLUS: '+';
MINUS: '-';
POW: '^';
COMMA: ',';
LPAR: '(';
RPAR: ')';
WS: [ \n\t\r]+ -> skip;
number
: NUM
;
singular
: ID
;
unaryFunction
: ID LPAR subSumChain RPAR
;
binaryFunction
: ID LPAR subSumChain COMMA subSumChain RPAR
;
term
: number
| singular
| unaryFunction
| binaryFunction
| MINUS term
| LPAR subSumChain RPAR
;
powChain
: term (POW term)*
;
divMulChain
: powChain ((PLUS | MINUS) powChain)*
;
subSumChain
: divMulChain ((DIV | MUL) divMulChain)*
;
rootParser
: subSumChain EOF
;

View File

@ -41,8 +41,7 @@ sealed class MST {
//TODO add a function with named arguments //TODO add a function with named arguments
fun <T> Algebra<T>.evaluate(node: MST): T { fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
return when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value) is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this") ?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
@ -61,7 +60,6 @@ fun <T> Algebra<T>.evaluate(node: MST): T {
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
} }
}
} }
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this) fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)

View File

@ -5,6 +5,8 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
import com.github.h0tk3y.betterParse.grammar.parseToEnd import com.github.h0tk3y.betterParse.grammar.parseToEnd
import com.github.h0tk3y.betterParse.grammar.parser import com.github.h0tk3y.betterParse.grammar.parser
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
import com.github.h0tk3y.betterParse.lexer.Token
import com.github.h0tk3y.betterParse.lexer.TokenMatch
import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser import com.github.h0tk3y.betterParse.parser.Parser
import scientifik.kmath.operations.FieldOperations import scientifik.kmath.operations.FieldOperations
@ -15,42 +17,63 @@ import scientifik.kmath.operations.SpaceOperations
/** /**
* TODO move to common * TODO move to common
*/ */
private object ArithmeticsEvaluator : Grammar<MST>() { object ArithmeticsEvaluator : Grammar<MST>() {
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex()) private val num: Token by token("[\\d.]+(?:[eE]-?\\d+)?".toRegex())
val lpar by token("\\(".toRegex()) private val id: Token by token("[a-z_][\\da-z_]*".toRegex())
val rpar by token("\\)".toRegex()) private val lpar: Token by token("\\(".toRegex())
val mul by token("\\*".toRegex()) private val rpar: Token by token("\\)".toRegex())
val pow by token("\\^".toRegex()) private val comma: Token by token(",".toRegex())
val div by token("/".toRegex()) private val mul: Token by token("\\*".toRegex())
val minus by token("-".toRegex()) private val pow: Token by token("\\^".toRegex())
val plus by token("\\+".toRegex()) private val div: Token by token("/".toRegex())
val ws by token("\\s+".toRegex(), ignore = true) private val minus: Token by token("-".toRegex())
private val plus: Token by token("\\+".toRegex())
private val ws: Token by token("\\s+".toRegex(), ignore = true)
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) } private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
val term: Parser<MST> by number or private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or .map { (id, term) -> MST.Unary(id.text, term) }
(skip(lpar) and parser(this::rootParser) and skip(rpar))
val powChain by leftAssociative(term, pow) { a, _, b -> private val binaryFunction: Parser<MST> by id
.and(skip(lpar))
.and(parser(::subSumChain))
.and(skip(comma))
.and(parser(::subSumChain))
.and(skip(rpar))
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
private val term: Parser<MST> by number
.or(binaryFunction)
.or(unaryFunction)
.or(singular)
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
MST.Binary(PowerOperations.POW_OPERATION, a, b) MST.Binary(PowerOperations.POW_OPERATION, a, b)
} }
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b -> private val divMulChain: Parser<MST> by leftAssociative(
if (op == div) { term = powChain,
operator = div or mul use TokenMatch::type
) { a, op, b ->
if (op == div)
MST.Binary(FieldOperations.DIV_OPERATION, a, b) MST.Binary(FieldOperations.DIV_OPERATION, a, b)
} else { else
MST.Binary(RingOperations.TIMES_OPERATION, a, b) MST.Binary(RingOperations.TIMES_OPERATION, a, b)
} }
}
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b -> private val subSumChain: Parser<MST> by leftAssociative(
if (op == plus) { term = divMulChain,
operator = plus or minus use TokenMatch::type
) { a, op, b ->
if (op == plus)
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b) MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
} else { else
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b) MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
} }
}
override val rootParser: Parser<MST> by subSumChain override val rootParser: Parser<MST> by subSumChain
} }

View File

@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MstExpression import scientifik.kmath.ast.MstExpression
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> { fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST) { fun AsmBuilder<T>.visit(node: MST) {
when (node) { when (node) {
is MST.Symbolic -> loadVariable(node.value) is MST.Symbolic -> {
val symbol = try {
algebra.symbol(node.value)
} catch (ignored: Throwable) {
null
}
if (symbol != null)
loadTConstant(symbol)
else
loadVariable(node.value)
}
is MST.Numeric -> loadNumeric(node.value) is MST.Numeric -> loadNumeric(node.value)
is MST.Unary -> buildAlgebraOperationCall( is MST.Unary -> buildAlgebraOperationCall(

View File

@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* Loads a [T] constant from [constants]. * Loads a [T] constant from [constants].
*/ */
private fun loadTConstant(value: T) { internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) { if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop() val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT val mustBeBoxed = expectedType.sort == Type.OBJECT

View File

@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.mstInField import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -22,4 +24,37 @@ internal class ParserTest {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
@Test
fun `evaluate MST with singular`() {
val mst = "i".parseMath()
val res = ComplexField.evaluate(mst)
assertEquals(ComplexField.i, res)
}
@Test
fun `evaluate MST with unary function`() {
val mst = "sin(0)".parseMath()
val res = RealField.evaluate(mst)
assertEquals(0.0, res)
}
@Test
fun `evaluate MST with binary function`() {
val magicalAlgebra = object : Algebra<String> {
override fun symbol(value: String): String = value
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
"magic" -> "$left$right"
else -> throw NotImplementedError()
}
}
val mst = "magic(a, b)".parseMath()
val res = magicalAlgebra.evaluate(mst)
assertEquals("a ★ b", res)
}
} }