forked from kscience/kmath
Add reference ANTLR grammar, implement missing rules and operations in parser, also add support for symbols in ASM
This commit is contained in:
parent
3d85c22497
commit
bce2a8460e
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
grammar ArithmeticsEvaluator;
|
||||||
|
|
||||||
|
fragment DIGIT: '0'..'9';
|
||||||
|
fragment LETTER: 'a'..'z';
|
||||||
|
fragment UNDERSCORE: '_';
|
||||||
|
|
||||||
|
ID: (LETTER | UNDERSCORE) (LETTER | UNDERSCORE | DIGIT)*;
|
||||||
|
NUM: (DIGIT | '.')+ ([eE] MINUS? DIGIT+)?;
|
||||||
|
MUL: '*';
|
||||||
|
DIV: '/';
|
||||||
|
PLUS: '+';
|
||||||
|
MINUS: '-';
|
||||||
|
POW: '^';
|
||||||
|
COMMA: ',';
|
||||||
|
LPAR: '(';
|
||||||
|
RPAR: ')';
|
||||||
|
WS: [ \n\t\r]+ -> skip;
|
||||||
|
|
||||||
|
number
|
||||||
|
: NUM
|
||||||
|
;
|
||||||
|
|
||||||
|
singular
|
||||||
|
: ID
|
||||||
|
;
|
||||||
|
|
||||||
|
unaryFunction
|
||||||
|
: ID LPAR subSumChain RPAR
|
||||||
|
;
|
||||||
|
|
||||||
|
binaryFunction
|
||||||
|
: ID LPAR subSumChain COMMA subSumChain RPAR
|
||||||
|
;
|
||||||
|
|
||||||
|
term
|
||||||
|
: number
|
||||||
|
| singular
|
||||||
|
| unaryFunction
|
||||||
|
| binaryFunction
|
||||||
|
| MINUS term
|
||||||
|
| LPAR subSumChain RPAR
|
||||||
|
;
|
||||||
|
|
||||||
|
powChain
|
||||||
|
: term (POW term)*
|
||||||
|
;
|
||||||
|
|
||||||
|
|
||||||
|
divMulChain
|
||||||
|
: powChain ((PLUS | MINUS) powChain)*
|
||||||
|
;
|
||||||
|
|
||||||
|
subSumChain
|
||||||
|
: divMulChain ((DIV | MUL) divMulChain)*
|
||||||
|
;
|
||||||
|
|
||||||
|
rootParser
|
||||||
|
: subSumChain EOF
|
||||||
|
;
|
@ -41,26 +41,24 @@ sealed class MST {
|
|||||||
|
|
||||||
//TODO add a function with named arguments
|
//TODO add a function with named arguments
|
||||||
|
|
||||||
fun <T> Algebra<T>.evaluate(node: MST): T {
|
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||||
return when (node) {
|
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
?: error("Numeric nodes are not supported by $this")
|
||||||
?: error("Numeric nodes are not supported by $this")
|
is MST.Symbolic -> symbol(node.value)
|
||||||
is MST.Symbolic -> symbol(node.value)
|
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
is MST.Binary -> when {
|
||||||
is MST.Binary -> when {
|
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
val number = RealField.binaryOperation(
|
||||||
val number = RealField.binaryOperation(
|
node.operation,
|
||||||
node.operation,
|
node.left.value.toDouble(),
|
||||||
node.left.value.toDouble(),
|
node.right.value.toDouble()
|
||||||
node.right.value.toDouble()
|
)
|
||||||
)
|
number(number)
|
||||||
number(number)
|
|
||||||
}
|
|
||||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
|
||||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
|
||||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
|
||||||
}
|
}
|
||||||
|
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||||
|
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||||
|
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
|
|||||||
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
||||||
import com.github.h0tk3y.betterParse.grammar.parser
|
import com.github.h0tk3y.betterParse.grammar.parser
|
||||||
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.Token
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
||||||
import com.github.h0tk3y.betterParse.parser.ParseResult
|
import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||||
import com.github.h0tk3y.betterParse.parser.Parser
|
import com.github.h0tk3y.betterParse.parser.Parser
|
||||||
import scientifik.kmath.operations.FieldOperations
|
import scientifik.kmath.operations.FieldOperations
|
||||||
@ -15,41 +17,62 @@ import scientifik.kmath.operations.SpaceOperations
|
|||||||
/**
|
/**
|
||||||
* TODO move to common
|
* TODO move to common
|
||||||
*/
|
*/
|
||||||
private object ArithmeticsEvaluator : Grammar<MST>() {
|
object ArithmeticsEvaluator : Grammar<MST>() {
|
||||||
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
private val num: Token by token("[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
||||||
val lpar by token("\\(".toRegex())
|
private val id: Token by token("[a-z_][\\da-z_]*".toRegex())
|
||||||
val rpar by token("\\)".toRegex())
|
private val lpar: Token by token("\\(".toRegex())
|
||||||
val mul by token("\\*".toRegex())
|
private val rpar: Token by token("\\)".toRegex())
|
||||||
val pow by token("\\^".toRegex())
|
private val comma: Token by token(",".toRegex())
|
||||||
val div by token("/".toRegex())
|
private val mul: Token by token("\\*".toRegex())
|
||||||
val minus by token("-".toRegex())
|
private val pow: Token by token("\\^".toRegex())
|
||||||
val plus by token("\\+".toRegex())
|
private val div: Token by token("/".toRegex())
|
||||||
val ws by token("\\s+".toRegex(), ignore = true)
|
private val minus: Token by token("-".toRegex())
|
||||||
|
private val plus: Token by token("\\+".toRegex())
|
||||||
|
private val ws: Token by token("\\s+".toRegex(), ignore = true)
|
||||||
|
|
||||||
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||||
|
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
|
||||||
|
|
||||||
val term: Parser<MST> by number or
|
private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||||
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
|
.map { (id, term) -> MST.Unary(id.text, term) }
|
||||||
(skip(lpar) and parser(this::rootParser) and skip(rpar))
|
|
||||||
|
|
||||||
val powChain by leftAssociative(term, pow) { a, _, b ->
|
private val binaryFunction: Parser<MST> by id
|
||||||
|
.and(skip(lpar))
|
||||||
|
.and(parser(::subSumChain))
|
||||||
|
.and(skip(comma))
|
||||||
|
.and(parser(::subSumChain))
|
||||||
|
.and(skip(rpar))
|
||||||
|
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
|
||||||
|
|
||||||
|
private val term: Parser<MST> by number
|
||||||
|
.or(binaryFunction)
|
||||||
|
.or(unaryFunction)
|
||||||
|
.or(singular)
|
||||||
|
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
|
||||||
|
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||||
|
|
||||||
|
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
|
||||||
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
|
private val divMulChain: Parser<MST> by leftAssociative(
|
||||||
if (op == div) {
|
term = powChain,
|
||||||
|
operator = div or mul use TokenMatch::type
|
||||||
|
) { a, op, b ->
|
||||||
|
if (op == div)
|
||||||
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
||||||
} else {
|
else
|
||||||
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
|
private val subSumChain: Parser<MST> by leftAssociative(
|
||||||
if (op == plus) {
|
term = divMulChain,
|
||||||
|
operator = plus or minus use TokenMatch::type
|
||||||
|
) { a, op, b ->
|
||||||
|
if (op == plus)
|
||||||
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
} else {
|
else
|
||||||
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override val rootParser: Parser<MST> by subSumChain
|
override val rootParser: Parser<MST> by subSumChain
|
||||||
|
@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
|
|||||||
import scientifik.kmath.ast.MstExpression
|
import scientifik.kmath.ast.MstExpression
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.NumericAlgebra
|
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
|
|||||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun AsmBuilder<T>.visit(node: MST) {
|
fun AsmBuilder<T>.visit(node: MST) {
|
||||||
when (node) {
|
when (node) {
|
||||||
is MST.Symbolic -> loadVariable(node.value)
|
is MST.Symbolic -> {
|
||||||
|
val symbol = try {
|
||||||
|
algebra.symbol(node.value)
|
||||||
|
} catch (ignored: Throwable) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
|
||||||
|
if (symbol != null)
|
||||||
|
loadTConstant(symbol)
|
||||||
|
else
|
||||||
|
loadVariable(node.value)
|
||||||
|
}
|
||||||
|
|
||||||
is MST.Numeric -> loadNumeric(node.value)
|
is MST.Numeric -> loadNumeric(node.value)
|
||||||
|
|
||||||
is MST.Unary -> buildAlgebraOperationCall(
|
is MST.Unary -> buildAlgebraOperationCall(
|
||||||
|
@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Loads a [T] constant from [constants].
|
* Loads a [T] constant from [constants].
|
||||||
*/
|
*/
|
||||||
private fun loadTConstant(value: T) {
|
internal fun loadTConstant(value: T) {
|
||||||
if (classOfT in INLINABLE_NUMBERS) {
|
if (classOfT in INLINABLE_NUMBERS) {
|
||||||
val expectedType = expectationStack.pop()
|
val expectedType = expectationStack.pop()
|
||||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||||
|
@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
|
|||||||
import scientifik.kmath.ast.mstInField
|
import scientifik.kmath.ast.mstInField
|
||||||
import scientifik.kmath.ast.parseMath
|
import scientifik.kmath.ast.parseMath
|
||||||
import scientifik.kmath.expressions.invoke
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -22,4 +24,37 @@ internal class ParserTest {
|
|||||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with singular`() {
|
||||||
|
val mst = "i".parseMath()
|
||||||
|
val res = ComplexField.evaluate(mst)
|
||||||
|
assertEquals(ComplexField.i, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with unary function`() {
|
||||||
|
val mst = "sin(0)".parseMath()
|
||||||
|
val res = RealField.evaluate(mst)
|
||||||
|
assertEquals(0.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with binary function`() {
|
||||||
|
val magicalAlgebra = object : Algebra<String> {
|
||||||
|
override fun symbol(value: String): String = value
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
|
||||||
|
"magic" -> "$left ★ $right"
|
||||||
|
else -> throw NotImplementedError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val mst = "magic(a, b)".parseMath()
|
||||||
|
val res = magicalAlgebra.evaluate(mst)
|
||||||
|
assertEquals("a ★ b", res)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user