diff --git a/README.md b/README.md
index 92260716e..99dd6d00f 100644
--- a/README.md
+++ b/README.md
@@ -247,6 +247,12 @@ One can still use generic algebras though.
> **Maturity**: PROTOTYPE
+* ### [kmath-tensorflow](kmath-tensorflow)
+>
+>
+> **Maturity**: PROTOTYPE
+
+
* ### [kmath-tensors](kmath-tensors)
>
>
diff --git a/kmath-ast/README.md b/kmath-ast/README.md
index 5e3366881..bedf17486 100644
--- a/kmath-ast/README.md
+++ b/kmath-ast/README.md
@@ -1,6 +1,6 @@
# Module kmath-ast
-Performance and visualization extensions to MST API.
+Extensions to MST API: transformations, dynamic compilation and visualization.
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
@@ -35,6 +35,26 @@ dependencies {
}
```
+## Parsing expressions
+
+In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
+
+Supported literals:
+1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
+2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`.
+
+Supported binary operators (from the highest precedence to the lowest one):
+1. `^`
+2. `*`, `/`
+3. `+`, `-`
+
+Supported unary operator:
+1. `-`, e. g. `-x`
+
+Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
+1. `sin(x)`
+2. `add(x, y)`
+
## Dynamic expression code generation
### On JVM
@@ -42,48 +62,41 @@ dependencies {
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression` with implemented `invoke` function.
-For example, the following builder:
+For example, the following code:
```kotlin
-import space.kscience.kmath.expressions.Symbol.Companion.x
-import space.kscience.kmath.expressions.*
-import space.kscience.kmath.operations.*
-import space.kscience.kmath.asm.*
+import space.kscience.kmath.asm.compileToExpression
+import space.kscience.kmath.complex.ComplexField
-MstField { x + 2 }.compileToExpression(DoubleField)
-```
+"x+2".parseMath().compileToExpression(ComplexField)
+```
-... leads to generation of bytecode, which can be decompiled to the following Java class:
+… leads to generation of bytecode, which can be decompiled to the following Java class:
```java
-package space.kscience.kmath.asm.generated;
-
import java.util.Map;
-
import kotlin.jvm.functions.Function2;
import space.kscience.kmath.asm.internal.MapIntrinsics;
+import space.kscience.kmath.complex.Complex;
import space.kscience.kmath.expressions.Expression;
import space.kscience.kmath.expressions.Symbol;
-public final class AsmCompiledExpression_45045_0 implements Expression {
+public final class CompiledExpression_45045_0 implements Expression {
private final Object[] constants;
- public final Double invoke(Map arguments) {
- return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
- }
-
- public AsmCompiledExpression_45045_0(Object[] constants) {
- this.constants = constants;
+ public Complex invoke(Map arguments) {
+ Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
+ return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
}
}
-
```
-#### Known issues
+Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
-- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
- loading overhead.
-- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
+#### Limitations
+
+- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
+- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
### On JS
@@ -129,7 +142,7 @@ An example of emitted Wasm IR in the form of WAT:
)
```
-#### Known issues
+#### Limitations
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
diff --git a/kmath-ast/docs/README-TEMPLATE.md b/kmath-ast/docs/README-TEMPLATE.md
index 9494af63a..e9e22f4d4 100644
--- a/kmath-ast/docs/README-TEMPLATE.md
+++ b/kmath-ast/docs/README-TEMPLATE.md
@@ -1,11 +1,31 @@
# Module kmath-ast
-Performance and visualization extensions to MST API.
+Extensions to MST API: transformations, dynamic compilation and visualization.
${features}
${artifact}
+## Parsing expressions
+
+In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
+
+Supported literals:
+1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
+2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`.
+
+Supported binary operators (from the highest precedence to the lowest one):
+1. `^`
+2. `*`, `/`
+3. `+`, `-`
+
+Supported unary operator:
+1. `-`, e. g. `-x`
+
+Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
+1. `sin(x)`
+2. `add(x, y)`
+
## Dynamic expression code generation
### On JVM
@@ -13,48 +33,66 @@ ${artifact}
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression` with implemented `invoke` function.
-For example, the following builder:
+For example, the following code:
```kotlin
-import space.kscience.kmath.expressions.Symbol.Companion.x
-import space.kscience.kmath.expressions.*
-import space.kscience.kmath.operations.*
-import space.kscience.kmath.asm.*
-
-MstField { x + 2 }.compileToExpression(DoubleField)
-```
-
-... leads to generation of bytecode, which can be decompiled to the following Java class:
-
-```java
-package space.kscience.kmath.asm.generated;
-
-import java.util.Map;
-
-import kotlin.jvm.functions.Function2;
-import space.kscience.kmath.asm.internal.MapIntrinsics;
-import space.kscience.kmath.expressions.Expression;
-import space.kscience.kmath.expressions.Symbol;
-
-public final class AsmCompiledExpression_45045_0 implements Expression {
- private final Object[] constants;
-
- public final Double invoke(Map arguments) {
- return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
- }
-
- public AsmCompiledExpression_45045_0(Object[] constants) {
- this.constants = constants;
- }
-}
+import space.kscience.kmath.asm.compileToExpression
+import space.kscience.kmath.operations.DoubleField
+"x^3-x+3".parseMath().compileToExpression(DoubleField)
```
-#### Known issues
+… leads to generation of bytecode, which can be decompiled to the following Java class:
-- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
- loading overhead.
-- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
+```java
+import java.util.*;
+import kotlin.jvm.functions.*;
+import space.kscience.kmath.asm.internal.*;
+import space.kscience.kmath.complex.*;
+import space.kscience.kmath.expressions.*;
+
+public final class CompiledExpression_45045_0 implements Expression {
+ private final Object[] constants;
+
+ public Complex invoke(Map arguments) {
+ Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
+ return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
+ }
+}
+```
+
+For `LongRing`, `IntRing`, and `DoubleField` specialization is supported for better performance:
+
+```java
+import java.util.*;
+import space.kscience.kmath.asm.internal.*;
+import space.kscience.kmath.expressions.*;
+
+public final class CompiledExpression_-386104628_0 implements DoubleExpression {
+ private final SymbolIndexer indexer;
+
+ public SymbolIndexer getIndexer() {
+ return this.indexer;
+ }
+
+ public double invoke(double[] arguments) {
+ double var2 = arguments[0];
+ return Math.pow(var2, 3.0D) - var2 + 3.0D;
+ }
+
+ public final Double invoke(Map arguments) {
+ double var2 = ((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue();
+ return Math.pow(var2, 3.0D) - var2 + 3.0D;
+ }
+}
+```
+
+Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
+
+#### Limitations
+
+- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
+- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
### On JS
@@ -100,7 +138,7 @@ An example of emitted Wasm IR in the form of WAT:
)
```
-#### Known issues
+#### Limitations
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt
new file mode 100644
index 000000000..8a8b8797d
--- /dev/null
+++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt
@@ -0,0 +1,177 @@
+/*
+ * Copyright 2018-2021 KMath contributors.
+ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
+ */
+
+package space.kscience.kmath.ast
+
+import space.kscience.kmath.expressions.Expression
+import space.kscience.kmath.expressions.Symbol
+import space.kscience.kmath.misc.UnstableKMathAPI
+import space.kscience.kmath.operations.Algebra
+import space.kscience.kmath.operations.NumericAlgebra
+
+/**
+ * MST form where all values belong to the type [T]. It is optimal for constant folding, dynamic compilation, etc.
+ *
+ * @param T the type.
+ */
+@UnstableKMathAPI
+public sealed interface TypedMst {
+ /**
+ * A node containing a unary operation.
+ *
+ * @param T the type.
+ * @property operation The identifier of operation.
+ * @property function The function implementing this operation.
+ * @property value The argument of this operation.
+ */
+ public class Unary(public val operation: String, public val function: (T) -> T, public val value: TypedMst) :
+ TypedMst {
+ override fun equals(other: Any?): Boolean {
+ if (this === other) return true
+ if (other == null || this::class != other::class) return false
+ other as Unary<*>
+ if (operation != other.operation) return false
+ if (value != other.value) return false
+ return true
+ }
+
+ override fun hashCode(): Int {
+ var result = operation.hashCode()
+ result = 31 * result + value.hashCode()
+ return result
+ }
+
+ override fun toString(): String = "Unary(operation=$operation, value=$value)"
+ }
+
+ /**
+ * A node containing binary operation.
+ *
+ * @param T the type.
+ * @property operation The identifier of operation.
+ * @property function The binary function implementing this operation.
+ * @property left The left operand.
+ * @property right The right operand.
+ */
+ public class Binary(
+ public val operation: String,
+ public val function: Function,
+ public val left: TypedMst,
+ public val right: TypedMst,
+ ) : TypedMst {
+ override fun equals(other: Any?): Boolean {
+ if (this === other) return true
+ if (other == null || this::class != other::class) return false
+
+ other as Binary<*>
+
+ if (operation != other.operation) return false
+ if (left != other.left) return false
+ if (right != other.right) return false
+
+ return true
+ }
+
+ override fun hashCode(): Int {
+ var result = operation.hashCode()
+ result = 31 * result + left.hashCode()
+ result = 31 * result + right.hashCode()
+ return result
+ }
+
+ override fun toString(): String = "Binary(operation=$operation, left=$left, right=$right)"
+ }
+
+ /**
+ * The non-numeric constant value.
+ *
+ * @param T the type.
+ * @property value The held value.
+ * @property number The number this value corresponds.
+ */
+ public class Constant(public val value: T, public val number: Number?) : TypedMst {
+ override fun equals(other: Any?): Boolean {
+ if (this === other) return true
+ if (other == null || this::class != other::class) return false
+ other as Constant<*>
+ if (value != other.value) return false
+ if (number != other.number) return false
+ return true
+ }
+
+ override fun hashCode(): Int {
+ var result = value?.hashCode() ?: 0
+ result = 31 * result + (number?.hashCode() ?: 0)
+ return result
+ }
+
+ override fun toString(): String = "Constant(value=$value, number=$number)"
+ }
+
+ /**
+ * The node containing a variable
+ *
+ * @param T the type.
+ * @property symbol The symbol of the variable.
+ */
+ public class Variable(public val symbol: Symbol) : TypedMst {
+ override fun equals(other: Any?): Boolean {
+ if (this === other) return true
+ if (other == null || this::class != other::class) return false
+ other as Variable<*>
+ if (symbol != other.symbol) return false
+ return true
+ }
+
+ override fun hashCode(): Int = symbol.hashCode()
+ override fun toString(): String = "Variable(symbol=$symbol)"
+ }
+}
+
+/**
+ * Interprets the [TypedMst] node with this [Algebra] and [arguments].
+ */
+@UnstableKMathAPI
+public fun TypedMst.interpret(algebra: Algebra, arguments: Map): T = when (this) {
+ is TypedMst.Unary -> algebra.unaryOperation(operation, interpret(algebra, arguments))
+
+ is TypedMst.Binary -> when {
+ algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null ->
+ algebra.leftSideNumberOperation(operation, left.number, right.interpret(algebra, arguments))
+
+ algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null ->
+ algebra.rightSideNumberOperation(operation, left.interpret(algebra, arguments), right.number)
+
+ else -> algebra.binaryOperation(
+ operation,
+ left.interpret(algebra, arguments),
+ right.interpret(algebra, arguments),
+ )
+ }
+
+ is TypedMst.Constant -> value
+ is TypedMst.Variable -> arguments.getValue(symbol)
+}
+
+/**
+ * Interprets the [TypedMst] node with this [Algebra] and optional [arguments].
+ */
+@UnstableKMathAPI
+public fun TypedMst.interpret(algebra: Algebra, vararg arguments: Pair): T = interpret(
+ algebra,
+ when (arguments.size) {
+ 0 -> emptyMap()
+ 1 -> mapOf(arguments[0])
+ else -> hashMapOf(*arguments)
+ },
+)
+
+/**
+ * Interpret this [TypedMst] node as expression.
+ */
+@UnstableKMathAPI
+public fun TypedMst.toExpression(algebra: Algebra): Expression = Expression { arguments ->
+ interpret(algebra, arguments)
+}
diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt
new file mode 100644
index 000000000..71fb154c9
--- /dev/null
+++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt
@@ -0,0 +1,93 @@
+/*
+ * Copyright 2018-2021 KMath contributors.
+ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
+ */
+
+package space.kscience.kmath.ast
+
+import space.kscience.kmath.expressions.MST
+import space.kscience.kmath.expressions.Symbol
+import space.kscience.kmath.misc.UnstableKMathAPI
+import space.kscience.kmath.operations.Algebra
+import space.kscience.kmath.operations.NumericAlgebra
+import space.kscience.kmath.operations.bindSymbolOrNull
+
+/**
+ * Evaluates constants in given [MST] for given [algebra] at the same time with converting to [TypedMst].
+ */
+@UnstableKMathAPI
+public fun MST.evaluateConstants(algebra: Algebra): TypedMst = when (this) {
+ is MST.Numeric -> TypedMst.Constant(
+ (algebra as? NumericAlgebra)?.number(value) ?: error("Numeric nodes are not supported by $algebra"),
+ value,
+ )
+
+ is MST.Unary -> when (val arg = value.evaluateConstants(algebra)) {
+ is TypedMst.Constant -> {
+ val value = algebra.unaryOperation(
+ operation,
+ arg.value,
+ )
+
+ TypedMst.Constant(value, if (value is Number) value else null)
+ }
+
+ else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg)
+ }
+
+ is MST.Binary -> {
+ val left = left.evaluateConstants(algebra)
+ val right = right.evaluateConstants(algebra)
+
+ when {
+ left is TypedMst.Constant && right is TypedMst.Constant -> {
+ val value = when {
+ algebra is NumericAlgebra && left.number != null -> algebra.leftSideNumberOperation(
+ operation,
+ left.number,
+ right.value,
+ )
+
+ algebra is NumericAlgebra && right.number != null -> algebra.rightSideNumberOperation(
+ operation,
+ left.value,
+ right.number,
+ )
+
+ else -> algebra.binaryOperation(
+ operation,
+ left.value,
+ right.value,
+ )
+ }
+
+ TypedMst.Constant(value, if (value is Number) value else null)
+ }
+
+ algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary(
+ operation,
+ algebra.leftSideNumberOperationFunction(operation),
+ left,
+ right,
+ )
+
+ algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null -> TypedMst.Binary(
+ operation,
+ algebra.rightSideNumberOperationFunction(operation),
+ left,
+ right,
+ )
+
+ else -> TypedMst.Binary(operation, algebra.binaryOperationFunction(operation), left, right)
+ }
+ }
+
+ is Symbol -> {
+ val boundSymbol = algebra.bindSymbolOrNull(this)
+
+ if (boundSymbol != null)
+ TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null)
+ else
+ TypedMst.Variable(this)
+ }
+}
diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt
new file mode 100644
index 000000000..954a0f330
--- /dev/null
+++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2018-2021 KMath contributors.
+ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
+ */
+
+package space.kscience.kmath.ast
+
+import space.kscience.kmath.operations.ByteRing
+import space.kscience.kmath.operations.DoubleField
+import space.kscience.kmath.operations.IntRing
+import space.kscience.kmath.operations.pi
+import kotlin.test.Test
+import kotlin.test.assertEquals
+import kotlin.test.fail
+
+internal class TestFolding {
+ @Test
+ fun foldUnary() = assertEquals(
+ -1,
+ ("-(1)".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldDeepUnary() = assertEquals(
+ 1,
+ ("-(-(1))".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldBinary() = assertEquals(
+ 2,
+ ("1*2".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldDeepBinary() = assertEquals(
+ 10,
+ ("1*2*5".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldSymbol() = assertEquals(
+ DoubleField.pi,
+ ("pi".parseMath().evaluateConstants(DoubleField) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldNumeric() = assertEquals(
+ 42.toByte(),
+ ("42".parseMath().evaluateConstants(ByteRing) as? TypedMst.Constant ?: fail()).value,
+ )
+}
diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt
index a6b6e022b..a8b1aa2e1 100644
--- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt
+++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt
@@ -5,87 +5,48 @@
package space.kscience.kmath.estree
+import space.kscience.kmath.ast.TypedMst
+import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.estree.internal.ESTreeBuilder
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
-import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.internal.estree.BaseExpression
+import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra
-import space.kscience.kmath.operations.NumericAlgebra
-import space.kscience.kmath.operations.bindSymbolOrNull
-
-@PublishedApi
-internal fun MST.compileWith(algebra: Algebra): Expression {
- fun ESTreeBuilder.visit(node: MST): BaseExpression = when (node) {
- is Symbol -> {
- val symbol = algebra.bindSymbolOrNull(node)
-
- if (symbol != null)
- constant(symbol)
- else
- variable(node.identity)
- }
-
- is Numeric -> constant(
- (algebra as? NumericAlgebra)?.number(node.value) ?: error("Numeric nodes are not supported by $this")
- )
-
- is Unary -> when {
- algebra is NumericAlgebra && node.value is Numeric -> constant(
- algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))
- )
-
- else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
- }
-
- is Binary -> when {
- algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
- algebra.binaryOperationFunction(node.operation).invoke(
- algebra.number((node.left as Numeric).value),
- algebra.number((node.right as Numeric).value)
- )
- )
-
- algebra is NumericAlgebra && node.left is Numeric -> call(
- algebra.leftSideNumberOperationFunction(node.operation),
- visit(node.left),
- visit(node.right),
- )
-
- algebra is NumericAlgebra && node.right is Numeric -> call(
- algebra.rightSideNumberOperationFunction(node.operation),
- visit(node.left),
- visit(node.right),
- )
-
- else -> call(
- algebra.binaryOperationFunction(node.operation),
- visit(node.left),
- visit(node.right),
- )
- }
- }
-
- return ESTreeBuilder { visit(this@compileWith) }.instance
-}
/**
* Create a compiled expression with given [MST] and given [algebra].
*/
-public fun MST.compileToExpression(algebra: Algebra): Expression = compileWith(algebra)
+@OptIn(UnstableKMathAPI::class)
+public fun MST.compileToExpression(algebra: Algebra): Expression {
+ val typed = evaluateConstants(algebra)
+ if (typed is TypedMst.Constant) return Expression { typed.value }
+ fun ESTreeBuilder.visit(node: TypedMst): BaseExpression = when (node) {
+ is TypedMst.Constant -> constant(node.value)
+ is TypedMst.Variable -> variable(node.symbol)
+ is TypedMst.Unary -> call(node.function, visit(node.value))
+
+ is TypedMst.Binary -> call(
+ node.function,
+ visit(node.left),
+ visit(node.right),
+ )
+ }
+
+ return ESTreeBuilder { visit(typed) }.instance
+}
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
-public inline fun MST.compile(algebra: Algebra, arguments: Map): T =
- compileToExpression(algebra).invoke(arguments)
-
+public fun MST.compile(algebra: Algebra, arguments: Map): T =
+ compileToExpression(algebra)(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
-public inline fun MST.compile(algebra: Algebra, vararg arguments: Pair): T =
- compileToExpression(algebra).invoke(*arguments)
+public fun MST.compile(algebra: Algebra, vararg arguments: Pair): T =
+ compileToExpression(algebra)(*arguments)
diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt
index 4907d8225..10a6c4a16 100644
--- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt
+++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt
@@ -61,7 +61,7 @@ internal class ESTreeBuilder(val bodyCallback: ESTreeBuilder.() -> BaseExp
}
}
- fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name))
+ fun variable(name: Symbol): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name.identity))
fun call(function: Function, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
optional = false,
diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt
index 96090a633..aacb62f36 100644
--- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt
+++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt
@@ -5,8 +5,8 @@
package space.kscience.kmath.wasm.internal
+import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.*
-import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.internal.binaryen.*
import space.kscience.kmath.internal.webassembly.Instance
import space.kscience.kmath.misc.UnstableKMathAPI
@@ -16,11 +16,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
private val spreader = eval("(obj, args) => obj(...args)")
+@OptIn(UnstableKMathAPI::class)
@Suppress("UnsafeCastFromDynamic")
internal sealed class WasmBuilder>(
protected val binaryenType: Type,
protected val algebra: Algebra,
- protected val target: MST,
+ protected val target: TypedMst,
) {
protected val keys: MutableList = mutableListOf()
protected lateinit var ctx: BinaryenModule
@@ -51,59 +52,41 @@ internal sealed class WasmBuilder>(
Instance(c, js("{}")).exports.executable
}
- protected open fun visitSymbol(node: Symbol): ExpressionRef {
- algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) }
+ protected abstract fun visitNumber(number: Number): ExpressionRef
- var idx = keys.indexOf(node)
+ protected open fun visitVariable(node: TypedMst.Variable): ExpressionRef {
+ var idx = keys.indexOf(node.symbol)
if (idx == -1) {
- keys += node
+ keys += node.symbol
idx = keys.lastIndex
}
return ctx.local.get(idx, binaryenType)
}
- protected abstract fun visitNumeric(node: Numeric): ExpressionRef
-
- protected open fun visitUnary(node: Unary): ExpressionRef =
+ protected open fun visitUnary(node: TypedMst.Unary): ExpressionRef =
error("Unary operation ${node.operation} not defined in $this")
- protected open fun visitBinary(mst: Binary): ExpressionRef =
+ protected open fun visitBinary(mst: TypedMst.Binary): ExpressionRef =
error("Binary operation ${mst.operation} not defined in $this")
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
- protected fun visit(node: MST): ExpressionRef = when (node) {
- is Symbol -> visitSymbol(node)
- is Numeric -> visitNumeric(node)
+ protected fun visit(node: TypedMst): ExpressionRef = when (node) {
+ is TypedMst.Constant -> visitNumber(
+ node.number ?: error("Object constants are not supported by pritimive ASM builder"),
+ )
- is Unary -> when {
- algebra is NumericAlgebra && node.value is Numeric -> visitNumeric(
- Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
- )
-
- else -> visitUnary(node)
- }
-
- is Binary -> when {
- algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric(
- Numeric(
- algebra.binaryOperationFunction(node.operation)
- .invoke(
- algebra.number((node.left as Numeric).value),
- algebra.number((node.right as Numeric).value)
- )
- )
- )
-
- else -> visitBinary(node)
- }
+ is TypedMst.Variable -> visitVariable(node)
+ is TypedMst.Unary -> visitUnary(node)
+ is TypedMst.Binary -> visitBinary(node)
}
}
@UnstableKMathAPI
-internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleField, target) {
+internal class DoubleWasmBuilder(target: TypedMst) :
+ WasmBuilder(f64, DoubleField, target) {
override val instance by lazy {
object : DoubleExpression {
override val indexer = SimpleSymbolIndexer(keys)
@@ -114,9 +97,9 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder): ExpressionRef = when (node.operation) {
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
GroupOps.PLUS_OPERATION -> visit(node.value)
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
@@ -137,7 +120,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder super.visitUnary(node)
}
- override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
+ override fun visitBinary(mst: TypedMst.Binary): ExpressionRef = when (mst.operation) {
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
@@ -148,7 +131,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder(i32, IntRing, target) {
+internal class IntWasmBuilder(target: TypedMst) : WasmBuilder(i32, IntRing, target) {
override val instance by lazy {
object : IntExpression {
override val indexer = SimpleSymbolIndexer(keys)
@@ -157,15 +140,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder(i32
}
}
- override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value.toInt())
+ override fun visitNumber(number: Number) = ctx.i32.const(number.toInt())
- override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
+ override fun visitUnary(node: TypedMst.Unary): ExpressionRef = when (node.operation) {
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
GroupOps.PLUS_OPERATION -> visit(node.value)
else -> super.visitUnary(node)
}
- override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
+ override fun visitBinary(mst: TypedMst.Binary): ExpressionRef = when (mst.operation) {
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt
index 12e6b41af..f9540f9db 100644
--- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt
+++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt
@@ -7,7 +7,8 @@
package space.kscience.kmath.wasm
-import space.kscience.kmath.estree.compileWith
+import space.kscience.kmath.ast.TypedMst
+import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
@@ -21,8 +22,16 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance
+public fun MST.compileToExpression(algebra: IntRing): IntExpression {
+ val typed = evaluateConstants(algebra)
+ return if (typed is TypedMst.Constant) object : IntExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: IntArray): Int = typed.value
+ } else
+ IntWasmBuilder(typed).instance
+}
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -31,7 +40,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBui
*/
@UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map): Int =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
@@ -49,7 +58,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair): I
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: DoubleField): Expression = DoubleWasmBuilder(this).instance
+public fun MST.compileToExpression(algebra: DoubleField): Expression {
+ val typed = evaluateConstants(algebra)
+
+ return if (typed is TypedMst.Constant) object : DoubleExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: DoubleArray): Double = typed.value
+ } else
+ DoubleWasmBuilder(typed).instance
+}
/**
@@ -59,7 +77,7 @@ public fun MST.compileToExpression(algebra: DoubleField): Expression = D
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map): Double =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
@@ -69,4 +87,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map): Do
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair): Double =
- compileToExpression(algebra).invoke(*arguments)
+ compileToExpression(algebra)(*arguments)
diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt
index 8e426622d..73b9c97a7 100644
--- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt
+++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt
@@ -8,10 +8,14 @@
package space.kscience.kmath.asm
import space.kscience.kmath.asm.internal.*
+import space.kscience.kmath.ast.TypedMst
+import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.*
-import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.misc.UnstableKMathAPI
-import space.kscience.kmath.operations.*
+import space.kscience.kmath.operations.Algebra
+import space.kscience.kmath.operations.DoubleField
+import space.kscience.kmath.operations.IntRing
+import space.kscience.kmath.operations.LongRing
/**
* Compiles given MST to an Expression using AST compiler.
@@ -21,102 +25,64 @@ import space.kscience.kmath.operations.*
* @return the compiled expression.
* @author Alexander Nozik
*/
+@OptIn(UnstableKMathAPI::class)
@PublishedApi
internal fun MST.compileWith(type: Class, algebra: Algebra): Expression {
- fun GenericAsmBuilder.variablesVisitor(node: MST): Unit = when (node) {
- is Symbol -> prepareVariable(node.identity)
- is Unary -> variablesVisitor(node.value)
+ val typed = evaluateConstants(algebra)
+ if (typed is TypedMst.Constant) return Expression { typed.value }
- is Binary -> {
+ fun GenericAsmBuilder.variablesVisitor(node: TypedMst): Unit = when (node) {
+ is TypedMst.Unary -> variablesVisitor(node.value)
+
+ is TypedMst.Binary -> {
variablesVisitor(node.left)
variablesVisitor(node.right)
}
- else -> Unit
+ is TypedMst.Variable -> prepareVariable(node.symbol)
+ is TypedMst.Constant -> Unit
}
- fun GenericAsmBuilder.expressionVisitor(node: MST): Unit = when (node) {
- is Symbol -> {
- val symbol = algebra.bindSymbolOrNull(node)
+ fun GenericAsmBuilder.expressionVisitor(node: TypedMst): Unit = when (node) {
+ is TypedMst.Constant -> if (node.number != null)
+ loadNumberConstant(node.number)
+ else
+ loadObjectConstant(node.value)
- if (symbol != null)
- loadObjectConstant(symbol as Any)
- else
- loadVariable(node.identity)
- }
+ is TypedMst.Variable -> loadVariable(node.symbol)
+ is TypedMst.Unary -> buildCall(node.function) { expressionVisitor(node.value) }
- is Numeric -> if (algebra is NumericAlgebra) {
- if (Number::class.java.isAssignableFrom(type))
- loadNumberConstant(algebra.number(node.value) as Number)
- else
- loadObjectConstant(algebra.number(node.value))
- } else
- error("Numeric nodes are not supported by $this")
-
- is Unary -> when {
- algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
- algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)),
- )
-
- else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) }
- }
-
- is Binary -> when {
- algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
- algebra.binaryOperationFunction(node.operation).invoke(
- algebra.number((node.left as Numeric).value),
- algebra.number((node.right as Numeric).value),
- )
- )
-
- algebra is NumericAlgebra && node.left is Numeric -> buildCall(
- algebra.leftSideNumberOperationFunction(node.operation),
- ) {
- expressionVisitor(node.left)
- expressionVisitor(node.right)
- }
-
- algebra is NumericAlgebra && node.right is Numeric -> buildCall(
- algebra.rightSideNumberOperationFunction(node.operation),
- ) {
- expressionVisitor(node.left)
- expressionVisitor(node.right)
- }
-
- else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
- expressionVisitor(node.left)
- expressionVisitor(node.right)
- }
+ is TypedMst.Binary -> buildCall(node.function) {
+ expressionVisitor(node.left)
+ expressionVisitor(node.right)
}
}
return GenericAsmBuilder(
type,
- buildName(this),
- { variablesVisitor(this@compileWith) },
- { expressionVisitor(this@compileWith) },
+ buildName("${typed.hashCode()}_${type.simpleName}"),
+ { variablesVisitor(typed) },
+ { expressionVisitor(typed) },
).instance
}
-
/**
* Create a compiled expression with given [MST] and given [algebra].
*/
public inline fun MST.compileToExpression(algebra: Algebra): Expression =
compileWith(T::class.java, algebra)
-
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun MST.compile(algebra: Algebra, arguments: Map): T =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun MST.compile(algebra: Algebra, vararg arguments: Pair): T =
- compileToExpression(algebra).invoke(*arguments)
+ compileToExpression(algebra)(*arguments)
/**
@@ -125,7 +91,16 @@ public inline fun MST.compile(algebra: Algebra, vararg argu
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance
+public fun MST.compileToExpression(algebra: IntRing): IntExpression {
+ val typed = evaluateConstants(algebra)
+
+ return if (typed is TypedMst.Constant) object : IntExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: IntArray): Int = typed.value
+ } else
+ IntAsmBuilder(typed).instance
+}
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -134,7 +109,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuil
*/
@UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map): Int =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -152,8 +127,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair): I
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance
+public fun MST.compileToExpression(algebra: LongRing): LongExpression {
+ val typed = evaluateConstants(algebra)
+ return if (typed is TypedMst.Constant) object : LongExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: LongArray): Long = typed.value
+ } else
+ LongAsmBuilder(typed).instance
+}
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -162,7 +145,7 @@ public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmB
*/
@UnstableKMathAPI
public fun MST.compile(algebra: LongRing, arguments: Map): Long =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
@@ -181,7 +164,17 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair):
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance
+public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression {
+ val typed = evaluateConstants(algebra)
+
+ return if (typed is TypedMst.Constant) object : DoubleExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: DoubleArray): Double = typed.value
+ } else
+ DoubleAsmBuilder(typed).instance
+}
+
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -190,7 +183,7 @@ public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = Dou
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map): Double =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -199,4 +192,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map): Do
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair): Double =
- compileToExpression(algebra).invoke(*arguments)
+ compileToExpression(algebra)(*arguments)
diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt
index 5eb739956..6cf3d8721 100644
--- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt
+++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt
@@ -56,7 +56,7 @@ internal class GenericAsmBuilder(
/**
* Local variables indices are indices of symbols in this list.
*/
- private val argumentsLocals = mutableListOf()
+ private val argumentsLocals = mutableListOf()
/**
* Subclasses, loads and instantiates [Expression] for given parameters.
@@ -253,10 +253,10 @@ internal class GenericAsmBuilder(
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
* [loadVariable].
*/
- fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
+ fun prepareVariable(name: Symbol): Unit = invokeMethodVisitor.run {
if (name in argumentsLocals) return@run
load(1, MAP_TYPE)
- aconst(name)
+ aconst(name.identity)
invokestatic(
MAP_INTRINSICS_TYPE.internalName,
@@ -280,7 +280,7 @@ internal class GenericAsmBuilder(
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
* with [prepareVariable] first.
*/
- fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
+ fun loadVariable(name: Symbol): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
inline fun buildCall(function: Function, parameters: GenericAsmBuilder.() -> Unit) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt
index bf1f42395..01bad83e5 100644
--- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt
+++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt
@@ -11,6 +11,7 @@ import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type
import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter
+import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
@@ -25,9 +26,9 @@ internal sealed class PrimitiveAsmBuilder>(
classOfT: Class<*>,
protected val classOfTPrimitive: Class<*>,
expressionParent: Class,
- protected val target: MST,
+ protected val target: TypedMst,
) : AsmBuilder() {
- private val className: String = buildName(target)
+ private val className: String = buildName("${target.hashCode()}_${classOfT.simpleName}")
/**
* ASM type for [tType].
@@ -329,63 +330,39 @@ internal sealed class PrimitiveAsmBuilder>(
}
private fun visitVariables(
- node: MST,
+ node: TypedMst,
arrayMode: Boolean,
alreadyLoaded: MutableList = mutableListOf()
): Unit = when (node) {
- is Symbol -> when (node) {
- !in alreadyLoaded -> {
- alreadyLoaded += node
- prepareVariable(node, arrayMode)
- }
- else -> {
- }
- }
+ is TypedMst.Variable -> if (node.symbol !in alreadyLoaded) {
+ alreadyLoaded += node.symbol
+ prepareVariable(node.symbol, arrayMode)
+ } else Unit
- is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
+ is TypedMst.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
- is MST.Binary -> {
+ is TypedMst.Binary -> {
visitVariables(node.left, arrayMode, alreadyLoaded)
visitVariables(node.right, arrayMode, alreadyLoaded)
}
- else -> Unit
+ is TypedMst.Constant -> Unit
}
- private fun visitExpression(node: MST): Unit = when (node) {
- is Symbol -> {
- val symbol = algebra.bindSymbolOrNull(node)
+ private fun visitExpression(node: TypedMst): Unit = when (node) {
+ is TypedMst.Variable -> loadVariable(node.symbol)
- if (symbol != null)
- loadNumberConstant(symbol)
- else
- loadVariable(node)
- }
+ is TypedMst.Constant -> loadNumberConstant(
+ node.number ?: error("Object constants are not supported by pritimive ASM builder"),
+ )
- is MST.Numeric -> loadNumberConstant(algebra.number(node.value))
-
- is MST.Unary -> if (node.value is MST.Numeric)
- loadNumberConstant(
- algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)),
- )
- else
- visitUnary(node)
-
- is MST.Binary -> when {
- node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant(
- algebra.binaryOperationFunction(node.operation)(
- algebra.number((node.left as MST.Numeric).value),
- algebra.number((node.right as MST.Numeric).value),
- ),
- )
-
- else -> visitBinary(node)
- }
+ is TypedMst.Unary -> visitUnary(node)
+ is TypedMst.Binary -> visitBinary(node)
}
- protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value)
+ protected open fun visitUnary(node: TypedMst.Unary) = visitExpression(node.value)
- protected open fun visitBinary(node: MST.Binary) {
+ protected open fun visitBinary(node: TypedMst.Binary) {
visitExpression(node.left)
visitExpression(node.right)
}
@@ -404,14 +381,13 @@ internal sealed class PrimitiveAsmBuilder>(
}
@UnstableKMathAPI
-internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder(
+internal class DoubleAsmBuilder(target: TypedMst) : PrimitiveAsmBuilder(
DoubleField,
java.lang.Double::class.java,
java.lang.Double.TYPE,
DoubleExpression::class.java,
target,
) {
-
private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
MATH_TYPE.internalName,
name,
@@ -434,7 +410,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder) {
super.visitUnary(node)
when (node.operation) {
@@ -459,7 +435,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder) {
super.visitBinary(node)
when (node.operation) {
@@ -479,7 +455,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder) :
PrimitiveAsmBuilder(
IntRing,
Integer::class.java,
@@ -487,7 +463,7 @@ internal class IntAsmBuilder(target: MST) :
IntExpression::class.java,
target
) {
- override fun visitUnary(node: MST.Unary) {
+ override fun visitUnary(node: TypedMst.Unary) {
super.visitUnary(node)
when (node.operation) {
@@ -497,7 +473,7 @@ internal class IntAsmBuilder(target: MST) :
}
}
- override fun visitBinary(node: MST.Binary) {
+ override fun visitBinary(node: TypedMst.Binary) {
super.visitBinary(node)
when (node.operation) {
@@ -510,14 +486,14 @@ internal class IntAsmBuilder(target: MST) :
}
@UnstableKMathAPI
-internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder(
+internal class LongAsmBuilder(target: TypedMst) : PrimitiveAsmBuilder(
LongRing,
java.lang.Long::class.java,
java.lang.Long.TYPE,
LongExpression::class.java,
target,
) {
- override fun visitUnary(node: MST.Unary) {
+ override fun visitUnary(node: TypedMst.Unary) {
super.visitUnary(node)
when (node.operation) {
@@ -527,7 +503,7 @@ internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder) {
super.visitBinary(node)
when (node.operation) {
diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt
index 06e040e93..9e880f4fc 100644
--- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt
+++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt
@@ -55,15 +55,15 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.(
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
/**
- * Creates a class name for [Expression] subclassed to implement [mst] provided.
+ * Creates a class name for [Expression] based with appending [marker] to reduce collisions.
*
* These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
*
* @author Iaroslav Postovalov
*/
-internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
- val name = "space.kscience.kmath.asm.generated.CompiledExpression_${mst.hashCode()}_$collision"
+internal tailrec fun buildName(marker: String, collision: Int = 0): String {
+ val name = "space.kscience.kmath.asm.generated.CompiledExpression_${marker}_$collision"
try {
Class.forName(name)
@@ -71,7 +71,7 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
return name
}
- return buildName(mst, collision + 1)
+ return buildName(marker, collision + 1)
}
@Suppress("FunctionName")
diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
index 5f194f2ea..880cf8421 100644
--- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
+++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
@@ -34,12 +34,12 @@ public abstract class FunctionalExpressionAlgebra>(
override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression =
{ left, right ->
Expression { arguments ->
- algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments))
+ algebra.binaryOperationFunction(operation)(left(arguments), right(arguments))
}
}
override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = { arg ->
- Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) }
+ Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) }
}
}
diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt
index 24e96e845..18226119b 100644
--- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt
+++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt
@@ -24,7 +24,7 @@ public sealed interface MST {
public data class Numeric(val value: Number) : MST
/**
- * A node containing an unary operation.
+ * A node containing a unary operation.
*
* @property operation the identifier of operation.
* @property value the argument of this operation.
@@ -34,7 +34,7 @@ public sealed interface MST {
/**
* A node containing binary operation.
*
- * @property operation the identifier operation.
+ * @property operation the identifier of operation.
* @property left the left operand.
* @property right the right operand.
*/