From cdfb3f3fa335884c1417cbc3c878f5e09bcc955b Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Sat, 5 Feb 2022 04:27:10 +0700 Subject: [PATCH] Add complete constant folding in kmath-ast by introducing TypedMst, some minor changes --- README.md | 6 + kmath-ast/README.md | 63 ++++--- kmath-ast/docs/README-TEMPLATE.md | 114 +++++++---- .../space/kscience/kmath/ast/TypedMst.kt | 177 ++++++++++++++++++ .../kscience/kmath/ast/evaluateConstants.kt | 93 +++++++++ .../space/kscience/kmath/ast/TestFolding.kt | 52 +++++ .../space/kscience/kmath/estree/estree.kt | 89 +++------ .../kmath/estree/internal/ESTreeBuilder.kt | 2 +- .../kmath/wasm/internal/WasmBuilder.kt | 67 +++---- .../kotlin/space/kscience/kmath/wasm/wasm.kt | 30 ++- .../kotlin/space/kscience/kmath/asm/asm.kt | 135 +++++++------ .../kmath/asm/internal/GenericAsmBuilder.kt | 8 +- .../kmath/asm/internal/PrimitiveAsmBuilder.kt | 82 +++----- .../kmath/asm/internal/codegenUtils.kt | 8 +- .../FunctionalExpressionAlgebra.kt | 4 +- .../space/kscience/kmath/expressions/MST.kt | 4 +- 16 files changed, 622 insertions(+), 312 deletions(-) create mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt create mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt create mode 100644 kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt diff --git a/README.md b/README.md index 92260716e..99dd6d00f 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,12 @@ One can still use generic algebras though. > **Maturity**: PROTOTYPE
+* ### [kmath-tensorflow](kmath-tensorflow) +> +> +> **Maturity**: PROTOTYPE +
+ * ### [kmath-tensors](kmath-tensors) > > diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 5e3366881..bedf17486 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -1,6 +1,6 @@ # Module kmath-ast -Performance and visualization extensions to MST API. +Extensions to MST API: transformations, dynamic compilation and visualization. - [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser - [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler @@ -35,6 +35,26 @@ dependencies { } ``` +## Parsing expressions + +In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances. + +Supported literals: +1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`. +2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`. + +Supported binary operators (from the highest precedence to the lowest one): +1. `^` +2. `*`, `/` +3. `+`, `-` + +Supported unary operator: +1. `-`, e. g. `-x` + +Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples: +1. `sin(x)` +2. `add(x, y)` + ## Dynamic expression code generation ### On JVM @@ -42,48 +62,41 @@ dependencies { `kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a special implementation of `Expression` with implemented `invoke` function. -For example, the following builder: +For example, the following code: ```kotlin -import space.kscience.kmath.expressions.Symbol.Companion.x -import space.kscience.kmath.expressions.* -import space.kscience.kmath.operations.* -import space.kscience.kmath.asm.* +import space.kscience.kmath.asm.compileToExpression +import space.kscience.kmath.complex.ComplexField -MstField { x + 2 }.compileToExpression(DoubleField) -``` +"x+2".parseMath().compileToExpression(ComplexField) +``` -... leads to generation of bytecode, which can be decompiled to the following Java class: +… leads to generation of bytecode, which can be decompiled to the following Java class: ```java -package space.kscience.kmath.asm.generated; - import java.util.Map; - import kotlin.jvm.functions.Function2; import space.kscience.kmath.asm.internal.MapIntrinsics; +import space.kscience.kmath.complex.Complex; import space.kscience.kmath.expressions.Expression; import space.kscience.kmath.expressions.Symbol; -public final class AsmCompiledExpression_45045_0 implements Expression { +public final class CompiledExpression_45045_0 implements Expression { private final Object[] constants; - public final Double invoke(Map arguments) { - return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2); - } - - public AsmCompiledExpression_45045_0(Object[] constants) { - this.constants = constants; + public Complex invoke(Map arguments) { + Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x"); + return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]); } } - ``` -#### Known issues +Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually. -- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class - loading overhead. -- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. +#### Limitations + +- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead. +- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders. ### On JS @@ -129,7 +142,7 @@ An example of emitted Wasm IR in the form of WAT: ) ``` -#### Known issues +#### Limitations - ESTree expression compilation uses `eval` which can be unavailable in several environments. - WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/). diff --git a/kmath-ast/docs/README-TEMPLATE.md b/kmath-ast/docs/README-TEMPLATE.md index 9494af63a..e9e22f4d4 100644 --- a/kmath-ast/docs/README-TEMPLATE.md +++ b/kmath-ast/docs/README-TEMPLATE.md @@ -1,11 +1,31 @@ # Module kmath-ast -Performance and visualization extensions to MST API. +Extensions to MST API: transformations, dynamic compilation and visualization. ${features} ${artifact} +## Parsing expressions + +In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances. + +Supported literals: +1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`. +2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`. + +Supported binary operators (from the highest precedence to the lowest one): +1. `^` +2. `*`, `/` +3. `+`, `-` + +Supported unary operator: +1. `-`, e. g. `-x` + +Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples: +1. `sin(x)` +2. `add(x, y)` + ## Dynamic expression code generation ### On JVM @@ -13,48 +33,66 @@ ${artifact} `kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a special implementation of `Expression` with implemented `invoke` function. -For example, the following builder: +For example, the following code: ```kotlin -import space.kscience.kmath.expressions.Symbol.Companion.x -import space.kscience.kmath.expressions.* -import space.kscience.kmath.operations.* -import space.kscience.kmath.asm.* - -MstField { x + 2 }.compileToExpression(DoubleField) -``` - -... leads to generation of bytecode, which can be decompiled to the following Java class: - -```java -package space.kscience.kmath.asm.generated; - -import java.util.Map; - -import kotlin.jvm.functions.Function2; -import space.kscience.kmath.asm.internal.MapIntrinsics; -import space.kscience.kmath.expressions.Expression; -import space.kscience.kmath.expressions.Symbol; - -public final class AsmCompiledExpression_45045_0 implements Expression { - private final Object[] constants; - - public final Double invoke(Map arguments) { - return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2); - } - - public AsmCompiledExpression_45045_0(Object[] constants) { - this.constants = constants; - } -} +import space.kscience.kmath.asm.compileToExpression +import space.kscience.kmath.operations.DoubleField +"x^3-x+3".parseMath().compileToExpression(DoubleField) ``` -#### Known issues +… leads to generation of bytecode, which can be decompiled to the following Java class: -- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class - loading overhead. -- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. +```java +import java.util.*; +import kotlin.jvm.functions.*; +import space.kscience.kmath.asm.internal.*; +import space.kscience.kmath.complex.*; +import space.kscience.kmath.expressions.*; + +public final class CompiledExpression_45045_0 implements Expression { + private final Object[] constants; + + public Complex invoke(Map arguments) { + Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x"); + return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]); + } +} +``` + +For `LongRing`, `IntRing`, and `DoubleField` specialization is supported for better performance: + +```java +import java.util.*; +import space.kscience.kmath.asm.internal.*; +import space.kscience.kmath.expressions.*; + +public final class CompiledExpression_-386104628_0 implements DoubleExpression { + private final SymbolIndexer indexer; + + public SymbolIndexer getIndexer() { + return this.indexer; + } + + public double invoke(double[] arguments) { + double var2 = arguments[0]; + return Math.pow(var2, 3.0D) - var2 + 3.0D; + } + + public final Double invoke(Map arguments) { + double var2 = ((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(); + return Math.pow(var2, 3.0D) - var2 + 3.0D; + } +} +``` + +Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually. + +#### Limitations + +- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead. +- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders. ### On JS @@ -100,7 +138,7 @@ An example of emitted Wasm IR in the form of WAT: ) ``` -#### Known issues +#### Limitations - ESTree expression compilation uses `eval` which can be unavailable in several environments. - WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/). diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt new file mode 100644 index 000000000..8a8b8797d --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt @@ -0,0 +1,177 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.ast + +import space.kscience.kmath.expressions.Expression +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.Algebra +import space.kscience.kmath.operations.NumericAlgebra + +/** + * MST form where all values belong to the type [T]. It is optimal for constant folding, dynamic compilation, etc. + * + * @param T the type. + */ +@UnstableKMathAPI +public sealed interface TypedMst { + /** + * A node containing a unary operation. + * + * @param T the type. + * @property operation The identifier of operation. + * @property function The function implementing this operation. + * @property value The argument of this operation. + */ + public class Unary(public val operation: String, public val function: (T) -> T, public val value: TypedMst) : + TypedMst { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + other as Unary<*> + if (operation != other.operation) return false + if (value != other.value) return false + return true + } + + override fun hashCode(): Int { + var result = operation.hashCode() + result = 31 * result + value.hashCode() + return result + } + + override fun toString(): String = "Unary(operation=$operation, value=$value)" + } + + /** + * A node containing binary operation. + * + * @param T the type. + * @property operation The identifier of operation. + * @property function The binary function implementing this operation. + * @property left The left operand. + * @property right The right operand. + */ + public class Binary( + public val operation: String, + public val function: Function, + public val left: TypedMst, + public val right: TypedMst, + ) : TypedMst { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as Binary<*> + + if (operation != other.operation) return false + if (left != other.left) return false + if (right != other.right) return false + + return true + } + + override fun hashCode(): Int { + var result = operation.hashCode() + result = 31 * result + left.hashCode() + result = 31 * result + right.hashCode() + return result + } + + override fun toString(): String = "Binary(operation=$operation, left=$left, right=$right)" + } + + /** + * The non-numeric constant value. + * + * @param T the type. + * @property value The held value. + * @property number The number this value corresponds. + */ + public class Constant(public val value: T, public val number: Number?) : TypedMst { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + other as Constant<*> + if (value != other.value) return false + if (number != other.number) return false + return true + } + + override fun hashCode(): Int { + var result = value?.hashCode() ?: 0 + result = 31 * result + (number?.hashCode() ?: 0) + return result + } + + override fun toString(): String = "Constant(value=$value, number=$number)" + } + + /** + * The node containing a variable + * + * @param T the type. + * @property symbol The symbol of the variable. + */ + public class Variable(public val symbol: Symbol) : TypedMst { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + other as Variable<*> + if (symbol != other.symbol) return false + return true + } + + override fun hashCode(): Int = symbol.hashCode() + override fun toString(): String = "Variable(symbol=$symbol)" + } +} + +/** + * Interprets the [TypedMst] node with this [Algebra] and [arguments]. + */ +@UnstableKMathAPI +public fun TypedMst.interpret(algebra: Algebra, arguments: Map): T = when (this) { + is TypedMst.Unary -> algebra.unaryOperation(operation, interpret(algebra, arguments)) + + is TypedMst.Binary -> when { + algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> + algebra.leftSideNumberOperation(operation, left.number, right.interpret(algebra, arguments)) + + algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null -> + algebra.rightSideNumberOperation(operation, left.interpret(algebra, arguments), right.number) + + else -> algebra.binaryOperation( + operation, + left.interpret(algebra, arguments), + right.interpret(algebra, arguments), + ) + } + + is TypedMst.Constant -> value + is TypedMst.Variable -> arguments.getValue(symbol) +} + +/** + * Interprets the [TypedMst] node with this [Algebra] and optional [arguments]. + */ +@UnstableKMathAPI +public fun TypedMst.interpret(algebra: Algebra, vararg arguments: Pair): T = interpret( + algebra, + when (arguments.size) { + 0 -> emptyMap() + 1 -> mapOf(arguments[0]) + else -> hashMapOf(*arguments) + }, +) + +/** + * Interpret this [TypedMst] node as expression. + */ +@UnstableKMathAPI +public fun TypedMst.toExpression(algebra: Algebra): Expression = Expression { arguments -> + interpret(algebra, arguments) +} diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt new file mode 100644 index 000000000..71fb154c9 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt @@ -0,0 +1,93 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.ast + +import space.kscience.kmath.expressions.MST +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.Algebra +import space.kscience.kmath.operations.NumericAlgebra +import space.kscience.kmath.operations.bindSymbolOrNull + +/** + * Evaluates constants in given [MST] for given [algebra] at the same time with converting to [TypedMst]. + */ +@UnstableKMathAPI +public fun MST.evaluateConstants(algebra: Algebra): TypedMst = when (this) { + is MST.Numeric -> TypedMst.Constant( + (algebra as? NumericAlgebra)?.number(value) ?: error("Numeric nodes are not supported by $algebra"), + value, + ) + + is MST.Unary -> when (val arg = value.evaluateConstants(algebra)) { + is TypedMst.Constant -> { + val value = algebra.unaryOperation( + operation, + arg.value, + ) + + TypedMst.Constant(value, if (value is Number) value else null) + } + + else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg) + } + + is MST.Binary -> { + val left = left.evaluateConstants(algebra) + val right = right.evaluateConstants(algebra) + + when { + left is TypedMst.Constant && right is TypedMst.Constant -> { + val value = when { + algebra is NumericAlgebra && left.number != null -> algebra.leftSideNumberOperation( + operation, + left.number, + right.value, + ) + + algebra is NumericAlgebra && right.number != null -> algebra.rightSideNumberOperation( + operation, + left.value, + right.number, + ) + + else -> algebra.binaryOperation( + operation, + left.value, + right.value, + ) + } + + TypedMst.Constant(value, if (value is Number) value else null) + } + + algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary( + operation, + algebra.leftSideNumberOperationFunction(operation), + left, + right, + ) + + algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null -> TypedMst.Binary( + operation, + algebra.rightSideNumberOperationFunction(operation), + left, + right, + ) + + else -> TypedMst.Binary(operation, algebra.binaryOperationFunction(operation), left, right) + } + } + + is Symbol -> { + val boundSymbol = algebra.bindSymbolOrNull(this) + + if (boundSymbol != null) + TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null) + else + TypedMst.Variable(this) + } +} diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt new file mode 100644 index 000000000..954a0f330 --- /dev/null +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt @@ -0,0 +1,52 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.ast + +import space.kscience.kmath.operations.ByteRing +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.IntRing +import space.kscience.kmath.operations.pi +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.fail + +internal class TestFolding { + @Test + fun foldUnary() = assertEquals( + -1, + ("-(1)".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value, + ) + + @Test + fun foldDeepUnary() = assertEquals( + 1, + ("-(-(1))".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value, + ) + + @Test + fun foldBinary() = assertEquals( + 2, + ("1*2".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value, + ) + + @Test + fun foldDeepBinary() = assertEquals( + 10, + ("1*2*5".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value, + ) + + @Test + fun foldSymbol() = assertEquals( + DoubleField.pi, + ("pi".parseMath().evaluateConstants(DoubleField) as? TypedMst.Constant ?: fail()).value, + ) + + @Test + fun foldNumeric() = assertEquals( + 42.toByte(), + ("42".parseMath().evaluateConstants(ByteRing) as? TypedMst.Constant ?: fail()).value, + ) +} diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt index a6b6e022b..a8b1aa2e1 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt @@ -5,87 +5,48 @@ package space.kscience.kmath.estree +import space.kscience.kmath.ast.TypedMst +import space.kscience.kmath.ast.evaluateConstants import space.kscience.kmath.estree.internal.ESTreeBuilder import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.MST -import space.kscience.kmath.expressions.MST.* import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.invoke import space.kscience.kmath.internal.estree.BaseExpression +import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.Algebra -import space.kscience.kmath.operations.NumericAlgebra -import space.kscience.kmath.operations.bindSymbolOrNull - -@PublishedApi -internal fun MST.compileWith(algebra: Algebra): Expression { - fun ESTreeBuilder.visit(node: MST): BaseExpression = when (node) { - is Symbol -> { - val symbol = algebra.bindSymbolOrNull(node) - - if (symbol != null) - constant(symbol) - else - variable(node.identity) - } - - is Numeric -> constant( - (algebra as? NumericAlgebra)?.number(node.value) ?: error("Numeric nodes are not supported by $this") - ) - - is Unary -> when { - algebra is NumericAlgebra && node.value is Numeric -> constant( - algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)) - ) - - else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value)) - } - - is Binary -> when { - algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant( - algebra.binaryOperationFunction(node.operation).invoke( - algebra.number((node.left as Numeric).value), - algebra.number((node.right as Numeric).value) - ) - ) - - algebra is NumericAlgebra && node.left is Numeric -> call( - algebra.leftSideNumberOperationFunction(node.operation), - visit(node.left), - visit(node.right), - ) - - algebra is NumericAlgebra && node.right is Numeric -> call( - algebra.rightSideNumberOperationFunction(node.operation), - visit(node.left), - visit(node.right), - ) - - else -> call( - algebra.binaryOperationFunction(node.operation), - visit(node.left), - visit(node.right), - ) - } - } - - return ESTreeBuilder { visit(this@compileWith) }.instance -} /** * Create a compiled expression with given [MST] and given [algebra]. */ -public fun MST.compileToExpression(algebra: Algebra): Expression = compileWith(algebra) +@OptIn(UnstableKMathAPI::class) +public fun MST.compileToExpression(algebra: Algebra): Expression { + val typed = evaluateConstants(algebra) + if (typed is TypedMst.Constant) return Expression { typed.value } + fun ESTreeBuilder.visit(node: TypedMst): BaseExpression = when (node) { + is TypedMst.Constant -> constant(node.value) + is TypedMst.Variable -> variable(node.symbol) + is TypedMst.Unary -> call(node.function, visit(node.value)) + + is TypedMst.Binary -> call( + node.function, + visit(node.left), + visit(node.right), + ) + } + + return ESTreeBuilder { visit(typed) }.instance +} /** * Compile given MST to expression and evaluate it against [arguments] */ -public inline fun MST.compile(algebra: Algebra, arguments: Map): T = - compileToExpression(algebra).invoke(arguments) - +public fun MST.compile(algebra: Algebra, arguments: Map): T = + compileToExpression(algebra)(arguments) /** * Compile given MST to expression and evaluate it against [arguments] */ -public inline fun MST.compile(algebra: Algebra, vararg arguments: Pair): T = - compileToExpression(algebra).invoke(*arguments) +public fun MST.compile(algebra: Algebra, vararg arguments: Pair): T = + compileToExpression(algebra)(*arguments) diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt index 4907d8225..10a6c4a16 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt @@ -61,7 +61,7 @@ internal class ESTreeBuilder(val bodyCallback: ESTreeBuilder.() -> BaseExp } } - fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name)) + fun variable(name: Symbol): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name.identity)) fun call(function: Function, vararg args: BaseExpression): BaseExpression = SimpleCallExpression( optional = false, diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt index 96090a633..aacb62f36 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt @@ -5,8 +5,8 @@ package space.kscience.kmath.wasm.internal +import space.kscience.kmath.ast.TypedMst import space.kscience.kmath.expressions.* -import space.kscience.kmath.expressions.MST.* import space.kscience.kmath.internal.binaryen.* import space.kscience.kmath.internal.webassembly.Instance import space.kscience.kmath.misc.UnstableKMathAPI @@ -16,11 +16,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule private val spreader = eval("(obj, args) => obj(...args)") +@OptIn(UnstableKMathAPI::class) @Suppress("UnsafeCastFromDynamic") internal sealed class WasmBuilder>( protected val binaryenType: Type, protected val algebra: Algebra, - protected val target: MST, + protected val target: TypedMst, ) { protected val keys: MutableList = mutableListOf() protected lateinit var ctx: BinaryenModule @@ -51,59 +52,41 @@ internal sealed class WasmBuilder>( Instance(c, js("{}")).exports.executable } - protected open fun visitSymbol(node: Symbol): ExpressionRef { - algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) } + protected abstract fun visitNumber(number: Number): ExpressionRef - var idx = keys.indexOf(node) + protected open fun visitVariable(node: TypedMst.Variable): ExpressionRef { + var idx = keys.indexOf(node.symbol) if (idx == -1) { - keys += node + keys += node.symbol idx = keys.lastIndex } return ctx.local.get(idx, binaryenType) } - protected abstract fun visitNumeric(node: Numeric): ExpressionRef - - protected open fun visitUnary(node: Unary): ExpressionRef = + protected open fun visitUnary(node: TypedMst.Unary): ExpressionRef = error("Unary operation ${node.operation} not defined in $this") - protected open fun visitBinary(mst: Binary): ExpressionRef = + protected open fun visitBinary(mst: TypedMst.Binary): ExpressionRef = error("Binary operation ${mst.operation} not defined in $this") protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()") - protected fun visit(node: MST): ExpressionRef = when (node) { - is Symbol -> visitSymbol(node) - is Numeric -> visitNumeric(node) + protected fun visit(node: TypedMst): ExpressionRef = when (node) { + is TypedMst.Constant -> visitNumber( + node.number ?: error("Object constants are not supported by pritimive ASM builder"), + ) - is Unary -> when { - algebra is NumericAlgebra && node.value is Numeric -> visitNumeric( - Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))) - ) - - else -> visitUnary(node) - } - - is Binary -> when { - algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric( - Numeric( - algebra.binaryOperationFunction(node.operation) - .invoke( - algebra.number((node.left as Numeric).value), - algebra.number((node.right as Numeric).value) - ) - ) - ) - - else -> visitBinary(node) - } + is TypedMst.Variable -> visitVariable(node) + is TypedMst.Unary -> visitUnary(node) + is TypedMst.Binary -> visitBinary(node) } } @UnstableKMathAPI -internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleField, target) { +internal class DoubleWasmBuilder(target: TypedMst) : + WasmBuilder(f64, DoubleField, target) { override val instance by lazy { object : DoubleExpression { override val indexer = SimpleSymbolIndexer(keys) @@ -114,9 +97,9 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder): ExpressionRef = when (node.operation) { GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value)) GroupOps.PLUS_OPERATION -> visit(node.value) PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value)) @@ -137,7 +120,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder super.visitUnary(node) } - override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { + override fun visitBinary(mst: TypedMst.Binary): ExpressionRef = when (mst.operation) { GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) @@ -148,7 +131,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder(i32, IntRing, target) { +internal class IntWasmBuilder(target: TypedMst) : WasmBuilder(i32, IntRing, target) { override val instance by lazy { object : IntExpression { override val indexer = SimpleSymbolIndexer(keys) @@ -157,15 +140,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder(i32 } } - override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value.toInt()) + override fun visitNumber(number: Number) = ctx.i32.const(number.toInt()) - override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) { + override fun visitUnary(node: TypedMst.Unary): ExpressionRef = when (node.operation) { GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value)) GroupOps.PLUS_OPERATION -> visit(node.value) else -> super.visitUnary(node) } - override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { + override fun visitBinary(mst: TypedMst.Binary): ExpressionRef = when (mst.operation) { GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt index 12e6b41af..f9540f9db 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt @@ -7,7 +7,8 @@ package space.kscience.kmath.wasm -import space.kscience.kmath.estree.compileWith +import space.kscience.kmath.ast.TypedMst +import space.kscience.kmath.ast.evaluateConstants import space.kscience.kmath.expressions.* import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.DoubleField @@ -21,8 +22,16 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder * @author Iaroslav Postovalov */ @UnstableKMathAPI -public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance +public fun MST.compileToExpression(algebra: IntRing): IntExpression { + val typed = evaluateConstants(algebra) + return if (typed is TypedMst.Constant) object : IntExpression { + override val indexer = SimpleSymbolIndexer(emptyList()) + + override fun invoke(arguments: IntArray): Int = typed.value + } else + IntWasmBuilder(typed).instance +} /** * Compile given MST to expression and evaluate it against [arguments]. @@ -31,7 +40,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBui */ @UnstableKMathAPI public fun MST.compile(algebra: IntRing, arguments: Map): Int = - compileToExpression(algebra).invoke(arguments) + compileToExpression(algebra)(arguments) /** @@ -49,7 +58,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair): I * @author Iaroslav Postovalov */ @UnstableKMathAPI -public fun MST.compileToExpression(algebra: DoubleField): Expression = DoubleWasmBuilder(this).instance +public fun MST.compileToExpression(algebra: DoubleField): Expression { + val typed = evaluateConstants(algebra) + + return if (typed is TypedMst.Constant) object : DoubleExpression { + override val indexer = SimpleSymbolIndexer(emptyList()) + + override fun invoke(arguments: DoubleArray): Double = typed.value + } else + DoubleWasmBuilder(typed).instance +} /** @@ -59,7 +77,7 @@ public fun MST.compileToExpression(algebra: DoubleField): Expression = D */ @UnstableKMathAPI public fun MST.compile(algebra: DoubleField, arguments: Map): Double = - compileToExpression(algebra).invoke(arguments) + compileToExpression(algebra)(arguments) /** @@ -69,4 +87,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map): Do */ @UnstableKMathAPI public fun MST.compile(algebra: DoubleField, vararg arguments: Pair): Double = - compileToExpression(algebra).invoke(*arguments) + compileToExpression(algebra)(*arguments) diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt index 8e426622d..73b9c97a7 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt @@ -8,10 +8,14 @@ package space.kscience.kmath.asm import space.kscience.kmath.asm.internal.* +import space.kscience.kmath.ast.TypedMst +import space.kscience.kmath.ast.evaluateConstants import space.kscience.kmath.expressions.* -import space.kscience.kmath.expressions.MST.* import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.* +import space.kscience.kmath.operations.Algebra +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.IntRing +import space.kscience.kmath.operations.LongRing /** * Compiles given MST to an Expression using AST compiler. @@ -21,102 +25,64 @@ import space.kscience.kmath.operations.* * @return the compiled expression. * @author Alexander Nozik */ +@OptIn(UnstableKMathAPI::class) @PublishedApi internal fun MST.compileWith(type: Class, algebra: Algebra): Expression { - fun GenericAsmBuilder.variablesVisitor(node: MST): Unit = when (node) { - is Symbol -> prepareVariable(node.identity) - is Unary -> variablesVisitor(node.value) + val typed = evaluateConstants(algebra) + if (typed is TypedMst.Constant) return Expression { typed.value } - is Binary -> { + fun GenericAsmBuilder.variablesVisitor(node: TypedMst): Unit = when (node) { + is TypedMst.Unary -> variablesVisitor(node.value) + + is TypedMst.Binary -> { variablesVisitor(node.left) variablesVisitor(node.right) } - else -> Unit + is TypedMst.Variable -> prepareVariable(node.symbol) + is TypedMst.Constant -> Unit } - fun GenericAsmBuilder.expressionVisitor(node: MST): Unit = when (node) { - is Symbol -> { - val symbol = algebra.bindSymbolOrNull(node) + fun GenericAsmBuilder.expressionVisitor(node: TypedMst): Unit = when (node) { + is TypedMst.Constant -> if (node.number != null) + loadNumberConstant(node.number) + else + loadObjectConstant(node.value) - if (symbol != null) - loadObjectConstant(symbol as Any) - else - loadVariable(node.identity) - } + is TypedMst.Variable -> loadVariable(node.symbol) + is TypedMst.Unary -> buildCall(node.function) { expressionVisitor(node.value) } - is Numeric -> if (algebra is NumericAlgebra) { - if (Number::class.java.isAssignableFrom(type)) - loadNumberConstant(algebra.number(node.value) as Number) - else - loadObjectConstant(algebra.number(node.value)) - } else - error("Numeric nodes are not supported by $this") - - is Unary -> when { - algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant( - algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)), - ) - - else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) } - } - - is Binary -> when { - algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant( - algebra.binaryOperationFunction(node.operation).invoke( - algebra.number((node.left as Numeric).value), - algebra.number((node.right as Numeric).value), - ) - ) - - algebra is NumericAlgebra && node.left is Numeric -> buildCall( - algebra.leftSideNumberOperationFunction(node.operation), - ) { - expressionVisitor(node.left) - expressionVisitor(node.right) - } - - algebra is NumericAlgebra && node.right is Numeric -> buildCall( - algebra.rightSideNumberOperationFunction(node.operation), - ) { - expressionVisitor(node.left) - expressionVisitor(node.right) - } - - else -> buildCall(algebra.binaryOperationFunction(node.operation)) { - expressionVisitor(node.left) - expressionVisitor(node.right) - } + is TypedMst.Binary -> buildCall(node.function) { + expressionVisitor(node.left) + expressionVisitor(node.right) } } return GenericAsmBuilder( type, - buildName(this), - { variablesVisitor(this@compileWith) }, - { expressionVisitor(this@compileWith) }, + buildName("${typed.hashCode()}_${type.simpleName}"), + { variablesVisitor(typed) }, + { expressionVisitor(typed) }, ).instance } - /** * Create a compiled expression with given [MST] and given [algebra]. */ public inline fun MST.compileToExpression(algebra: Algebra): Expression = compileWith(T::class.java, algebra) - /** * Compile given MST to expression and evaluate it against [arguments] */ public inline fun MST.compile(algebra: Algebra, arguments: Map): T = - compileToExpression(algebra).invoke(arguments) + compileToExpression(algebra)(arguments) /** * Compile given MST to expression and evaluate it against [arguments] */ public inline fun MST.compile(algebra: Algebra, vararg arguments: Pair): T = - compileToExpression(algebra).invoke(*arguments) + compileToExpression(algebra)(*arguments) /** @@ -125,7 +91,16 @@ public inline fun MST.compile(algebra: Algebra, vararg argu * @author Iaroslav Postovalov */ @UnstableKMathAPI -public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance +public fun MST.compileToExpression(algebra: IntRing): IntExpression { + val typed = evaluateConstants(algebra) + + return if (typed is TypedMst.Constant) object : IntExpression { + override val indexer = SimpleSymbolIndexer(emptyList()) + + override fun invoke(arguments: IntArray): Int = typed.value + } else + IntAsmBuilder(typed).instance +} /** * Compile given MST to expression and evaluate it against [arguments]. @@ -134,7 +109,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuil */ @UnstableKMathAPI public fun MST.compile(algebra: IntRing, arguments: Map): Int = - compileToExpression(algebra).invoke(arguments) + compileToExpression(algebra)(arguments) /** * Compile given MST to expression and evaluate it against [arguments]. @@ -152,8 +127,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair): I * @author Iaroslav Postovalov */ @UnstableKMathAPI -public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance +public fun MST.compileToExpression(algebra: LongRing): LongExpression { + val typed = evaluateConstants(algebra) + return if (typed is TypedMst.Constant) object : LongExpression { + override val indexer = SimpleSymbolIndexer(emptyList()) + + override fun invoke(arguments: LongArray): Long = typed.value + } else + LongAsmBuilder(typed).instance +} /** * Compile given MST to expression and evaluate it against [arguments]. @@ -162,7 +145,7 @@ public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmB */ @UnstableKMathAPI public fun MST.compile(algebra: LongRing, arguments: Map): Long = - compileToExpression(algebra).invoke(arguments) + compileToExpression(algebra)(arguments) /** @@ -181,7 +164,17 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair): * @author Iaroslav Postovalov */ @UnstableKMathAPI -public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance +public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression { + val typed = evaluateConstants(algebra) + + return if (typed is TypedMst.Constant) object : DoubleExpression { + override val indexer = SimpleSymbolIndexer(emptyList()) + + override fun invoke(arguments: DoubleArray): Double = typed.value + } else + DoubleAsmBuilder(typed).instance +} + /** * Compile given MST to expression and evaluate it against [arguments]. @@ -190,7 +183,7 @@ public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = Dou */ @UnstableKMathAPI public fun MST.compile(algebra: DoubleField, arguments: Map): Double = - compileToExpression(algebra).invoke(arguments) + compileToExpression(algebra)(arguments) /** * Compile given MST to expression and evaluate it against [arguments]. @@ -199,4 +192,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map): Do */ @UnstableKMathAPI public fun MST.compile(algebra: DoubleField, vararg arguments: Pair): Double = - compileToExpression(algebra).invoke(*arguments) + compileToExpression(algebra)(*arguments) diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt index 5eb739956..6cf3d8721 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt @@ -56,7 +56,7 @@ internal class GenericAsmBuilder( /** * Local variables indices are indices of symbols in this list. */ - private val argumentsLocals = mutableListOf() + private val argumentsLocals = mutableListOf() /** * Subclasses, loads and instantiates [Expression] for given parameters. @@ -253,10 +253,10 @@ internal class GenericAsmBuilder( * Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using * [loadVariable]. */ - fun prepareVariable(name: String): Unit = invokeMethodVisitor.run { + fun prepareVariable(name: Symbol): Unit = invokeMethodVisitor.run { if (name in argumentsLocals) return@run load(1, MAP_TYPE) - aconst(name) + aconst(name.identity) invokestatic( MAP_INTRINSICS_TYPE.internalName, @@ -280,7 +280,7 @@ internal class GenericAsmBuilder( * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored * with [prepareVariable] first. */ - fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType) + fun loadVariable(name: Symbol): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType) inline fun buildCall(function: Function, parameters: GenericAsmBuilder.() -> Unit) { contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt index bf1f42395..01bad83e5 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt @@ -11,6 +11,7 @@ import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Type import org.objectweb.asm.Type.* import org.objectweb.asm.commons.InstructionAdapter +import space.kscience.kmath.ast.TypedMst import space.kscience.kmath.expressions.* import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.* @@ -25,9 +26,9 @@ internal sealed class PrimitiveAsmBuilder>( classOfT: Class<*>, protected val classOfTPrimitive: Class<*>, expressionParent: Class, - protected val target: MST, + protected val target: TypedMst, ) : AsmBuilder() { - private val className: String = buildName(target) + private val className: String = buildName("${target.hashCode()}_${classOfT.simpleName}") /** * ASM type for [tType]. @@ -329,63 +330,39 @@ internal sealed class PrimitiveAsmBuilder>( } private fun visitVariables( - node: MST, + node: TypedMst, arrayMode: Boolean, alreadyLoaded: MutableList = mutableListOf() ): Unit = when (node) { - is Symbol -> when (node) { - !in alreadyLoaded -> { - alreadyLoaded += node - prepareVariable(node, arrayMode) - } - else -> { - } - } + is TypedMst.Variable -> if (node.symbol !in alreadyLoaded) { + alreadyLoaded += node.symbol + prepareVariable(node.symbol, arrayMode) + } else Unit - is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded) + is TypedMst.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded) - is MST.Binary -> { + is TypedMst.Binary -> { visitVariables(node.left, arrayMode, alreadyLoaded) visitVariables(node.right, arrayMode, alreadyLoaded) } - else -> Unit + is TypedMst.Constant -> Unit } - private fun visitExpression(node: MST): Unit = when (node) { - is Symbol -> { - val symbol = algebra.bindSymbolOrNull(node) + private fun visitExpression(node: TypedMst): Unit = when (node) { + is TypedMst.Variable -> loadVariable(node.symbol) - if (symbol != null) - loadNumberConstant(symbol) - else - loadVariable(node) - } + is TypedMst.Constant -> loadNumberConstant( + node.number ?: error("Object constants are not supported by pritimive ASM builder"), + ) - is MST.Numeric -> loadNumberConstant(algebra.number(node.value)) - - is MST.Unary -> if (node.value is MST.Numeric) - loadNumberConstant( - algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)), - ) - else - visitUnary(node) - - is MST.Binary -> when { - node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant( - algebra.binaryOperationFunction(node.operation)( - algebra.number((node.left as MST.Numeric).value), - algebra.number((node.right as MST.Numeric).value), - ), - ) - - else -> visitBinary(node) - } + is TypedMst.Unary -> visitUnary(node) + is TypedMst.Binary -> visitBinary(node) } - protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value) + protected open fun visitUnary(node: TypedMst.Unary) = visitExpression(node.value) - protected open fun visitBinary(node: MST.Binary) { + protected open fun visitBinary(node: TypedMst.Binary) { visitExpression(node.left) visitExpression(node.right) } @@ -404,14 +381,13 @@ internal sealed class PrimitiveAsmBuilder>( } @UnstableKMathAPI -internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder( +internal class DoubleAsmBuilder(target: TypedMst) : PrimitiveAsmBuilder( DoubleField, java.lang.Double::class.java, java.lang.Double.TYPE, DoubleExpression::class.java, target, ) { - private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic( MATH_TYPE.internalName, name, @@ -434,7 +410,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder) { super.visitUnary(node) when (node.operation) { @@ -459,7 +435,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder) { super.visitBinary(node) when (node.operation) { @@ -479,7 +455,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder) : PrimitiveAsmBuilder( IntRing, Integer::class.java, @@ -487,7 +463,7 @@ internal class IntAsmBuilder(target: MST) : IntExpression::class.java, target ) { - override fun visitUnary(node: MST.Unary) { + override fun visitUnary(node: TypedMst.Unary) { super.visitUnary(node) when (node.operation) { @@ -497,7 +473,7 @@ internal class IntAsmBuilder(target: MST) : } } - override fun visitBinary(node: MST.Binary) { + override fun visitBinary(node: TypedMst.Binary) { super.visitBinary(node) when (node.operation) { @@ -510,14 +486,14 @@ internal class IntAsmBuilder(target: MST) : } @UnstableKMathAPI -internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder( +internal class LongAsmBuilder(target: TypedMst) : PrimitiveAsmBuilder( LongRing, java.lang.Long::class.java, java.lang.Long.TYPE, LongExpression::class.java, target, ) { - override fun visitUnary(node: MST.Unary) { + override fun visitUnary(node: TypedMst.Unary) { super.visitUnary(node) when (node.operation) { @@ -527,7 +503,7 @@ internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder) { super.visitBinary(node) when (node.operation) { diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt index 06e040e93..9e880f4fc 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt @@ -55,15 +55,15 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.( internal fun MethodVisitor.label(): Label = Label().also(::visitLabel) /** - * Creates a class name for [Expression] subclassed to implement [mst] provided. + * Creates a class name for [Expression] based with appending [marker] to reduce collisions. * * These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. * * @author Iaroslav Postovalov */ -internal tailrec fun buildName(mst: MST, collision: Int = 0): String { - val name = "space.kscience.kmath.asm.generated.CompiledExpression_${mst.hashCode()}_$collision" +internal tailrec fun buildName(marker: String, collision: Int = 0): String { + val name = "space.kscience.kmath.asm.generated.CompiledExpression_${marker}_$collision" try { Class.forName(name) @@ -71,7 +71,7 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String { return name } - return buildName(mst, collision + 1) + return buildName(marker, collision + 1) } @Suppress("FunctionName") diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 5f194f2ea..880cf8421 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -34,12 +34,12 @@ public abstract class FunctionalExpressionAlgebra>( override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = { left, right -> Expression { arguments -> - algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments)) + algebra.binaryOperationFunction(operation)(left(arguments), right(arguments)) } } override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = { arg -> - Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) } + Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) } } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt index 24e96e845..18226119b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt @@ -24,7 +24,7 @@ public sealed interface MST { public data class Numeric(val value: Number) : MST /** - * A node containing an unary operation. + * A node containing a unary operation. * * @property operation the identifier of operation. * @property value the argument of this operation. @@ -34,7 +34,7 @@ public sealed interface MST { /** * A node containing binary operation. * - * @property operation the identifier operation. + * @property operation the identifier of operation. * @property left the left operand. * @property right the right operand. */