Add complete constant folding in kmath-ast by introducing TypedMst, some minor changes

This commit is contained in:
Iaroslav Postovalov 2022-02-05 04:27:10 +07:00 committed by Iaroslav Postovalov
parent ef747f642f
commit 745a7ad66e
16 changed files with 622 additions and 312 deletions

View File

@ -247,6 +247,12 @@ One can still use generic algebras though.
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
<hr/> <hr/>
* ### [kmath-tensorflow](kmath-tensorflow)
>
>
> **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-tensors](kmath-tensors) * ### [kmath-tensors](kmath-tensors)
> >
> >

View File

@ -1,6 +1,6 @@
# Module kmath-ast # Module kmath-ast
Performance and visualization extensions to MST API. Extensions to MST API: transformations, dynamic compilation and visualization.
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser - [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler - [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
@ -35,6 +35,26 @@ dependencies {
} }
``` ```
## Parsing expressions
In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
Supported literals:
1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`&mdash;all parsed either as `kotlin.Long` or `kotlin.Double`.
Supported binary operators (from the highest precedence to the lowest one):
1. `^`
2. `*`, `/`
3. `+`, `-`
Supported unary operator:
1. `-`, e.&nbsp;g. `-x`
Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
1. `sin(x)`
2. `add(x, y)`
## Dynamic expression code generation ## Dynamic expression code generation
### On JVM ### On JVM
@ -42,48 +62,41 @@ dependencies {
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a `kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression<T>` with implemented `invoke` function. special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder: For example, the following code:
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.* import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.operations.*
import space.kscience.kmath.asm.*
MstField { x + 2 }.compileToExpression(DoubleField) "x+2".parseMath().compileToExpression(ComplexField)
``` ```
... leads to generation of bytecode, which can be decompiled to the following Java class: &mldr; leads to generation of bytecode, which can be decompiled to the following Java class:
```java ```java
package space.kscience.kmath.asm.generated;
import java.util.Map; import java.util.Map;
import kotlin.jvm.functions.Function2; import kotlin.jvm.functions.Function2;
import space.kscience.kmath.asm.internal.MapIntrinsics; import space.kscience.kmath.asm.internal.MapIntrinsics;
import space.kscience.kmath.complex.Complex;
import space.kscience.kmath.expressions.Expression; import space.kscience.kmath.expressions.Expression;
import space.kscience.kmath.expressions.Symbol; import space.kscience.kmath.expressions.Symbol;
public final class AsmCompiledExpression_45045_0 implements Expression<Double> { public final class CompiledExpression_45045_0 implements Expression<Complex> {
private final Object[] constants; private final Object[] constants;
public final Double invoke(Map<Symbol, ? extends Double> arguments) { public Complex invoke(Map<Symbol, ? extends Complex> arguments) {
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2); Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
} return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
public AsmCompiledExpression_45045_0(Object[] constants) {
this.constants = constants;
} }
} }
``` ```
#### Known issues Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class #### Limitations
loading overhead.
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. - The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
### On JS ### On JS
@ -129,7 +142,7 @@ An example of emitted Wasm IR in the form of WAT:
) )
``` ```
#### Known issues #### Limitations
- ESTree expression compilation uses `eval` which can be unavailable in several environments. - ESTree expression compilation uses `eval` which can be unavailable in several environments.
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/). - WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).

View File

@ -1,11 +1,31 @@
# Module kmath-ast # Module kmath-ast
Performance and visualization extensions to MST API. Extensions to MST API: transformations, dynamic compilation and visualization.
${features} ${features}
${artifact} ${artifact}
## Parsing expressions
In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
Supported literals:
1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`&mdash;all parsed either as `kotlin.Long` or `kotlin.Double`.
Supported binary operators (from the highest precedence to the lowest one):
1. `^`
2. `*`, `/`
3. `+`, `-`
Supported unary operator:
1. `-`, e.&nbsp;g. `-x`
Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
1. `sin(x)`
2. `add(x, y)`
## Dynamic expression code generation ## Dynamic expression code generation
### On JVM ### On JVM
@ -13,48 +33,66 @@ ${artifact}
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a `kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression<T>` with implemented `invoke` function. special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder: For example, the following code:
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.* import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.*
import space.kscience.kmath.asm.*
MstField { x + 2 }.compileToExpression(DoubleField) "x^3-x+3".parseMath().compileToExpression(DoubleField)
``` ```
... leads to generation of bytecode, which can be decompiled to the following Java class: &mldr; leads to generation of bytecode, which can be decompiled to the following Java class:
```java ```java
package space.kscience.kmath.asm.generated; import java.util.*;
import kotlin.jvm.functions.*;
import space.kscience.kmath.asm.internal.*;
import space.kscience.kmath.complex.*;
import space.kscience.kmath.expressions.*;
import java.util.Map; public final class CompiledExpression_45045_0 implements Expression<Complex> {
import kotlin.jvm.functions.Function2;
import space.kscience.kmath.asm.internal.MapIntrinsics;
import space.kscience.kmath.expressions.Expression;
import space.kscience.kmath.expressions.Symbol;
public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
private final Object[] constants; private final Object[] constants;
public final Double invoke(Map<Symbol, ? extends Double> arguments) { public Complex invoke(Map<Symbol, ? extends Complex> arguments) {
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2); Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
} return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
public AsmCompiledExpression_45045_0(Object[] constants) {
this.constants = constants;
} }
} }
``` ```
#### Known issues For `LongRing`, `IntRing`, and `DoubleField` specialization is supported for better performance:
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class ```java
loading overhead. import java.util.*;
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. import space.kscience.kmath.asm.internal.*;
import space.kscience.kmath.expressions.*;
public final class CompiledExpression_-386104628_0 implements DoubleExpression {
private final SymbolIndexer indexer;
public SymbolIndexer getIndexer() {
return this.indexer;
}
public double invoke(double[] arguments) {
double var2 = arguments[0];
return Math.pow(var2, 3.0D) - var2 + 3.0D;
}
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
double var2 = ((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue();
return Math.pow(var2, 3.0D) - var2 + 3.0D;
}
}
```
Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
#### Limitations
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
### On JS ### On JS
@ -100,7 +138,7 @@ An example of emitted Wasm IR in the form of WAT:
) )
``` ```
#### Known issues #### Limitations
- ESTree expression compilation uses `eval` which can be unavailable in several environments. - ESTree expression compilation uses `eval` which can be unavailable in several environments.
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/). - WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).

View File

@ -0,0 +1,177 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
/**
* MST form where all values belong to the type [T]. It is optimal for constant folding, dynamic compilation, etc.
*
* @param T the type.
*/
@UnstableKMathAPI
public sealed interface TypedMst<T> {
/**
* A node containing a unary operation.
*
* @param T the type.
* @property operation The identifier of operation.
* @property function The function implementing this operation.
* @property value The argument of this operation.
*/
public class Unary<T>(public val operation: String, public val function: (T) -> T, public val value: TypedMst<T>) :
TypedMst<T> {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as Unary<*>
if (operation != other.operation) return false
if (value != other.value) return false
return true
}
override fun hashCode(): Int {
var result = operation.hashCode()
result = 31 * result + value.hashCode()
return result
}
override fun toString(): String = "Unary(operation=$operation, value=$value)"
}
/**
* A node containing binary operation.
*
* @param T the type.
* @property operation The identifier of operation.
* @property function The binary function implementing this operation.
* @property left The left operand.
* @property right The right operand.
*/
public class Binary<T>(
public val operation: String,
public val function: Function<T>,
public val left: TypedMst<T>,
public val right: TypedMst<T>,
) : TypedMst<T> {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as Binary<*>
if (operation != other.operation) return false
if (left != other.left) return false
if (right != other.right) return false
return true
}
override fun hashCode(): Int {
var result = operation.hashCode()
result = 31 * result + left.hashCode()
result = 31 * result + right.hashCode()
return result
}
override fun toString(): String = "Binary(operation=$operation, left=$left, right=$right)"
}
/**
* The non-numeric constant value.
*
* @param T the type.
* @property value The held value.
* @property number The number this value corresponds.
*/
public class Constant<T>(public val value: T, public val number: Number?) : TypedMst<T> {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as Constant<*>
if (value != other.value) return false
if (number != other.number) return false
return true
}
override fun hashCode(): Int {
var result = value?.hashCode() ?: 0
result = 31 * result + (number?.hashCode() ?: 0)
return result
}
override fun toString(): String = "Constant(value=$value, number=$number)"
}
/**
* The node containing a variable
*
* @param T the type.
* @property symbol The symbol of the variable.
*/
public class Variable<T>(public val symbol: Symbol) : TypedMst<T> {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as Variable<*>
if (symbol != other.symbol) return false
return true
}
override fun hashCode(): Int = symbol.hashCode()
override fun toString(): String = "Variable(symbol=$symbol)"
}
}
/**
* Interprets the [TypedMst] node with this [Algebra] and [arguments].
*/
@UnstableKMathAPI
public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = when (this) {
is TypedMst.Unary -> algebra.unaryOperation(operation, interpret(algebra, arguments))
is TypedMst.Binary -> when {
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null ->
algebra.leftSideNumberOperation(operation, left.number, right.interpret(algebra, arguments))
algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null ->
algebra.rightSideNumberOperation(operation, left.interpret(algebra, arguments), right.number)
else -> algebra.binaryOperation(
operation,
left.interpret(algebra, arguments),
right.interpret(algebra, arguments),
)
}
is TypedMst.Constant -> value
is TypedMst.Variable -> arguments.getValue(symbol)
}
/**
* Interprets the [TypedMst] node with this [Algebra] and optional [arguments].
*/
@UnstableKMathAPI
public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = interpret(
algebra,
when (arguments.size) {
0 -> emptyMap()
1 -> mapOf(arguments[0])
else -> hashMapOf(*arguments)
},
)
/**
* Interpret this [TypedMst] node as expression.
*/
@UnstableKMathAPI
public fun <T : Any> TypedMst<T>.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
interpret(algebra, arguments)
}

View File

@ -0,0 +1,93 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
import space.kscience.kmath.operations.bindSymbolOrNull
/**
* Evaluates constants in given [MST] for given [algebra] at the same time with converting to [TypedMst].
*/
@UnstableKMathAPI
public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (this) {
is MST.Numeric -> TypedMst.Constant(
(algebra as? NumericAlgebra<T>)?.number(value) ?: error("Numeric nodes are not supported by $algebra"),
value,
)
is MST.Unary -> when (val arg = value.evaluateConstants(algebra)) {
is TypedMst.Constant<T> -> {
val value = algebra.unaryOperation(
operation,
arg.value,
)
TypedMst.Constant(value, if (value is Number) value else null)
}
else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg)
}
is MST.Binary -> {
val left = left.evaluateConstants(algebra)
val right = right.evaluateConstants(algebra)
when {
left is TypedMst.Constant<T> && right is TypedMst.Constant<T> -> {
val value = when {
algebra is NumericAlgebra && left.number != null -> algebra.leftSideNumberOperation(
operation,
left.number,
right.value,
)
algebra is NumericAlgebra && right.number != null -> algebra.rightSideNumberOperation(
operation,
left.value,
right.number,
)
else -> algebra.binaryOperation(
operation,
left.value,
right.value,
)
}
TypedMst.Constant(value, if (value is Number) value else null)
}
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary(
operation,
algebra.leftSideNumberOperationFunction(operation),
left,
right,
)
algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null -> TypedMst.Binary(
operation,
algebra.rightSideNumberOperationFunction(operation),
left,
right,
)
else -> TypedMst.Binary(operation, algebra.binaryOperationFunction(operation), left, right)
}
}
is Symbol -> {
val boundSymbol = algebra.bindSymbolOrNull(this)
if (boundSymbol != null)
TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null)
else
TypedMst.Variable(this)
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.pi
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.fail
internal class TestFolding {
@Test
fun foldUnary() = assertEquals(
-1,
("-(1)".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
)
@Test
fun foldDeepUnary() = assertEquals(
1,
("-(-(1))".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
)
@Test
fun foldBinary() = assertEquals(
2,
("1*2".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
)
@Test
fun foldDeepBinary() = assertEquals(
10,
("1*2*5".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
)
@Test
fun foldSymbol() = assertEquals(
DoubleField.pi,
("pi".parseMath().evaluateConstants(DoubleField) as? TypedMst.Constant<Double> ?: fail()).value,
)
@Test
fun foldNumeric() = assertEquals(
42.toByte(),
("42".parseMath().evaluateConstants(ByteRing) as? TypedMst.Constant<Byte> ?: fail()).value,
)
}

View File

@ -5,87 +5,48 @@
package space.kscience.kmath.estree package space.kscience.kmath.estree
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.estree.internal.ESTreeBuilder import space.kscience.kmath.estree.internal.ESTreeBuilder
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.internal.estree.BaseExpression import space.kscience.kmath.internal.estree.BaseExpression
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
import space.kscience.kmath.operations.bindSymbolOrNull
@PublishedApi
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
is Symbol -> {
val symbol = algebra.bindSymbolOrNull(node)
if (symbol != null)
constant(symbol)
else
variable(node.identity)
}
is Numeric -> constant(
(algebra as? NumericAlgebra<T>)?.number(node.value) ?: error("Numeric nodes are not supported by $this")
)
is Unary -> when {
algebra is NumericAlgebra && node.value is Numeric -> constant(
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))
)
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
}
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
algebra.binaryOperationFunction(node.operation).invoke(
algebra.number((node.left as Numeric).value),
algebra.number((node.right as Numeric).value)
)
)
algebra is NumericAlgebra && node.left is Numeric -> call(
algebra.leftSideNumberOperationFunction(node.operation),
visit(node.left),
visit(node.right),
)
algebra is NumericAlgebra && node.right is Numeric -> call(
algebra.rightSideNumberOperationFunction(node.operation),
visit(node.left),
visit(node.right),
)
else -> call(
algebra.binaryOperationFunction(node.operation),
visit(node.left),
visit(node.right),
)
}
}
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
}
/** /**
* Create a compiled expression with given [MST] and given [algebra]. * Create a compiled expression with given [MST] and given [algebra].
*/ */
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = compileWith(algebra) @OptIn(UnstableKMathAPI::class)
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> {
val typed = evaluateConstants(algebra)
if (typed is TypedMst.Constant<T>) return Expression { typed.value }
fun ESTreeBuilder<T>.visit(node: TypedMst<T>): BaseExpression = when (node) {
is TypedMst.Constant -> constant(node.value)
is TypedMst.Variable -> variable(node.symbol)
is TypedMst.Unary -> call(node.function, visit(node.value))
is TypedMst.Binary -> call(
node.function,
visit(node.left),
visit(node.right),
)
}
return ESTreeBuilder<T> { visit(typed) }.instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments] * Compile given MST to expression and evaluate it against [arguments]
*/ */
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = public fun <T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
* Compile given MST to expression and evaluate it against [arguments] * Compile given MST to expression and evaluate it against [arguments]
*/ */
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = public fun <T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra)(*arguments)

View File

@ -61,7 +61,7 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
} }
} }
fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name)) fun variable(name: Symbol): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name.identity))
fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression( fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
optional = false, optional = false,

View File

@ -5,8 +5,8 @@
package space.kscience.kmath.wasm.internal package space.kscience.kmath.wasm.internal
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.internal.binaryen.* import space.kscience.kmath.internal.binaryen.*
import space.kscience.kmath.internal.webassembly.Instance import space.kscience.kmath.internal.webassembly.Instance
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
@ -16,11 +16,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
private val spreader = eval("(obj, args) => obj(...args)") private val spreader = eval("(obj, args) => obj(...args)")
@OptIn(UnstableKMathAPI::class)
@Suppress("UnsafeCastFromDynamic") @Suppress("UnsafeCastFromDynamic")
internal sealed class WasmBuilder<T : Number, out E : Expression<T>>( internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
protected val binaryenType: Type, protected val binaryenType: Type,
protected val algebra: Algebra<T>, protected val algebra: Algebra<T>,
protected val target: MST, protected val target: TypedMst<T>,
) { ) {
protected val keys: MutableList<Symbol> = mutableListOf() protected val keys: MutableList<Symbol> = mutableListOf()
protected lateinit var ctx: BinaryenModule protected lateinit var ctx: BinaryenModule
@ -51,59 +52,41 @@ internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
Instance(c, js("{}")).exports.executable Instance(c, js("{}")).exports.executable
} }
protected open fun visitSymbol(node: Symbol): ExpressionRef { protected abstract fun visitNumber(number: Number): ExpressionRef
algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) }
var idx = keys.indexOf(node) protected open fun visitVariable(node: TypedMst.Variable<T>): ExpressionRef {
var idx = keys.indexOf(node.symbol)
if (idx == -1) { if (idx == -1) {
keys += node keys += node.symbol
idx = keys.lastIndex idx = keys.lastIndex
} }
return ctx.local.get(idx, binaryenType) return ctx.local.get(idx, binaryenType)
} }
protected abstract fun visitNumeric(node: Numeric): ExpressionRef protected open fun visitUnary(node: TypedMst.Unary<T>): ExpressionRef =
protected open fun visitUnary(node: Unary): ExpressionRef =
error("Unary operation ${node.operation} not defined in $this") error("Unary operation ${node.operation} not defined in $this")
protected open fun visitBinary(mst: Binary): ExpressionRef = protected open fun visitBinary(mst: TypedMst.Binary<T>): ExpressionRef =
error("Binary operation ${mst.operation} not defined in $this") error("Binary operation ${mst.operation} not defined in $this")
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()") protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
protected fun visit(node: MST): ExpressionRef = when (node) { protected fun visit(node: TypedMst<T>): ExpressionRef = when (node) {
is Symbol -> visitSymbol(node) is TypedMst.Constant -> visitNumber(
is Numeric -> visitNumeric(node) node.number ?: error("Object constants are not supported by pritimive ASM builder"),
is Unary -> when {
algebra is NumericAlgebra && node.value is Numeric -> visitNumeric(
Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
) )
else -> visitUnary(node) is TypedMst.Variable -> visitVariable(node)
} is TypedMst.Unary -> visitUnary(node)
is TypedMst.Binary -> visitBinary(node)
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric(
Numeric(
algebra.binaryOperationFunction(node.operation)
.invoke(
algebra.number((node.left as Numeric).value),
algebra.number((node.right as Numeric).value)
)
)
)
else -> visitBinary(node)
}
} }
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) { internal class DoubleWasmBuilder(target: TypedMst<Double>) :
WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) {
override val instance by lazy { override val instance by lazy {
object : DoubleExpression { object : DoubleExpression {
override val indexer = SimpleSymbolIndexer(keys) override val indexer = SimpleSymbolIndexer(keys)
@ -114,9 +97,9 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
override fun createModule() = readBinary(f64StandardFunctions) override fun createModule() = readBinary(f64StandardFunctions)
override fun visitNumeric(node: Numeric) = ctx.f64.const(node.value.toDouble()) override fun visitNumber(number: Number) = ctx.f64.const(number.toDouble())
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) { override fun visitUnary(node: TypedMst.Unary<Double>): ExpressionRef = when (node.operation) {
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value)) GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
GroupOps.PLUS_OPERATION -> visit(node.value) GroupOps.PLUS_OPERATION -> visit(node.value)
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value)) PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
@ -137,7 +120,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
else -> super.visitUnary(node) else -> super.visitUnary(node)
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: TypedMst.Binary<Double>): ExpressionRef = when (mst.operation) {
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
@ -148,7 +131,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) { internal class IntWasmBuilder(target: TypedMst<Int>) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) {
override val instance by lazy { override val instance by lazy {
object : IntExpression { object : IntExpression {
override val indexer = SimpleSymbolIndexer(keys) override val indexer = SimpleSymbolIndexer(keys)
@ -157,15 +140,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32
} }
} }
override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value.toInt()) override fun visitNumber(number: Number) = ctx.i32.const(number.toInt())
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) { override fun visitUnary(node: TypedMst.Unary<Int>): ExpressionRef = when (node.operation) {
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value)) GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
GroupOps.PLUS_OPERATION -> visit(node.value) GroupOps.PLUS_OPERATION -> visit(node.value)
else -> super.visitUnary(node) else -> super.visitUnary(node)
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: TypedMst.Binary<Int>): ExpressionRef = when (mst.operation) {
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))

View File

@ -7,7 +7,8 @@
package space.kscience.kmath.wasm package space.kscience.kmath.wasm
import space.kscience.kmath.estree.compileWith import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
@ -21,8 +22,16 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance public fun MST.compileToExpression(algebra: IntRing): IntExpression {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant) object : IntExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: IntArray): Int = typed.value
} else
IntWasmBuilder(typed).instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -31,7 +40,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBui
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
@ -49,7 +58,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleWasmBuilder(this).instance public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant) object : DoubleExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: DoubleArray): Double = typed.value
} else
DoubleWasmBuilder(typed).instance
}
/** /**
@ -59,7 +77,7 @@ public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = D
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
@ -69,4 +87,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra)(*arguments)

View File

@ -8,10 +8,14 @@
package space.kscience.kmath.asm package space.kscience.kmath.asm
import space.kscience.kmath.asm.internal.* import space.kscience.kmath.asm.internal.*
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.LongRing
/** /**
* Compiles given MST to an Expression using AST compiler. * Compiles given MST to an Expression using AST compiler.
@ -21,102 +25,64 @@ import space.kscience.kmath.operations.*
* @return the compiled expression. * @return the compiled expression.
* @author Alexander Nozik * @author Alexander Nozik
*/ */
@OptIn(UnstableKMathAPI::class)
@PublishedApi @PublishedApi
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> { internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
fun GenericAsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) { val typed = evaluateConstants(algebra)
is Symbol -> prepareVariable(node.identity) if (typed is TypedMst.Constant<T>) return Expression { typed.value }
is Unary -> variablesVisitor(node.value)
is Binary -> { fun GenericAsmBuilder<T>.variablesVisitor(node: TypedMst<T>): Unit = when (node) {
is TypedMst.Unary -> variablesVisitor(node.value)
is TypedMst.Binary -> {
variablesVisitor(node.left) variablesVisitor(node.left)
variablesVisitor(node.right) variablesVisitor(node.right)
} }
else -> Unit is TypedMst.Variable -> prepareVariable(node.symbol)
is TypedMst.Constant -> Unit
} }
fun GenericAsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) { fun GenericAsmBuilder<T>.expressionVisitor(node: TypedMst<T>): Unit = when (node) {
is Symbol -> { is TypedMst.Constant -> if (node.number != null)
val symbol = algebra.bindSymbolOrNull(node) loadNumberConstant(node.number)
if (symbol != null)
loadObjectConstant(symbol as Any)
else else
loadVariable(node.identity) loadObjectConstant(node.value)
}
is Numeric -> if (algebra is NumericAlgebra) { is TypedMst.Variable -> loadVariable(node.symbol)
if (Number::class.java.isAssignableFrom(type)) is TypedMst.Unary -> buildCall(node.function) { expressionVisitor(node.value) }
loadNumberConstant(algebra.number(node.value) as Number)
else
loadObjectConstant(algebra.number(node.value))
} else
error("Numeric nodes are not supported by $this")
is Unary -> when { is TypedMst.Binary -> buildCall(node.function) {
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)),
)
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) }
}
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
algebra.binaryOperationFunction(node.operation).invoke(
algebra.number((node.left as Numeric).value),
algebra.number((node.right as Numeric).value),
)
)
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
algebra.leftSideNumberOperationFunction(node.operation),
) {
expressionVisitor(node.left) expressionVisitor(node.left)
expressionVisitor(node.right) expressionVisitor(node.right)
} }
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
algebra.rightSideNumberOperationFunction(node.operation),
) {
expressionVisitor(node.left)
expressionVisitor(node.right)
}
else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
expressionVisitor(node.left)
expressionVisitor(node.right)
}
}
} }
return GenericAsmBuilder<T>( return GenericAsmBuilder<T>(
type, type,
buildName(this), buildName("${typed.hashCode()}_${type.simpleName}"),
{ variablesVisitor(this@compileWith) }, { variablesVisitor(typed) },
{ expressionVisitor(this@compileWith) }, { expressionVisitor(typed) },
).instance ).instance
} }
/** /**
* Create a compiled expression with given [MST] and given [algebra]. * Create a compiled expression with given [MST] and given [algebra].
*/ */
public inline fun <reified T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = public inline fun <reified T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
compileWith(T::class.java, algebra) compileWith(T::class.java, algebra)
/** /**
* Compile given MST to expression and evaluate it against [arguments] * Compile given MST to expression and evaluate it against [arguments]
*/ */
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
* Compile given MST to expression and evaluate it against [arguments] * Compile given MST to expression and evaluate it against [arguments]
*/ */
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra)(*arguments)
/** /**
@ -125,7 +91,16 @@ public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg argu
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance public fun MST.compileToExpression(algebra: IntRing): IntExpression {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant) object : IntExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: IntArray): Int = typed.value
} else
IntAsmBuilder(typed).instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -134,7 +109,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuil
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -152,8 +127,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance public fun MST.compileToExpression(algebra: LongRing): LongExpression {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant<Long>) object : LongExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: LongArray): Long = typed.value
} else
LongAsmBuilder(typed).instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -162,7 +145,7 @@ public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmB
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long = public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
@ -181,7 +164,17 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>):
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant) object : DoubleExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: DoubleArray): Double = typed.value
} else
DoubleAsmBuilder(typed).instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -190,7 +183,7 @@ public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = Dou
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -199,4 +192,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra)(*arguments)

View File

@ -56,7 +56,7 @@ internal class GenericAsmBuilder<T>(
/** /**
* Local variables indices are indices of symbols in this list. * Local variables indices are indices of symbols in this list.
*/ */
private val argumentsLocals = mutableListOf<String>() private val argumentsLocals = mutableListOf<Symbol>()
/** /**
* Subclasses, loads and instantiates [Expression] for given parameters. * Subclasses, loads and instantiates [Expression] for given parameters.
@ -253,10 +253,10 @@ internal class GenericAsmBuilder<T>(
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using * Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
* [loadVariable]. * [loadVariable].
*/ */
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run { fun prepareVariable(name: Symbol): Unit = invokeMethodVisitor.run {
if (name in argumentsLocals) return@run if (name in argumentsLocals) return@run
load(1, MAP_TYPE) load(1, MAP_TYPE)
aconst(name) aconst(name.identity)
invokestatic( invokestatic(
MAP_INTRINSICS_TYPE.internalName, MAP_INTRINSICS_TYPE.internalName,
@ -280,7 +280,7 @@ internal class GenericAsmBuilder<T>(
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
* with [prepareVariable] first. * with [prepareVariable] first.
*/ */
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType) fun loadVariable(name: Symbol): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
inline fun buildCall(function: Function<T>, parameters: GenericAsmBuilder<T>.() -> Unit) { inline fun buildCall(function: Function<T>, parameters: GenericAsmBuilder<T>.() -> Unit) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }

View File

@ -11,6 +11,7 @@ import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type import org.objectweb.asm.Type
import org.objectweb.asm.Type.* import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter import org.objectweb.asm.commons.InstructionAdapter
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
@ -25,9 +26,9 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
classOfT: Class<*>, classOfT: Class<*>,
protected val classOfTPrimitive: Class<*>, protected val classOfTPrimitive: Class<*>,
expressionParent: Class<E>, expressionParent: Class<E>,
protected val target: MST, protected val target: TypedMst<T>,
) : AsmBuilder() { ) : AsmBuilder() {
private val className: String = buildName(target) private val className: String = buildName("${target.hashCode()}_${classOfT.simpleName}")
/** /**
* ASM type for [tType]. * ASM type for [tType].
@ -329,63 +330,39 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
} }
private fun visitVariables( private fun visitVariables(
node: MST, node: TypedMst<T>,
arrayMode: Boolean, arrayMode: Boolean,
alreadyLoaded: MutableList<Symbol> = mutableListOf() alreadyLoaded: MutableList<Symbol> = mutableListOf()
): Unit = when (node) { ): Unit = when (node) {
is Symbol -> when (node) { is TypedMst.Variable -> if (node.symbol !in alreadyLoaded) {
!in alreadyLoaded -> { alreadyLoaded += node.symbol
alreadyLoaded += node prepareVariable(node.symbol, arrayMode)
prepareVariable(node, arrayMode) } else Unit
}
else -> {
}
}
is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded) is TypedMst.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
is MST.Binary -> { is TypedMst.Binary -> {
visitVariables(node.left, arrayMode, alreadyLoaded) visitVariables(node.left, arrayMode, alreadyLoaded)
visitVariables(node.right, arrayMode, alreadyLoaded) visitVariables(node.right, arrayMode, alreadyLoaded)
} }
else -> Unit is TypedMst.Constant -> Unit
} }
private fun visitExpression(node: MST): Unit = when (node) { private fun visitExpression(node: TypedMst<T>): Unit = when (node) {
is Symbol -> { is TypedMst.Variable -> loadVariable(node.symbol)
val symbol = algebra.bindSymbolOrNull(node)
if (symbol != null) is TypedMst.Constant -> loadNumberConstant(
loadNumberConstant(symbol) node.number ?: error("Object constants are not supported by pritimive ASM builder"),
else
loadVariable(node)
}
is MST.Numeric -> loadNumberConstant(algebra.number(node.value))
is MST.Unary -> if (node.value is MST.Numeric)
loadNumberConstant(
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)),
)
else
visitUnary(node)
is MST.Binary -> when {
node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant(
algebra.binaryOperationFunction(node.operation)(
algebra.number((node.left as MST.Numeric).value),
algebra.number((node.right as MST.Numeric).value),
),
) )
else -> visitBinary(node) is TypedMst.Unary -> visitUnary(node)
} is TypedMst.Binary -> visitBinary(node)
} }
protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value) protected open fun visitUnary(node: TypedMst.Unary<T>) = visitExpression(node.value)
protected open fun visitBinary(node: MST.Binary) { protected open fun visitBinary(node: TypedMst.Binary<T>) {
visitExpression(node.left) visitExpression(node.left)
visitExpression(node.right) visitExpression(node.right)
} }
@ -404,14 +381,13 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, DoubleExpression>( internal class DoubleAsmBuilder(target: TypedMst<Double>) : PrimitiveAsmBuilder<Double, DoubleExpression>(
DoubleField, DoubleField,
java.lang.Double::class.java, java.lang.Double::class.java,
java.lang.Double.TYPE, java.lang.Double.TYPE,
DoubleExpression::class.java, DoubleExpression::class.java,
target, target,
) { ) {
private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic( private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
MATH_TYPE.internalName, MATH_TYPE.internalName,
name, name,
@ -434,7 +410,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
false, false,
) )
override fun visitUnary(node: MST.Unary) { override fun visitUnary(node: TypedMst.Unary<Double>) {
super.visitUnary(node) super.visitUnary(node)
when (node.operation) { when (node.operation) {
@ -459,7 +435,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
} }
} }
override fun visitBinary(node: MST.Binary) { override fun visitBinary(node: TypedMst.Binary<Double>) {
super.visitBinary(node) super.visitBinary(node)
when (node.operation) { when (node.operation) {
@ -479,7 +455,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class IntAsmBuilder(target: MST) : internal class IntAsmBuilder(target: TypedMst<Int>) :
PrimitiveAsmBuilder<Int, IntExpression>( PrimitiveAsmBuilder<Int, IntExpression>(
IntRing, IntRing,
Integer::class.java, Integer::class.java,
@ -487,7 +463,7 @@ internal class IntAsmBuilder(target: MST) :
IntExpression::class.java, IntExpression::class.java,
target target
) { ) {
override fun visitUnary(node: MST.Unary) { override fun visitUnary(node: TypedMst.Unary<Int>) {
super.visitUnary(node) super.visitUnary(node)
when (node.operation) { when (node.operation) {
@ -497,7 +473,7 @@ internal class IntAsmBuilder(target: MST) :
} }
} }
override fun visitBinary(node: MST.Binary) { override fun visitBinary(node: TypedMst.Binary<Int>) {
super.visitBinary(node) super.visitBinary(node)
when (node.operation) { when (node.operation) {
@ -510,14 +486,14 @@ internal class IntAsmBuilder(target: MST) :
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpression>( internal class LongAsmBuilder(target: TypedMst<Long>) : PrimitiveAsmBuilder<Long, LongExpression>(
LongRing, LongRing,
java.lang.Long::class.java, java.lang.Long::class.java,
java.lang.Long.TYPE, java.lang.Long.TYPE,
LongExpression::class.java, LongExpression::class.java,
target, target,
) { ) {
override fun visitUnary(node: MST.Unary) { override fun visitUnary(node: TypedMst.Unary<Long>) {
super.visitUnary(node) super.visitUnary(node)
when (node.operation) { when (node.operation) {
@ -527,7 +503,7 @@ internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpre
} }
} }
override fun visitBinary(node: MST.Binary) { override fun visitBinary(node: TypedMst.Binary<Long>) {
super.visitBinary(node) super.visitBinary(node)
when (node.operation) { when (node.operation) {

View File

@ -55,15 +55,15 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.(
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel) internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
/** /**
* Creates a class name for [Expression] subclassed to implement [mst] provided. * Creates a class name for [Expression] based with appending [marker] to reduce collisions.
* *
* These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there * These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
internal tailrec fun buildName(mst: MST, collision: Int = 0): String { internal tailrec fun buildName(marker: String, collision: Int = 0): String {
val name = "space.kscience.kmath.asm.generated.CompiledExpression_${mst.hashCode()}_$collision" val name = "space.kscience.kmath.asm.generated.CompiledExpression_${marker}_$collision"
try { try {
Class.forName(name) Class.forName(name)
@ -71,7 +71,7 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
return name return name
} }
return buildName(mst, collision + 1) return buildName(marker, collision + 1)
} }
@Suppress("FunctionName") @Suppress("FunctionName")

View File

@ -34,12 +34,12 @@ public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> = override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
{ left, right -> { left, right ->
Expression { arguments -> Expression { arguments ->
algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments)) algebra.binaryOperationFunction(operation)(left(arguments), right(arguments))
} }
} }
override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg -> override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) } Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) }
} }
} }

View File

@ -24,7 +24,7 @@ public sealed interface MST {
public data class Numeric(val value: Number) : MST public data class Numeric(val value: Number) : MST
/** /**
* A node containing an unary operation. * A node containing a unary operation.
* *
* @property operation the identifier of operation. * @property operation the identifier of operation.
* @property value the argument of this operation. * @property value the argument of this operation.
@ -34,7 +34,7 @@ public sealed interface MST {
/** /**
* A node containing binary operation. * A node containing binary operation.
* *
* @property operation the identifier operation. * @property operation the identifier of operation.
* @property left the left operand. * @property left the left operand.
* @property right the right operand. * @property right the right operand.
*/ */