Merge pull request #120 from mipt-npm/extended-grammar

Extend mathematic operations support in the kmath-ast parser
This commit is contained in:
Alexander Nozik 2020-07-27 19:37:08 +03:00 committed by GitHub
commit e714c9b808
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 224 additions and 64 deletions

View File

@ -11,6 +11,7 @@ allprojects {
repositories { repositories {
jcenter() jcenter()
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/hotkeytlt/maven")
} }
group = "scientifik" group = "scientifik"

View File

@ -1,10 +1,4 @@
plugins { plugins { id("scientifik.mpp") }
id("scientifik.mpp")
}
repositories {
maven("https://dl.bintray.com/hotkeytlt/maven")
}
kotlin.sourceSets { kotlin.sourceSets {
// all { // all {
@ -15,23 +9,15 @@ kotlin.sourceSets {
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3") implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
} }
} }
jvmMain { jvmMain {
dependencies { dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
implementation("org.ow2.asm:asm:8.0.1") implementation("org.ow2.asm:asm:8.0.1")
implementation("org.ow2.asm:asm-commons:8.0.1") implementation("org.ow2.asm:asm-commons:8.0.1")
implementation(kotlin("reflect")) implementation(kotlin("reflect"))
} }
} }
jsMain {
dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
}
}
} }

View File

@ -0,0 +1,59 @@
grammar ArithmeticsEvaluator;
fragment DIGIT: '0'..'9';
fragment LETTER: 'a'..'z';
fragment CAPITAL_LETTER: 'A'..'Z';
fragment UNDERSCORE: '_';
ID: (LETTER | UNDERSCORE | CAPITAL_LETTER) (LETTER | UNDERSCORE | DIGIT | CAPITAL_LETTER)*;
NUM: (DIGIT | '.')+ ([eE] [-+]? DIGIT+)?;
MUL: '*';
DIV: '/';
PLUS: '+';
MINUS: '-';
POW: '^';
COMMA: ',';
LPAR: '(';
RPAR: ')';
WS: [ \n\t\r]+ -> skip;
num
: NUM
;
singular
: ID
;
unaryFunction
: ID LPAR subSumChain RPAR
;
binaryFunction
: ID LPAR subSumChain COMMA subSumChain RPAR
;
term
: num
| singular
| unaryFunction
| binaryFunction
| MINUS term
| LPAR subSumChain RPAR
;
powChain
: term (POW term)*
;
divMulChain
: powChain ((DIV | MUL) powChain)*
;
subSumChain
: divMulChain ((PLUS | MINUS) divMulChain)*
;
rootParser
: subSumChain EOF
;

View File

@ -41,27 +41,28 @@ sealed class MST {
//TODO add a function with named arguments //TODO add a function with named arguments
fun <T> Algebra<T>.evaluate(node: MST): T { fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
return when (node) {
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value) is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this") ?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when { is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> { node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation( val number = RealField.binaryOperation(
node.operation, node.operation,
node.left.value.toDouble(), node.left.value.toDouble(),
node.right.value.toDouble() node.right.value.toDouble()
) )
number(number) number(number)
} }
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
} }
}
} }
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this) fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)

View File

@ -5,6 +5,9 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
import com.github.h0tk3y.betterParse.grammar.parseToEnd import com.github.h0tk3y.betterParse.grammar.parseToEnd
import com.github.h0tk3y.betterParse.grammar.parser import com.github.h0tk3y.betterParse.grammar.parser
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
import com.github.h0tk3y.betterParse.lexer.Token
import com.github.h0tk3y.betterParse.lexer.TokenMatch
import com.github.h0tk3y.betterParse.lexer.regexToken
import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser import com.github.h0tk3y.betterParse.parser.Parser
import scientifik.kmath.operations.FieldOperations import scientifik.kmath.operations.FieldOperations
@ -13,44 +16,66 @@ import scientifik.kmath.operations.RingOperations
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
/** /**
* TODO move to common * TODO move to core
*/ */
private object ArithmeticsEvaluator : Grammar<MST>() { object ArithmeticsEvaluator : Grammar<MST>() {
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex()) // TODO replace with "...".toRegex() when better-parse 0.4.1 is released
val lpar by token("\\(".toRegex()) private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
val rpar by token("\\)".toRegex()) private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
val mul by token("\\*".toRegex()) private val lpar: Token by regexToken("\\(")
val pow by token("\\^".toRegex()) private val rpar: Token by regexToken("\\)")
val div by token("/".toRegex()) private val comma: Token by regexToken(",")
val minus by token("-".toRegex()) private val mul: Token by regexToken("\\*")
val plus by token("\\+".toRegex()) private val pow: Token by regexToken("\\^")
val ws by token("\\s+".toRegex(), ignore = true) private val div: Token by regexToken("/")
private val minus: Token by regexToken("-")
private val plus: Token by regexToken("\\+")
private val ws: Token by regexToken("\\s+", ignore = true)
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) } private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
val term: Parser<MST> by number or private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or .map { (id, term) -> MST.Unary(id.text, term) }
(skip(lpar) and parser(this::rootParser) and skip(rpar))
val powChain by leftAssociative(term, pow) { a, _, b -> private val binaryFunction: Parser<MST> by id
.and(skip(lpar))
.and(parser(::subSumChain))
.and(skip(comma))
.and(parser(::subSumChain))
.and(skip(rpar))
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
private val term: Parser<MST> by number
.or(binaryFunction)
.or(unaryFunction)
.or(singular)
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
MST.Binary(PowerOperations.POW_OPERATION, a, b) MST.Binary(PowerOperations.POW_OPERATION, a, b)
} }
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b -> private val divMulChain: Parser<MST> by leftAssociative(
if (op == div) { term = powChain,
operator = div or mul use TokenMatch::type
) { a, op, b ->
if (op == div)
MST.Binary(FieldOperations.DIV_OPERATION, a, b) MST.Binary(FieldOperations.DIV_OPERATION, a, b)
} else { else
MST.Binary(RingOperations.TIMES_OPERATION, a, b) MST.Binary(RingOperations.TIMES_OPERATION, a, b)
} }
}
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b -> private val subSumChain: Parser<MST> by leftAssociative(
if (op == plus) { term = divMulChain,
operator = plus or minus use TokenMatch::type
) { a, op, b ->
if (op == plus)
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b) MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
} else { else
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b) MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
} }
}
override val rootParser: Parser<MST> by subSumChain override val rootParser: Parser<MST> by subSumChain
} }

View File

@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MstExpression import scientifik.kmath.ast.MstExpression
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> { fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST) { fun AsmBuilder<T>.visit(node: MST) {
when (node) { when (node) {
is MST.Symbolic -> loadVariable(node.value) is MST.Symbolic -> {
val symbol = try {
algebra.symbol(node.value)
} catch (ignored: Throwable) {
null
}
if (symbol != null)
loadTConstant(symbol)
else
loadVariable(node.value)
}
is MST.Numeric -> loadNumeric(node.value) is MST.Numeric -> loadNumeric(node.value)
is MST.Unary -> buildAlgebraOperationCall( is MST.Unary -> buildAlgebraOperationCall(

View File

@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* Loads a [T] constant from [constants]. * Loads a [T] constant from [constants].
*/ */
private fun loadTConstant(value: T) { internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) { if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop() val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT val mustBeBoxed = expectedType.sort == Type.OBJECT

View File

@ -0,0 +1,36 @@
package scietifik.kmath.ast
import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.parseMath
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ParserPrecedenceTest {
private val f: Field<Double> = RealField
@Test
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
@Test
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
@Test
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
@Test
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
@Test
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
@Test
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
@Test
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
@Test
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
}

View File

@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.mstInField import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -22,4 +24,37 @@ internal class ParserTest {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
@Test
fun `evaluate MST with singular`() {
val mst = "i".parseMath()
val res = ComplexField.evaluate(mst)
assertEquals(ComplexField.i, res)
}
@Test
fun `evaluate MST with unary function`() {
val mst = "sin(0)".parseMath()
val res = RealField.evaluate(mst)
assertEquals(0.0, res)
}
@Test
fun `evaluate MST with binary function`() {
val magicalAlgebra = object : Algebra<String> {
override fun symbol(value: String): String = value
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
"magic" -> "$left$right"
else -> throw NotImplementedError()
}
}
val mst = "magic(a, b)".parseMath()
val res = magicalAlgebra.evaluate(mst)
assertEquals("a ★ b", res)
}
} }

View File

@ -1,5 +1,6 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.RealField.pow
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow as kpow import kotlin.math.pow as kpow
@ -85,6 +86,11 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
override inline fun Double.times(b: Double) = this * b override inline fun Double.times(b: Double) = this * b
override inline fun Double.div(b: Double) = this / b override inline fun Double.div(b: Double) = this / b
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
} }
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")