Merge pull request #120 from mipt-npm/extended-grammar

Extend mathematic operations support in the kmath-ast parser
This commit is contained in:
Alexander Nozik 2020-07-27 19:37:08 +03:00 committed by GitHub
commit e714c9b808
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 224 additions and 64 deletions

View File

@ -11,6 +11,7 @@ allprojects {
repositories {
jcenter()
maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/hotkeytlt/maven")
}
group = "scientifik"

View File

@ -1,10 +1,4 @@
plugins {
id("scientifik.mpp")
}
repositories {
maven("https://dl.bintray.com/hotkeytlt/maven")
}
plugins { id("scientifik.mpp") }
kotlin.sourceSets {
// all {
@ -15,23 +9,15 @@ kotlin.sourceSets {
commonMain {
dependencies {
api(project(":kmath-core"))
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3")
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
}
}
jvmMain {
dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
implementation("org.ow2.asm:asm:8.0.1")
implementation("org.ow2.asm:asm-commons:8.0.1")
implementation(kotlin("reflect"))
}
}
jsMain {
dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
}
}
}

View File

@ -0,0 +1,59 @@
grammar ArithmeticsEvaluator;
fragment DIGIT: '0'..'9';
fragment LETTER: 'a'..'z';
fragment CAPITAL_LETTER: 'A'..'Z';
fragment UNDERSCORE: '_';
ID: (LETTER | UNDERSCORE | CAPITAL_LETTER) (LETTER | UNDERSCORE | DIGIT | CAPITAL_LETTER)*;
NUM: (DIGIT | '.')+ ([eE] [-+]? DIGIT+)?;
MUL: '*';
DIV: '/';
PLUS: '+';
MINUS: '-';
POW: '^';
COMMA: ',';
LPAR: '(';
RPAR: ')';
WS: [ \n\t\r]+ -> skip;
num
: NUM
;
singular
: ID
;
unaryFunction
: ID LPAR subSumChain RPAR
;
binaryFunction
: ID LPAR subSumChain COMMA subSumChain RPAR
;
term
: num
| singular
| unaryFunction
| binaryFunction
| MINUS term
| LPAR subSumChain RPAR
;
powChain
: term (POW term)*
;
divMulChain
: powChain ((DIV | MUL) powChain)*
;
subSumChain
: divMulChain ((PLUS | MINUS) divMulChain)*
;
rootParser
: subSumChain EOF
;

View File

@ -41,26 +41,27 @@ sealed class MST {
//TODO add a function with named arguments
fun <T> Algebra<T>.evaluate(node: MST): T {
return when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation(
node.operation,
node.left.value.toDouble(),
node.right.value.toDouble()
)
number(number)
}
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation(
node.operation,
node.left.value.toDouble(),
node.right.value.toDouble()
)
number(number)
}
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
}
}

View File

@ -5,6 +5,9 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
import com.github.h0tk3y.betterParse.grammar.parseToEnd
import com.github.h0tk3y.betterParse.grammar.parser
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
import com.github.h0tk3y.betterParse.lexer.Token
import com.github.h0tk3y.betterParse.lexer.TokenMatch
import com.github.h0tk3y.betterParse.lexer.regexToken
import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser
import scientifik.kmath.operations.FieldOperations
@ -13,43 +16,65 @@ import scientifik.kmath.operations.RingOperations
import scientifik.kmath.operations.SpaceOperations
/**
* TODO move to common
* TODO move to core
*/
private object ArithmeticsEvaluator : Grammar<MST>() {
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex())
val lpar by token("\\(".toRegex())
val rpar by token("\\)".toRegex())
val mul by token("\\*".toRegex())
val pow by token("\\^".toRegex())
val div by token("/".toRegex())
val minus by token("-".toRegex())
val plus by token("\\+".toRegex())
val ws by token("\\s+".toRegex(), ignore = true)
object ArithmeticsEvaluator : Grammar<MST>() {
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released
private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
private val lpar: Token by regexToken("\\(")
private val rpar: Token by regexToken("\\)")
private val comma: Token by regexToken(",")
private val mul: Token by regexToken("\\*")
private val pow: Token by regexToken("\\^")
private val div: Token by regexToken("/")
private val minus: Token by regexToken("-")
private val plus: Token by regexToken("\\+")
private val ws: Token by regexToken("\\s+", ignore = true)
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
val term: Parser<MST> by number or
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or
(skip(lpar) and parser(this::rootParser) and skip(rpar))
private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
.map { (id, term) -> MST.Unary(id.text, term) }
val powChain by leftAssociative(term, pow) { a, _, b ->
private val binaryFunction: Parser<MST> by id
.and(skip(lpar))
.and(parser(::subSumChain))
.and(skip(comma))
.and(parser(::subSumChain))
.and(skip(rpar))
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
private val term: Parser<MST> by number
.or(binaryFunction)
.or(unaryFunction)
.or(singular)
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
MST.Binary(PowerOperations.POW_OPERATION, a, b)
}
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b ->
if (op == div) {
private val divMulChain: Parser<MST> by leftAssociative(
term = powChain,
operator = div or mul use TokenMatch::type
) { a, op, b ->
if (op == div)
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
} else {
else
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
}
}
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b ->
if (op == plus) {
private val subSumChain: Parser<MST> by leftAssociative(
term = divMulChain,
operator = plus or minus use TokenMatch::type
) { a, op, b ->
if (op == plus)
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
} else {
else
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
}
}
override val rootParser: Parser<MST> by subSumChain

View File

@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MstExpression
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import kotlin.reflect.KClass
/**
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST) {
when (node) {
is MST.Symbolic -> loadVariable(node.value)
is MST.Symbolic -> {
val symbol = try {
algebra.symbol(node.value)
} catch (ignored: Throwable) {
null
}
if (symbol != null)
loadTConstant(symbol)
else
loadVariable(node.value)
}
is MST.Numeric -> loadNumeric(node.value)
is MST.Unary -> buildAlgebraOperationCall(

View File

@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
/**
* Loads a [T] constant from [constants].
*/
private fun loadTConstant(value: T) {
internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT

View File

@ -0,0 +1,36 @@
package scietifik.kmath.ast
import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.parseMath
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ParserPrecedenceTest {
private val f: Field<Double> = RealField
@Test
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
@Test
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
@Test
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
@Test
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
@Test
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
@Test
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
@Test
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
@Test
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
}

View File

@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
@ -22,4 +24,37 @@ internal class ParserTest {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
assertEquals(Complex(10.0, 0.0), res)
}
@Test
fun `evaluate MST with singular`() {
val mst = "i".parseMath()
val res = ComplexField.evaluate(mst)
assertEquals(ComplexField.i, res)
}
@Test
fun `evaluate MST with unary function`() {
val mst = "sin(0)".parseMath()
val res = RealField.evaluate(mst)
assertEquals(0.0, res)
}
@Test
fun `evaluate MST with binary function`() {
val magicalAlgebra = object : Algebra<String> {
override fun symbol(value: String): String = value
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
"magic" -> "$left$right"
else -> throw NotImplementedError()
}
}
val mst = "magic(a, b)".parseMath()
val res = magicalAlgebra.evaluate(mst)
assertEquals("a ★ b", res)
}
}

View File

@ -1,5 +1,6 @@
package scientifik.kmath.operations
import scientifik.kmath.operations.RealField.pow
import kotlin.math.abs
import kotlin.math.pow as kpow
@ -85,6 +86,11 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
override inline fun Double.times(b: Double) = this * b
override inline fun Double.div(b: Double) = this / b
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
}
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")