Extend mathematic operations support in the kmath-ast parser #120
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
|||||||
|
grammar ArithmeticsEvaluator;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
fragment DIGIT: '0'..'9';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
fragment LETTER: 'a'..'z';
|
||||||
What about uppercase letters? What about uppercase letters?
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
Added. Added.
|
|||||||
|
fragment UNDERSCORE: '_';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
ID: (LETTER | UNDERSCORE) (LETTER | UNDERSCORE | DIGIT)*;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
NUM: (DIGIT | '.')+ ([eE] MINUS? DIGIT+)?;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
MUL: '*';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
DIV: '/';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
PLUS: '+';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
MINUS: '-';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
POW: '^';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
COMMA: ',';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
LPAR: '(';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
RPAR: ')';
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
WS: [ \n\t\r]+ -> skip;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
number
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
: NUM
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
singular
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
: ID
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
unaryFunction
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
: ID LPAR subSumChain RPAR
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
binaryFunction
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
: ID LPAR subSumChain COMMA subSumChain RPAR
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
term
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
: number
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
| singular
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
| unaryFunction
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
| binaryFunction
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
| MINUS term
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
| LPAR subSumChain RPAR
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
powChain
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
: term (POW term)*
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
divMulChain
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
: powChain ((PLUS | MINUS) powChain)*
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
subSumChain
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Is operator precedence right? Looks suspicious. It would be nice to add test for it Is operator precedence right? Looks suspicious. It would be nice to add test for it
Tests were added Tests were added
Fixed Fixed
|
|||||||
|
: divMulChain ((DIV | MUL) divMulChain)*
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
rootParser
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
: subSumChain EOF
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
|||||||
|
;
|
||||||
`1.23e+3` generally + in exponent is generally allowed
Fixed Fixed
|
@ -41,8 +41,7 @@ sealed class MST {
|
|||||||
|
|
||||||
//TODO add a function with named arguments
|
//TODO add a function with named arguments
|
||||||
|
|
||||||
fun <T> Algebra<T>.evaluate(node: MST): T {
|
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||||
return when (node) {
|
|
||||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||||
?: error("Numeric nodes are not supported by $this")
|
?: error("Numeric nodes are not supported by $this")
|
||||||
is MST.Symbolic -> symbol(node.value)
|
is MST.Symbolic -> symbol(node.value)
|
||||||
@ -61,7 +60,6 @@ fun <T> Algebra<T>.evaluate(node: MST): T {
|
|||||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)
|
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)
|
@ -5,6 +5,8 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
|
|||||||
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
||||||
import com.github.h0tk3y.betterParse.grammar.parser
|
import com.github.h0tk3y.betterParse.grammar.parser
|
||||||
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.Token
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
||||||
import com.github.h0tk3y.betterParse.parser.ParseResult
|
import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||||
import com.github.h0tk3y.betterParse.parser.Parser
|
import com.github.h0tk3y.betterParse.parser.Parser
|
||||||
import scientifik.kmath.operations.FieldOperations
|
import scientifik.kmath.operations.FieldOperations
|
||||||
@ -15,42 +17,63 @@ import scientifik.kmath.operations.SpaceOperations
|
|||||||
/**
|
/**
|
||||||
* TODO move to common
|
* TODO move to common
|
||||||
*/
|
*/
|
||||||
private object ArithmeticsEvaluator : Grammar<MST>() {
|
object ArithmeticsEvaluator : Grammar<MST>() {
|
||||||
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
private val num: Token by token("[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
||||||
val lpar by token("\\(".toRegex())
|
private val id: Token by token("[a-z_][\\da-z_]*".toRegex())
|
||||||
val rpar by token("\\)".toRegex())
|
private val lpar: Token by token("\\(".toRegex())
|
||||||
val mul by token("\\*".toRegex())
|
private val rpar: Token by token("\\)".toRegex())
|
||||||
val pow by token("\\^".toRegex())
|
private val comma: Token by token(",".toRegex())
|
||||||
val div by token("/".toRegex())
|
private val mul: Token by token("\\*".toRegex())
|
||||||
val minus by token("-".toRegex())
|
private val pow: Token by token("\\^".toRegex())
|
||||||
val plus by token("\\+".toRegex())
|
private val div: Token by token("/".toRegex())
|
||||||
val ws by token("\\s+".toRegex(), ignore = true)
|
private val minus: Token by token("-".toRegex())
|
||||||
|
private val plus: Token by token("\\+".toRegex())
|
||||||
|
private val ws: Token by token("\\s+".toRegex(), ignore = true)
|
||||||
|
|
||||||
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||||
|
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
|
||||||
|
|
||||||
val term: Parser<MST> by number or
|
private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||||
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
|
.map { (id, term) -> MST.Unary(id.text, term) }
|
||||||
(skip(lpar) and parser(this::rootParser) and skip(rpar))
|
|
||||||
|
|
||||||
val powChain by leftAssociative(term, pow) { a, _, b ->
|
private val binaryFunction: Parser<MST> by id
|
||||||
|
.and(skip(lpar))
|
||||||
|
.and(parser(::subSumChain))
|
||||||
|
.and(skip(comma))
|
||||||
|
.and(parser(::subSumChain))
|
||||||
|
.and(skip(rpar))
|
||||||
|
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
|
||||||
|
|
||||||
|
private val term: Parser<MST> by number
|
||||||
|
.or(binaryFunction)
|
||||||
|
.or(unaryFunction)
|
||||||
|
.or(singular)
|
||||||
|
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
|
||||||
|
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||||
|
|
||||||
|
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
|
||||||
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
|
private val divMulChain: Parser<MST> by leftAssociative(
|
||||||
if (op == div) {
|
term = powChain,
|
||||||
|
operator = div or mul use TokenMatch::type
|
||||||
|
) { a, op, b ->
|
||||||
|
if (op == div)
|
||||||
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
||||||
} else {
|
else
|
||||||
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
|
private val subSumChain: Parser<MST> by leftAssociative(
|
||||||
if (op == plus) {
|
term = divMulChain,
|
||||||
|
operator = plus or minus use TokenMatch::type
|
||||||
|
) { a, op, b ->
|
||||||
|
if (op == plus)
|
||||||
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
} else {
|
else
|
||||||
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
override val rootParser: Parser<MST> by subSumChain
|
override val rootParser: Parser<MST> by subSumChain
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
|
|||||||
import scientifik.kmath.ast.MstExpression
|
import scientifik.kmath.ast.MstExpression
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.NumericAlgebra
|
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
|
|||||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun AsmBuilder<T>.visit(node: MST) {
|
fun AsmBuilder<T>.visit(node: MST) {
|
||||||
when (node) {
|
when (node) {
|
||||||
is MST.Symbolic -> loadVariable(node.value)
|
is MST.Symbolic -> {
|
||||||
|
val symbol = try {
|
||||||
|
algebra.symbol(node.value)
|
||||||
|
} catch (ignored: Throwable) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
|
||||||
|
if (symbol != null)
|
||||||
|
loadTConstant(symbol)
|
||||||
|
else
|
||||||
|
loadVariable(node.value)
|
||||||
|
}
|
||||||
|
|
||||||
is MST.Numeric -> loadNumeric(node.value)
|
is MST.Numeric -> loadNumeric(node.value)
|
||||||
|
|
||||||
is MST.Unary -> buildAlgebraOperationCall(
|
is MST.Unary -> buildAlgebraOperationCall(
|
||||||
|
@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Loads a [T] constant from [constants].
|
* Loads a [T] constant from [constants].
|
||||||
*/
|
*/
|
||||||
private fun loadTConstant(value: T) {
|
internal fun loadTConstant(value: T) {
|
||||||
if (classOfT in INLINABLE_NUMBERS) {
|
if (classOfT in INLINABLE_NUMBERS) {
|
||||||
val expectedType = expectationStack.pop()
|
val expectedType = expectationStack.pop()
|
||||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||||
|
@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
|
|||||||
import scientifik.kmath.ast.mstInField
|
import scientifik.kmath.ast.mstInField
|
||||||
import scientifik.kmath.ast.parseMath
|
import scientifik.kmath.ast.parseMath
|
||||||
import scientifik.kmath.expressions.invoke
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -22,4 +24,37 @@ internal class ParserTest {
|
|||||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with singular`() {
|
||||||
|
val mst = "i".parseMath()
|
||||||
|
val res = ComplexField.evaluate(mst)
|
||||||
|
assertEquals(ComplexField.i, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with unary function`() {
|
||||||
|
val mst = "sin(0)".parseMath()
|
||||||
|
val res = RealField.evaluate(mst)
|
||||||
|
assertEquals(0.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with binary function`() {
|
||||||
|
val magicalAlgebra = object : Algebra<String> {
|
||||||
|
override fun symbol(value: String): String = value
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
|
||||||
|
"magic" -> "$left ★ $right"
|
||||||
|
else -> throw NotImplementedError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val mst = "magic(a, b)".parseMath()
|
||||||
|
val res = magicalAlgebra.evaluate(mst)
|
||||||
|
assertEquals("a ★ b", res)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
1.23e+3
generally + in exponent is generally allowed1.23e+3
generally + in exponent is generally allowedFixed
Fixed