Fix #172, add constant folding for unary operations from numeric nodes #173

Merged
CommanderTvis merged 1 commits from commandertvis/ast-valid-constantfolding into dev 2021-01-07 19:20:32 +03:00
4 changed files with 48 additions and 39 deletions
Showing only changes of commit 2012d2c3f1 - Show all commits

View File

@ -2,10 +2,9 @@ package kscience.kmath.ast
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
/**
* A Mathematical Syntax Tree node for mathematical expressions.
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
*
* @author Alexander Nozik
*/
@ -57,21 +56,22 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperationFunction(node.operation)(evaluate(node.value))
is MST.Unary -> when {
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
}
is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField
.binaryOperationFunction(node.operation)
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
this is NumericAlgebra && node.left is MST.Numeric ->
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
number(number)
}
this is NumericAlgebra && node.right is MST.Numeric ->
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
node.left is MST.Numeric -> leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
}
}

View File

@ -1,18 +1,18 @@
package kscience.kmath.estree
import kscience.kmath.ast.MST
import kscience.kmath.ast.MST.*
import kscience.kmath.ast.MstExpression
import kscience.kmath.estree.internal.ESTreeBuilder
import kscience.kmath.estree.internal.estree.BaseExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
@PublishedApi
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
is MST.Symbolic -> {
is Symbolic -> {
val symbol = try {
algebra.symbol(node.value)
} catch (ignored: IllegalStateException) {
@ -25,25 +25,29 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
variable(node.value)
}
is MST.Numeric -> constant(node.value)
is MST.Unary -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
is Numeric -> constant(node.value)
is MST.Binary -> when {
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> constant(
algebra.number(
RealField
.binaryOperationFunction(node.operation)
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
)
is Unary -> when {
algebra is NumericAlgebra && node.value is Numeric -> constant(
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
}
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
algebra
.binaryOperationFunction(node.operation)
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
)
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> call(
algebra is NumericAlgebra && node.left is Numeric -> call(
algebra.leftSideNumberOperationFunction(node.operation),
visit(node.left),
visit(node.right),
)
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> call(
algebra is NumericAlgebra && node.right is Numeric -> call(
algebra.rightSideNumberOperationFunction(node.operation),
visit(node.left),
visit(node.right),

View File

@ -3,11 +3,11 @@ package kscience.kmath.asm
import kscience.kmath.asm.internal.AsmBuilder
import kscience.kmath.asm.internal.buildName
import kscience.kmath.ast.MST
import kscience.kmath.ast.MST.*
import kscience.kmath.ast.MstExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
/**
* Compiles given MST to an Expression using AST compiler.
@ -20,7 +20,7 @@ import kscience.kmath.operations.RealField
@PublishedApi
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
is MST.Symbolic -> {
is Symbolic -> {
val symbol = try {
algebra.symbol(node.value)
} catch (ignored: IllegalStateException) {
@ -33,24 +33,29 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
loadVariable(node.value)
}
is MST.Numeric -> loadNumberConstant(node.value)
is MST.Unary -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
is Numeric -> loadNumberConstant(node.value)
is MST.Binary -> when {
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant(
algebra.number(
RealField
.binaryOperationFunction(node.operation)
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
)
is Unary -> when {
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
}
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
algebra.binaryOperationFunction(node.operation)
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
)
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperationFunction(node.operation)) {
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
algebra.leftSideNumberOperationFunction(node.operation)) {
visit(node.left)
visit(node.right)
}
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperationFunction(node.operation)) {
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
algebra.rightSideNumberOperationFunction(node.operation)) {
visit(node.left)
visit(node.right)
}

View File

@ -191,7 +191,7 @@ internal class AsmBuilder<T>(
}
val cls = classLoader.defineClass(className, classWriter.toByteArray())
java.io.File("dump.class").writeBytes(classWriter.toByteArray())
// java.io.File("dump.class").writeBytes(classWriter.toByteArray())
val l = MethodHandles.publicLookup()
if (hasConstants)