Merge pull request #173 from mipt-npm/commandertvis/ast-valid-constantfolding
Fix #172, add constant folding for unary operations from numeric nodes
This commit is contained in:
commit
a5e8c971ba
@ -2,10 +2,9 @@ package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
* A Mathematical Syntax Tree node for mathematical expressions.
|
||||
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
@ -57,21 +56,22 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
||||
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField
|
||||
.binaryOperationFunction(node.operation)
|
||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
||||
|
||||
number(number)
|
||||
is MST.Unary -> when {
|
||||
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
|
||||
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
||||
}
|
||||
|
||||
node.left is MST.Numeric -> leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
||||
is MST.Binary -> when {
|
||||
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
|
||||
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
|
||||
|
||||
this is NumericAlgebra && node.left is MST.Numeric ->
|
||||
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
||||
|
||||
this is NumericAlgebra && node.right is MST.Numeric ->
|
||||
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
||||
|
||||
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +1,18 @@
|
||||
package kscience.kmath.estree
|
||||
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MST.*
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.estree.internal.ESTreeBuilder
|
||||
import kscience.kmath.estree.internal.estree.BaseExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
@PublishedApi
|
||||
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
||||
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
|
||||
is MST.Symbolic -> {
|
||||
is Symbolic -> {
|
||||
val symbol = try {
|
||||
algebra.symbol(node.value)
|
||||
} catch (ignored: IllegalStateException) {
|
||||
@ -25,25 +25,29 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
||||
variable(node.value)
|
||||
}
|
||||
|
||||
is MST.Numeric -> constant(node.value)
|
||||
is MST.Unary -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
||||
is Numeric -> constant(node.value)
|
||||
|
||||
is MST.Binary -> when {
|
||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> constant(
|
||||
algebra.number(
|
||||
RealField
|
||||
is Unary -> when {
|
||||
algebra is NumericAlgebra && node.value is Numeric -> constant(
|
||||
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||
|
||||
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
||||
}
|
||||
|
||||
is Binary -> when {
|
||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
|
||||
algebra
|
||||
.binaryOperationFunction(node.operation)
|
||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
||||
)
|
||||
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> call(
|
||||
algebra is NumericAlgebra && node.left is Numeric -> call(
|
||||
algebra.leftSideNumberOperationFunction(node.operation),
|
||||
visit(node.left),
|
||||
visit(node.right),
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> call(
|
||||
algebra is NumericAlgebra && node.right is Numeric -> call(
|
||||
algebra.rightSideNumberOperationFunction(node.operation),
|
||||
visit(node.left),
|
||||
visit(node.right),
|
||||
|
@ -3,11 +3,11 @@ package kscience.kmath.asm
|
||||
import kscience.kmath.asm.internal.AsmBuilder
|
||||
import kscience.kmath.asm.internal.buildName
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MST.*
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
* Compiles given MST to an Expression using AST compiler.
|
||||
@ -20,7 +20,7 @@ import kscience.kmath.operations.RealField
|
||||
@PublishedApi
|
||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
||||
is MST.Symbolic -> {
|
||||
is Symbolic -> {
|
||||
val symbol = try {
|
||||
algebra.symbol(node.value)
|
||||
} catch (ignored: IllegalStateException) {
|
||||
@ -33,24 +33,29 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
||||
loadVariable(node.value)
|
||||
}
|
||||
|
||||
is MST.Numeric -> loadNumberConstant(node.value)
|
||||
is MST.Unary -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
||||
is Numeric -> loadNumberConstant(node.value)
|
||||
|
||||
is MST.Binary -> when {
|
||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant(
|
||||
algebra.number(
|
||||
RealField
|
||||
.binaryOperationFunction(node.operation)
|
||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
||||
)
|
||||
is Unary -> when {
|
||||
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
||||
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||
|
||||
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
||||
}
|
||||
|
||||
is Binary -> when {
|
||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
||||
algebra.binaryOperationFunction(node.operation)
|
||||
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperationFunction(node.operation)) {
|
||||
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
||||
algebra.leftSideNumberOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
}
|
||||
|
||||
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperationFunction(node.operation)) {
|
||||
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
||||
algebra.rightSideNumberOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
}
|
||||
|
@ -191,7 +191,7 @@ internal class AsmBuilder<T>(
|
||||
}
|
||||
|
||||
val cls = classLoader.defineClass(className, classWriter.toByteArray())
|
||||
java.io.File("dump.class").writeBytes(classWriter.toByteArray())
|
||||
// java.io.File("dump.class").writeBytes(classWriter.toByteArray())
|
||||
val l = MethodHandles.publicLookup()
|
||||
|
||||
if (hasConstants)
|
||||
|
Loading…
Reference in New Issue
Block a user