forked from kscience/kmath
Merge pull request #120 from mipt-npm/extended-grammar
Extend mathematic operations support in the kmath-ast parser
This commit is contained in:
commit
e714c9b808
@ -11,6 +11,7 @@ allprojects {
|
|||||||
repositories {
|
repositories {
|
||||||
jcenter()
|
jcenter()
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
}
|
}
|
||||||
|
|
||||||
group = "scientifik"
|
group = "scientifik"
|
||||||
|
@ -1,10 +1,4 @@
|
|||||||
plugins {
|
plugins { id("scientifik.mpp") }
|
||||||
id("scientifik.mpp")
|
|
||||||
}
|
|
||||||
|
|
||||||
repositories {
|
|
||||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
|
||||||
}
|
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
// all {
|
// all {
|
||||||
@ -15,23 +9,15 @@ kotlin.sourceSets {
|
|||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3")
|
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
|
||||||
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
jvmMain {
|
jvmMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
|
|
||||||
implementation("org.ow2.asm:asm:8.0.1")
|
implementation("org.ow2.asm:asm:8.0.1")
|
||||||
implementation("org.ow2.asm:asm-commons:8.0.1")
|
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||||
implementation(kotlin("reflect"))
|
implementation(kotlin("reflect"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
jsMain {
|
|
||||||
dependencies {
|
|
||||||
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
59
kmath-ast/reference/ArithmeticsEvaluator.g4
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
grammar ArithmeticsEvaluator;
|
||||||
|
|
||||||
|
fragment DIGIT: '0'..'9';
|
||||||
|
fragment LETTER: 'a'..'z';
|
||||||
|
fragment CAPITAL_LETTER: 'A'..'Z';
|
||||||
|
fragment UNDERSCORE: '_';
|
||||||
|
|
||||||
|
ID: (LETTER | UNDERSCORE | CAPITAL_LETTER) (LETTER | UNDERSCORE | DIGIT | CAPITAL_LETTER)*;
|
||||||
|
NUM: (DIGIT | '.')+ ([eE] [-+]? DIGIT+)?;
|
||||||
|
MUL: '*';
|
||||||
|
DIV: '/';
|
||||||
|
PLUS: '+';
|
||||||
|
MINUS: '-';
|
||||||
|
POW: '^';
|
||||||
|
COMMA: ',';
|
||||||
|
LPAR: '(';
|
||||||
|
RPAR: ')';
|
||||||
|
WS: [ \n\t\r]+ -> skip;
|
||||||
|
|
||||||
|
num
|
||||||
|
: NUM
|
||||||
|
;
|
||||||
|
|
||||||
|
singular
|
||||||
|
: ID
|
||||||
|
;
|
||||||
|
|
||||||
|
unaryFunction
|
||||||
|
: ID LPAR subSumChain RPAR
|
||||||
|
;
|
||||||
|
|
||||||
|
binaryFunction
|
||||||
|
: ID LPAR subSumChain COMMA subSumChain RPAR
|
||||||
|
;
|
||||||
|
|
||||||
|
term
|
||||||
|
: num
|
||||||
|
| singular
|
||||||
|
| unaryFunction
|
||||||
|
| binaryFunction
|
||||||
|
| MINUS term
|
||||||
|
| LPAR subSumChain RPAR
|
||||||
|
;
|
||||||
|
|
||||||
|
powChain
|
||||||
|
: term (POW term)*
|
||||||
|
;
|
||||||
|
|
||||||
|
divMulChain
|
||||||
|
: powChain ((DIV | MUL) powChain)*
|
||||||
|
;
|
||||||
|
|
||||||
|
subSumChain
|
||||||
|
: divMulChain ((PLUS | MINUS) divMulChain)*
|
||||||
|
;
|
||||||
|
|
||||||
|
rootParser
|
||||||
|
: subSumChain EOF
|
||||||
|
;
|
@ -41,27 +41,28 @@ sealed class MST {
|
|||||||
|
|
||||||
//TODO add a function with named arguments
|
//TODO add a function with named arguments
|
||||||
|
|
||||||
fun <T> Algebra<T>.evaluate(node: MST): T {
|
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||||
return when (node) {
|
|
||||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||||
?: error("Numeric nodes are not supported by $this")
|
?: error("Numeric nodes are not supported by $this")
|
||||||
is MST.Symbolic -> symbol(node.value)
|
is MST.Symbolic -> symbol(node.value)
|
||||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||||
is MST.Binary -> when {
|
is MST.Binary -> when {
|
||||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||||
|
|
||||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||||
val number = RealField.binaryOperation(
|
val number = RealField.binaryOperation(
|
||||||
node.operation,
|
node.operation,
|
||||||
node.left.value.toDouble(),
|
node.left.value.toDouble(),
|
||||||
node.right.value.toDouble()
|
node.right.value.toDouble()
|
||||||
)
|
)
|
||||||
|
|
||||||
number(number)
|
number(number)
|
||||||
}
|
}
|
||||||
|
|
||||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)
|
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)
|
@ -5,6 +5,9 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
|
|||||||
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
import com.github.h0tk3y.betterParse.grammar.parseToEnd
|
||||||
import com.github.h0tk3y.betterParse.grammar.parser
|
import com.github.h0tk3y.betterParse.grammar.parser
|
||||||
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.Token
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.TokenMatch
|
||||||
|
import com.github.h0tk3y.betterParse.lexer.regexToken
|
||||||
import com.github.h0tk3y.betterParse.parser.ParseResult
|
import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||||
import com.github.h0tk3y.betterParse.parser.Parser
|
import com.github.h0tk3y.betterParse.parser.Parser
|
||||||
import scientifik.kmath.operations.FieldOperations
|
import scientifik.kmath.operations.FieldOperations
|
||||||
@ -13,44 +16,66 @@ import scientifik.kmath.operations.RingOperations
|
|||||||
import scientifik.kmath.operations.SpaceOperations
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO move to common
|
* TODO move to core
|
||||||
*/
|
*/
|
||||||
private object ArithmeticsEvaluator : Grammar<MST>() {
|
object ArithmeticsEvaluator : Grammar<MST>() {
|
||||||
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
|
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
|
||||||
val lpar by token("\\(".toRegex())
|
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
|
||||||
val rpar by token("\\)".toRegex())
|
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
|
||||||
val mul by token("\\*".toRegex())
|
private val lpar: Token by regexToken("\\(")
|
||||||
val pow by token("\\^".toRegex())
|
private val rpar: Token by regexToken("\\)")
|
||||||
val div by token("/".toRegex())
|
private val comma: Token by regexToken(",")
|
||||||
val minus by token("-".toRegex())
|
private val mul: Token by regexToken("\\*")
|
||||||
val plus by token("\\+".toRegex())
|
private val pow: Token by regexToken("\\^")
|
||||||
val ws by token("\\s+".toRegex(), ignore = true)
|
private val div: Token by regexToken("/")
|
||||||
|
private val minus: Token by regexToken("-")
|
||||||
|
private val plus: Token by regexToken("\\+")
|
||||||
|
private val ws: Token by regexToken("\\s+", ignore = true)
|
||||||
|
|
||||||
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
|
||||||
|
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
|
||||||
|
|
||||||
val term: Parser<MST> by number or
|
private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||||
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
|
.map { (id, term) -> MST.Unary(id.text, term) }
|
||||||
(skip(lpar) and parser(this::rootParser) and skip(rpar))
|
|
||||||
|
|
||||||
val powChain by leftAssociative(term, pow) { a, _, b ->
|
private val binaryFunction: Parser<MST> by id
|
||||||
|
.and(skip(lpar))
|
||||||
|
.and(parser(::subSumChain))
|
||||||
|
.and(skip(comma))
|
||||||
|
.and(parser(::subSumChain))
|
||||||
|
.and(skip(rpar))
|
||||||
|
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
|
||||||
|
|
||||||
|
private val term: Parser<MST> by number
|
||||||
|
.or(binaryFunction)
|
||||||
|
.or(unaryFunction)
|
||||||
|
.or(singular)
|
||||||
|
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
|
||||||
|
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
|
||||||
|
|
||||||
|
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
|
||||||
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
MST.Binary(PowerOperations.POW_OPERATION, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
|
private val divMulChain: Parser<MST> by leftAssociative(
|
||||||
if (op == div) {
|
term = powChain,
|
||||||
|
operator = div or mul use TokenMatch::type
|
||||||
|
) { a, op, b ->
|
||||||
|
if (op == div)
|
||||||
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
||||||
} else {
|
else
|
||||||
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
|
private val subSumChain: Parser<MST> by leftAssociative(
|
||||||
if (op == plus) {
|
term = divMulChain,
|
||||||
|
operator = plus or minus use TokenMatch::type
|
||||||
|
) { a, op, b ->
|
||||||
|
if (op == plus)
|
||||||
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
} else {
|
else
|
||||||
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
override val rootParser: Parser<MST> by subSumChain
|
override val rootParser: Parser<MST> by subSumChain
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
|
|||||||
import scientifik.kmath.ast.MstExpression
|
import scientifik.kmath.ast.MstExpression
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.NumericAlgebra
|
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
|
|||||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun AsmBuilder<T>.visit(node: MST) {
|
fun AsmBuilder<T>.visit(node: MST) {
|
||||||
when (node) {
|
when (node) {
|
||||||
is MST.Symbolic -> loadVariable(node.value)
|
is MST.Symbolic -> {
|
||||||
|
val symbol = try {
|
||||||
|
algebra.symbol(node.value)
|
||||||
|
} catch (ignored: Throwable) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
|
||||||
|
if (symbol != null)
|
||||||
|
loadTConstant(symbol)
|
||||||
|
else
|
||||||
|
loadVariable(node.value)
|
||||||
|
}
|
||||||
|
|
||||||
is MST.Numeric -> loadNumeric(node.value)
|
is MST.Numeric -> loadNumeric(node.value)
|
||||||
|
|
||||||
is MST.Unary -> buildAlgebraOperationCall(
|
is MST.Unary -> buildAlgebraOperationCall(
|
||||||
|
@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Loads a [T] constant from [constants].
|
* Loads a [T] constant from [constants].
|
||||||
*/
|
*/
|
||||||
private fun loadTConstant(value: T) {
|
internal fun loadTConstant(value: T) {
|
||||||
if (classOfT in INLINABLE_NUMBERS) {
|
if (classOfT in INLINABLE_NUMBERS) {
|
||||||
val expectedType = expectationStack.pop()
|
val expectedType = expectationStack.pop()
|
||||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
package scietifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.ast.evaluate
|
||||||
|
import scientifik.kmath.ast.parseMath
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class ParserPrecedenceTest {
|
||||||
|
private val f: Field<Double> = RealField
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
|
||||||
|
}
|
@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
|
|||||||
import scientifik.kmath.ast.mstInField
|
import scientifik.kmath.ast.mstInField
|
||||||
import scientifik.kmath.ast.parseMath
|
import scientifik.kmath.ast.parseMath
|
||||||
import scientifik.kmath.expressions.invoke
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -22,4 +24,37 @@ internal class ParserTest {
|
|||||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with singular`() {
|
||||||
|
val mst = "i".parseMath()
|
||||||
|
val res = ComplexField.evaluate(mst)
|
||||||
|
assertEquals(ComplexField.i, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with unary function`() {
|
||||||
|
val mst = "sin(0)".parseMath()
|
||||||
|
val res = RealField.evaluate(mst)
|
||||||
|
assertEquals(0.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `evaluate MST with binary function`() {
|
||||||
|
val magicalAlgebra = object : Algebra<String> {
|
||||||
|
override fun symbol(value: String): String = value
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
|
||||||
|
"magic" -> "$left ★ $right"
|
||||||
|
else -> throw NotImplementedError()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val mst = "magic(a, b)".parseMath()
|
||||||
|
val res = magicalAlgebra.evaluate(mst)
|
||||||
|
assertEquals("a ★ b", res)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.RealField.pow
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.math.pow as kpow
|
import kotlin.math.pow as kpow
|
||||||
|
|
||||||
@ -85,6 +86,11 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
|||||||
override inline fun Double.times(b: Double) = this * b
|
override inline fun Double.times(b: Double) = this * b
|
||||||
|
|
||||||
override inline fun Double.div(b: Double) = this / b
|
override inline fun Double.div(b: Double) = this / b
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
|
||||||
|
PowerOperations.POW_OPERATION -> left pow right
|
||||||
|
else -> super.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
|
Loading…
Reference in New Issue
Block a user