Extend mathematic operations support in the kmath-ast parser #120
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
@ -0,0 +1,59 @@
|
||||
|
||||
grammar ArithmeticsEvaluator;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
fragment DIGIT: '0'..'9';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
fragment LETTER: 'a'..'z';
|
||||
What about uppercase letters? What about uppercase letters?
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
Added. Added.
|
||||
fragment UNDERSCORE: '_';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
ID: (LETTER | UNDERSCORE) (LETTER | UNDERSCORE | DIGIT)*;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
NUM: (DIGIT | '.')+ ([eE] MINUS? DIGIT+)?;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
MUL: '*';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
DIV: '/';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
PLUS: '+';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
MINUS: '-';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
POW: '^';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
COMMA: ',';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
LPAR: '(';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
RPAR: ')';
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
WS: [ \n\t\r]+ -> skip;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
number
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
: NUM
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
singular
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
: ID
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
unaryFunction
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
: ID LPAR subSumChain RPAR
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
binaryFunction
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
: ID LPAR subSumChain COMMA subSumChain RPAR
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
term
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
: number
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
| singular
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
| unaryFunction
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
| binaryFunction
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
| MINUS term
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
| LPAR subSumChain RPAR
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
powChain
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
: term (POW term)*
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
divMulChain
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
: powChain ((PLUS | MINUS) powChain)*
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
subSumChain
|
||||
`1.23e+3` generally + in exponent is generally allowed
Is operator precedence right? Looks suspicious. It would be nice to add test for it Is operator precedence right? Looks suspicious. It would be nice to add test for it
Tests were added Tests were added
Fixed Fixed
|
||||
: divMulChain ((DIV | MUL) divMulChain)*
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
rootParser
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
: subSumChain EOF
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
||||
;
|
||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
@ -41,27 +41,25 @@ sealed class MST {
|
||||
|
||||
//TODO add a function with named arguments
|
||||
|
||||
fun <T> Algebra<T>.evaluate(node: MST): T {
|
||||
return when (node) {
|
||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField.binaryOperation(
|
||||
node.operation,
|
||||
node.left.value.toDouble(),
|
||||
node.right.value.toDouble()
|
||||
)
|
||||
number(number)
|
||||
}
|
||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField.binaryOperation(
|
||||
node.operation,
|
||||
node.left.value.toDouble(),
|
||||
node.right.value.toDouble()
|
||||
)
|
||||
number(number)
|
||||
}
|
||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
}
|
||||
}
|
||||
|
||||
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)
|
||||
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)
|
||||
|
@ -5,6 +5,8 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
|
||||
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
||||
import com.github.h0tk3y.betterParse.grammar.parser
|
||||
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
||||
import com.github.h0tk3y.betterParse.lexer.Token
|
||||
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
||||
import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||
import com.github.h0tk3y.betterParse.parser.Parser
|
||||
import scientifik.kmath.operations.FieldOperations
|
||||
@ -15,45 +17,66 @@ import scientifik.kmath.operations.SpaceOperations
|
||||
/**
|
||||
* TODO move to common
|
||||
*/
|
||||
private object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
||||
val lpar by token("\\(".toRegex())
|
||||
val rpar by token("\\)".toRegex())
|
||||
val mul by token("\\*".toRegex())
|
||||
val pow by token("\\^".toRegex())
|
||||
val div by token("/".toRegex())
|
||||
val minus by token("-".toRegex())
|
||||
val plus by token("\\+".toRegex())
|
||||
val ws by token("\\s+".toRegex(), ignore = true)
|
||||
object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
private val num: Token by token("[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
||||
private val id: Token by token("[a-z_][\\da-z_]*".toRegex())
|
||||
private val lpar: Token by token("\\(".toRegex())
|
||||
private val rpar: Token by token("\\)".toRegex())
|
||||
private val comma: Token by token(",".toRegex())
|
||||
private val mul: Token by token("\\*".toRegex())
|
||||
private val pow: Token by token("\\^".toRegex())
|
||||
private val div: Token by token("/".toRegex())
|
||||
private val minus: Token by token("-".toRegex())
|
||||
private val plus: Token by token("\\+".toRegex())
|
||||
private val ws: Token by token("\\s+".toRegex(), ignore = true)
|
||||
|
||||
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
|
||||
|
||||
val term: Parser<MST> by number or
|
||||
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
|
||||
(skip(lpar) and parser(this::rootParser) and skip(rpar))
|
||||
private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||
.map { (id, term) -> MST.Unary(id.text, term) }
|
||||
|
||||
val powChain by leftAssociative(term, pow) { a, _, b ->
|
||||
private val binaryFunction: Parser<MST> by id
|
||||
.and(skip(lpar))
|
||||
.and(parser(::subSumChain))
|
||||
.and(skip(comma))
|
||||
.and(parser(::subSumChain))
|
||||
.and(skip(rpar))
|
||||
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
|
||||
|
||||
private val term: Parser<MST> by number
|
||||
.or(binaryFunction)
|
||||
.or(unaryFunction)
|
||||
.or(singular)
|
||||
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
|
||||
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||
|
||||
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
|
||||
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
||||
}
|
||||
|
||||
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
|
||||
if (op == div) {
|
||||
private val divMulChain: Parser<MST> by leftAssociative(
|
||||
term = powChain,
|
||||
operator = div or mul use TokenMatch::type
|
||||
) { a, op, b ->
|
||||
if (op == div)
|
||||
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
||||
} else {
|
||||
else
|
||||
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
||||
}
|
||||
}
|
||||
|
||||
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
|
||||
if (op == plus) {
|
||||
private val subSumChain: Parser<MST> by leftAssociative(
|
||||
term = divMulChain,
|
||||
operator = plus or minus use TokenMatch::type
|
||||
) { a, op, b ->
|
||||
if (op == plus)
|
||||
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
} else {
|
||||
else
|
||||
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
||||
}
|
||||
}
|
||||
|
||||
override val rootParser: Parser<MST> by subSumChain
|
||||
}
|
||||
|
||||
fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
|
||||
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)
|
||||
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)
|
||||
|
@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
|
||||
import scientifik.kmath.ast.MstExpression
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.NumericAlgebra
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
|
||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||
fun AsmBuilder<T>.visit(node: MST) {
|
||||
when (node) {
|
||||
is MST.Symbolic -> loadVariable(node.value)
|
||||
is MST.Symbolic -> {
|
||||
val symbol = try {
|
||||
algebra.symbol(node.value)
|
||||
} catch (ignored: Throwable) {
|
||||
null
|
||||
}
|
||||
|
||||
if (symbol != null)
|
||||
loadTConstant(symbol)
|
||||
else
|
||||
loadVariable(node.value)
|
||||
}
|
||||
|
||||
is MST.Numeric -> loadNumeric(node.value)
|
||||
|
||||
is MST.Unary -> buildAlgebraOperationCall(
|
||||
|
@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
/**
|
||||
* Loads a [T] constant from [constants].
|
||||
*/
|
||||
private fun loadTConstant(value: T) {
|
||||
internal fun loadTConstant(value: T) {
|
||||
if (classOfT in INLINABLE_NUMBERS) {
|
||||
val expectedType = expectationStack.pop()
|
||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||
|
@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
|
||||
import scientifik.kmath.ast.mstInField
|
||||
import scientifik.kmath.ast.parseMath
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.Complex
|
||||
import scientifik.kmath.operations.ComplexField
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@ -22,4 +24,37 @@ internal class ParserTest {
|
||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `evaluate MST with singular`() {
|
||||
val mst = "i".parseMath()
|
||||
val res = ComplexField.evaluate(mst)
|
||||
assertEquals(ComplexField.i, res)
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun `evaluate MST with unary function`() {
|
||||
val mst = "sin(0)".parseMath()
|
||||
val res = RealField.evaluate(mst)
|
||||
assertEquals(0.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `evaluate MST with binary function`() {
|
||||
val magicalAlgebra = object : Algebra<String> {
|
||||
override fun symbol(value: String): String = value
|
||||
|
||||
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
|
||||
|
||||
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
|
||||
"magic" -> "$left ★ $right"
|
||||
else -> throw NotImplementedError()
|
||||
}
|
||||
}
|
||||
|
||||
val mst = "magic(a, b)".parseMath()
|
||||
val res = magicalAlgebra.evaluate(mst)
|
||||
assertEquals("a ★ b", res)
|
||||
}
|
||||
}
|
||||
|
1.23e+3
generally + in exponent is generally allowed1.23e+3
generally + in exponent is generally allowedFixed
Fixed