Fix #172, add constant folding for unary operations from numeric nodes #173

Merged
CommanderTvis merged 1 commits from commandertvis/ast-valid-constantfolding into dev 2021-01-07 19:20:32 +03:00
4 changed files with 48 additions and 39 deletions

View File

@ -2,10 +2,9 @@ package kscience.kmath.ast
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
/** /**
* A Mathematical Syntax Tree node for mathematical expressions. * A Mathematical Syntax Tree (MST) node for mathematical expressions.
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
@ -57,21 +56,22 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
?: error("Numeric nodes are not supported by $this") ?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperationFunction(node.operation)(evaluate(node.value))
is MST.Binary -> when { is MST.Unary -> when {
this !is NumericAlgebra -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right)) this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField
.binaryOperationFunction(node.operation)
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
number(number)
} }
node.left is MST.Numeric -> leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right)) is MST.Binary -> when {
node.right is MST.Numeric -> rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value) this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
this is NumericAlgebra && node.left is MST.Numeric ->
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
this is NumericAlgebra && node.right is MST.Numeric ->
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right)) else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
} }
} }

View File

@ -1,18 +1,18 @@
package kscience.kmath.estree package kscience.kmath.estree
import kscience.kmath.ast.MST import kscience.kmath.ast.MST
import kscience.kmath.ast.MST.*
import kscience.kmath.ast.MstExpression import kscience.kmath.ast.MstExpression
import kscience.kmath.estree.internal.ESTreeBuilder import kscience.kmath.estree.internal.ESTreeBuilder
import kscience.kmath.estree.internal.estree.BaseExpression import kscience.kmath.estree.internal.estree.BaseExpression
import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
@PublishedApi @PublishedApi
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> { internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) { fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
is MST.Symbolic -> { is Symbolic -> {
val symbol = try { val symbol = try {
algebra.symbol(node.value) algebra.symbol(node.value)
} catch (ignored: IllegalStateException) { } catch (ignored: IllegalStateException) {
@ -25,25 +25,29 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
variable(node.value) variable(node.value)
} }
is MST.Numeric -> constant(node.value) is Numeric -> constant(node.value)
is MST.Unary -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
is MST.Binary -> when { is Unary -> when {
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> constant( algebra is NumericAlgebra && node.value is Numeric -> constant(
algebra.number( algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
RealField
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
}
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
algebra
.binaryOperationFunction(node.operation) .binaryOperationFunction(node.operation)
.invoke(node.left.value.toDouble(), node.right.value.toDouble()) .invoke(algebra.number(node.left.value), algebra.number(node.right.value))
)
) )
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> call( algebra is NumericAlgebra && node.left is Numeric -> call(
algebra.leftSideNumberOperationFunction(node.operation), algebra.leftSideNumberOperationFunction(node.operation),
visit(node.left), visit(node.left),
visit(node.right), visit(node.right),
) )
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> call( algebra is NumericAlgebra && node.right is Numeric -> call(
algebra.rightSideNumberOperationFunction(node.operation), algebra.rightSideNumberOperationFunction(node.operation),
visit(node.left), visit(node.left),
visit(node.right), visit(node.right),

View File

@ -3,11 +3,11 @@ package kscience.kmath.asm
import kscience.kmath.asm.internal.AsmBuilder import kscience.kmath.asm.internal.AsmBuilder
import kscience.kmath.asm.internal.buildName import kscience.kmath.asm.internal.buildName
import kscience.kmath.ast.MST import kscience.kmath.ast.MST
import kscience.kmath.ast.MST.*
import kscience.kmath.ast.MstExpression import kscience.kmath.ast.MstExpression
import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kscience.kmath.operations.NumericAlgebra import kscience.kmath.operations.NumericAlgebra
import kscience.kmath.operations.RealField
/** /**
* Compiles given MST to an Expression using AST compiler. * Compiles given MST to an Expression using AST compiler.
@ -20,7 +20,7 @@ import kscience.kmath.operations.RealField
@PublishedApi @PublishedApi
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> { internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) { fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
is MST.Symbolic -> { is Symbolic -> {
val symbol = try { val symbol = try {
algebra.symbol(node.value) algebra.symbol(node.value)
} catch (ignored: IllegalStateException) { } catch (ignored: IllegalStateException) {
@ -33,24 +33,29 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
loadVariable(node.value) loadVariable(node.value)
} }
is MST.Numeric -> loadNumberConstant(node.value) is Numeric -> loadNumberConstant(node.value)
is MST.Unary -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
is MST.Binary -> when { is Unary -> when {
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant( algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
algebra.number( algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
RealField
.binaryOperationFunction(node.operation) else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
.invoke(node.left.value.toDouble(), node.right.value.toDouble()) }
)
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
algebra.binaryOperationFunction(node.operation)
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
) )
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperationFunction(node.operation)) { algebra is NumericAlgebra && node.left is Numeric -> buildCall(
algebra.leftSideNumberOperationFunction(node.operation)) {
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)
} }
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperationFunction(node.operation)) { algebra is NumericAlgebra && node.right is Numeric -> buildCall(
algebra.rightSideNumberOperationFunction(node.operation)) {
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)
} }

View File

@ -191,7 +191,7 @@ internal class AsmBuilder<T>(
} }
val cls = classLoader.defineClass(className, classWriter.toByteArray()) val cls = classLoader.defineClass(className, classWriter.toByteArray())
java.io.File("dump.class").writeBytes(classWriter.toByteArray()) // java.io.File("dump.class").writeBytes(classWriter.toByteArray())
val l = MethodHandles.publicLookup() val l = MethodHandles.publicLookup()
if (hasConstants) if (hasConstants)