Add reference ANTLR grammar, implement missing rules and operations in parser, also add support for symbols in ASM

This commit is contained in:
Iaroslav Postovalov 2020-07-20 20:40:03 +07:00
parent 3d85c22497
commit bce2a8460e
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B
6 changed files with 173 additions and 47 deletions

View File

@ -0,0 +1,59 @@
grammar ArithmeticsEvaluator;
fragment DIGIT: '0'..'9';
fragment LETTER: 'a'..'z';
fragment UNDERSCORE: '_';
ID: (LETTER | UNDERSCORE) (LETTER | UNDERSCORE | DIGIT)*;
NUM: (DIGIT | '.')+ ([eE] MINUS? DIGIT+)?;
MUL: '*';
DIV: '/';
PLUS: '+';
MINUS: '-';
POW: '^';
COMMA: ',';
LPAR: '(';
RPAR: ')';
WS: [ \n\t\r]+ -> skip;
number
: NUM
;
singular
: ID
;
unaryFunction
: ID LPAR subSumChain RPAR
;
binaryFunction
: ID LPAR subSumChain COMMA subSumChain RPAR
;
term
: number
| singular
| unaryFunction
| binaryFunction
| MINUS term
| LPAR subSumChain RPAR
;
powChain
: term (POW term)*
;
divMulChain
: powChain ((PLUS | MINUS) powChain)*
;
subSumChain
: divMulChain ((DIV | MUL) divMulChain)*
;
rootParser
: subSumChain EOF
;

View File

@ -41,26 +41,24 @@ sealed class MST {
//TODO add a function with named arguments
fun <T> Algebra<T>.evaluate(node: MST): T {
return when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation(
node.operation,
node.left.value.toDouble(),
node.right.value.toDouble()
)
number(number)
}
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation(
node.operation,
node.left.value.toDouble(),
node.right.value.toDouble()
)
number(number)
}
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
}
}

View File

@ -5,6 +5,8 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
import com.github.h0tk3y.betterParse.grammar.parseToEnd
import com.github.h0tk3y.betterParse.grammar.parser
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
import com.github.h0tk3y.betterParse.lexer.Token
import com.github.h0tk3y.betterParse.lexer.TokenMatch
import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser
import scientifik.kmath.operations.FieldOperations
@ -15,41 +17,62 @@ import scientifik.kmath.operations.SpaceOperations
/**
* TODO move to common
*/
private object ArithmeticsEvaluator : Grammar<MST>() {
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
val lpar by token("\\(".toRegex())
val rpar by token("\\)".toRegex())
val mul by token("\\*".toRegex())
val pow by token("\\^".toRegex())
val div by token("/".toRegex())
val minus by token("-".toRegex())
val plus by token("\\+".toRegex())
val ws by token("\\s+".toRegex(), ignore = true)
object ArithmeticsEvaluator : Grammar<MST>() {
private val num: Token by token("[\\d.]+(?:[eE]-?\\d+)?".toRegex())
private val id: Token by token("[a-z_][\\da-z_]*".toRegex())
private val lpar: Token by token("\\(".toRegex())
private val rpar: Token by token("\\)".toRegex())
private val comma: Token by token(",".toRegex())
private val mul: Token by token("\\*".toRegex())
private val pow: Token by token("\\^".toRegex())
private val div: Token by token("/".toRegex())
private val minus: Token by token("-".toRegex())
private val plus: Token by token("\\+".toRegex())
private val ws: Token by token("\\s+".toRegex(), ignore = true)
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
val term: Parser<MST> by number or
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
(skip(lpar) and parser(this::rootParser) and skip(rpar))
private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
.map { (id, term) -> MST.Unary(id.text, term) }
val powChain by leftAssociative(term, pow) { a, _, b ->
private val binaryFunction: Parser<MST> by id
.and(skip(lpar))
.and(parser(::subSumChain))
.and(skip(comma))
.and(parser(::subSumChain))
.and(skip(rpar))
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
private val term: Parser<MST> by number
.or(binaryFunction)
.or(unaryFunction)
.or(singular)
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
MST.Binary(PowerOperations.POW_OPERATION, a, b)
}
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
if (op == div) {
private val divMulChain: Parser<MST> by leftAssociative(
term = powChain,
operator = div or mul use TokenMatch::type
) { a, op, b ->
if (op == div)
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
} else {
else
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
}
}
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
if (op == plus) {
private val subSumChain: Parser<MST> by leftAssociative(
term = divMulChain,
operator = plus or minus use TokenMatch::type
) { a, op, b ->
if (op == plus)
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
} else {
else
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
}
}
override val rootParser: Parser<MST> by subSumChain

View File

@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MstExpression
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import kotlin.reflect.KClass
/**
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST) {
when (node) {
is MST.Symbolic -> loadVariable(node.value)
is MST.Symbolic -> {
val symbol = try {
algebra.symbol(node.value)
} catch (ignored: Throwable) {
null
}
if (symbol != null)
loadTConstant(symbol)
else
loadVariable(node.value)
}
is MST.Numeric -> loadNumeric(node.value)
is MST.Unary -> buildAlgebraOperationCall(

View File

@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
/**
* Loads a [T] constant from [constants].
*/
private fun loadTConstant(value: T) {
internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT

View File

@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
@ -22,4 +24,37 @@ internal class ParserTest {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
assertEquals(Complex(10.0, 0.0), res)
}
@Test
fun `evaluate MST with singular`() {
val mst = "i".parseMath()
val res = ComplexField.evaluate(mst)
assertEquals(ComplexField.i, res)
}
@Test
fun `evaluate MST with unary function`() {
val mst = "sin(0)".parseMath()
val res = RealField.evaluate(mst)
assertEquals(0.0, res)
}
@Test
fun `evaluate MST with binary function`() {
val magicalAlgebra = object : Algebra<String> {
override fun symbol(value: String): String = value
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
"magic" -> "$left$right"
else -> throw NotImplementedError()
}
}
val mst = "magic(a, b)".parseMath()
val res = magicalAlgebra.evaluate(mst)
assertEquals("a ★ b", res)
}
}