forked from kscience/kmath
Merge pull request #173 from mipt-npm/commandertvis/ast-valid-constantfolding
Fix #172, add constant folding for unary operations from numeric nodes
This commit is contained in:
commit
a5e8c971ba
@ -2,10 +2,9 @@ package kscience.kmath.ast
|
|||||||
|
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kscience.kmath.operations.NumericAlgebra
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A Mathematical Syntax Tree node for mathematical expressions.
|
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
|
||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
@ -57,21 +56,22 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
|||||||
?: error("Numeric nodes are not supported by $this")
|
?: error("Numeric nodes are not supported by $this")
|
||||||
|
|
||||||
is MST.Symbolic -> symbol(node.value)
|
is MST.Symbolic -> symbol(node.value)
|
||||||
is MST.Unary -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
|
||||||
|
is MST.Unary -> when {
|
||||||
|
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
|
||||||
|
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
||||||
|
}
|
||||||
|
|
||||||
is MST.Binary -> when {
|
is MST.Binary -> when {
|
||||||
this !is NumericAlgebra -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
|
||||||
|
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
|
||||||
|
|
||||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
this is NumericAlgebra && node.left is MST.Numeric ->
|
||||||
val number = RealField
|
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
||||||
.binaryOperationFunction(node.operation)
|
|
||||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
|
||||||
|
|
||||||
number(number)
|
this is NumericAlgebra && node.right is MST.Numeric ->
|
||||||
}
|
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
||||||
|
|
||||||
node.left is MST.Numeric -> leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
|
||||||
node.right is MST.Numeric -> rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
|
||||||
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
package kscience.kmath.estree
|
package kscience.kmath.estree
|
||||||
|
|
||||||
import kscience.kmath.ast.MST
|
import kscience.kmath.ast.MST
|
||||||
|
import kscience.kmath.ast.MST.*
|
||||||
import kscience.kmath.ast.MstExpression
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.estree.internal.ESTreeBuilder
|
import kscience.kmath.estree.internal.ESTreeBuilder
|
||||||
import kscience.kmath.estree.internal.estree.BaseExpression
|
import kscience.kmath.estree.internal.estree.BaseExpression
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kscience.kmath.operations.NumericAlgebra
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
|
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
||||||
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
|
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
|
||||||
is MST.Symbolic -> {
|
is Symbolic -> {
|
||||||
val symbol = try {
|
val symbol = try {
|
||||||
algebra.symbol(node.value)
|
algebra.symbol(node.value)
|
||||||
} catch (ignored: IllegalStateException) {
|
} catch (ignored: IllegalStateException) {
|
||||||
@ -25,25 +25,29 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
|||||||
variable(node.value)
|
variable(node.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Numeric -> constant(node.value)
|
is Numeric -> constant(node.value)
|
||||||
is MST.Unary -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
|
||||||
|
|
||||||
is MST.Binary -> when {
|
is Unary -> when {
|
||||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> constant(
|
algebra is NumericAlgebra && node.value is Numeric -> constant(
|
||||||
algebra.number(
|
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||||
RealField
|
|
||||||
.binaryOperationFunction(node.operation)
|
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
||||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
}
|
||||||
)
|
|
||||||
|
is Binary -> when {
|
||||||
|
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
|
||||||
|
algebra
|
||||||
|
.binaryOperationFunction(node.operation)
|
||||||
|
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||||
)
|
)
|
||||||
|
|
||||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> call(
|
algebra is NumericAlgebra && node.left is Numeric -> call(
|
||||||
algebra.leftSideNumberOperationFunction(node.operation),
|
algebra.leftSideNumberOperationFunction(node.operation),
|
||||||
visit(node.left),
|
visit(node.left),
|
||||||
visit(node.right),
|
visit(node.right),
|
||||||
)
|
)
|
||||||
|
|
||||||
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> call(
|
algebra is NumericAlgebra && node.right is Numeric -> call(
|
||||||
algebra.rightSideNumberOperationFunction(node.operation),
|
algebra.rightSideNumberOperationFunction(node.operation),
|
||||||
visit(node.left),
|
visit(node.left),
|
||||||
visit(node.right),
|
visit(node.right),
|
||||||
|
@ -3,11 +3,11 @@ package kscience.kmath.asm
|
|||||||
import kscience.kmath.asm.internal.AsmBuilder
|
import kscience.kmath.asm.internal.AsmBuilder
|
||||||
import kscience.kmath.asm.internal.buildName
|
import kscience.kmath.asm.internal.buildName
|
||||||
import kscience.kmath.ast.MST
|
import kscience.kmath.ast.MST
|
||||||
|
import kscience.kmath.ast.MST.*
|
||||||
import kscience.kmath.ast.MstExpression
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kscience.kmath.operations.NumericAlgebra
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compiles given MST to an Expression using AST compiler.
|
* Compiles given MST to an Expression using AST compiler.
|
||||||
@ -20,7 +20,7 @@ import kscience.kmath.operations.RealField
|
|||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
||||||
is MST.Symbolic -> {
|
is Symbolic -> {
|
||||||
val symbol = try {
|
val symbol = try {
|
||||||
algebra.symbol(node.value)
|
algebra.symbol(node.value)
|
||||||
} catch (ignored: IllegalStateException) {
|
} catch (ignored: IllegalStateException) {
|
||||||
@ -33,24 +33,29 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
|||||||
loadVariable(node.value)
|
loadVariable(node.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
is MST.Numeric -> loadNumberConstant(node.value)
|
is Numeric -> loadNumberConstant(node.value)
|
||||||
is MST.Unary -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
|
||||||
|
|
||||||
is MST.Binary -> when {
|
is Unary -> when {
|
||||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant(
|
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
||||||
algebra.number(
|
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||||
RealField
|
|
||||||
.binaryOperationFunction(node.operation)
|
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
||||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
}
|
||||||
)
|
|
||||||
|
is Binary -> when {
|
||||||
|
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
||||||
|
algebra.binaryOperationFunction(node.operation)
|
||||||
|
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||||
)
|
)
|
||||||
|
|
||||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperationFunction(node.operation)) {
|
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
||||||
|
algebra.leftSideNumberOperationFunction(node.operation)) {
|
||||||
visit(node.left)
|
visit(node.left)
|
||||||
visit(node.right)
|
visit(node.right)
|
||||||
}
|
}
|
||||||
|
|
||||||
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperationFunction(node.operation)) {
|
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
||||||
|
algebra.rightSideNumberOperationFunction(node.operation)) {
|
||||||
visit(node.left)
|
visit(node.left)
|
||||||
visit(node.right)
|
visit(node.right)
|
||||||
}
|
}
|
||||||
|
@ -191,7 +191,7 @@ internal class AsmBuilder<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
val cls = classLoader.defineClass(className, classWriter.toByteArray())
|
val cls = classLoader.defineClass(className, classWriter.toByteArray())
|
||||||
java.io.File("dump.class").writeBytes(classWriter.toByteArray())
|
// java.io.File("dump.class").writeBytes(classWriter.toByteArray())
|
||||||
val l = MethodHandles.publicLookup()
|
val l = MethodHandles.publicLookup()
|
||||||
|
|
||||||
if (hasConstants)
|
if (hasConstants)
|
||||||
|
Loading…
Reference in New Issue
Block a user