diff --git a/README.md b/README.md
index 92260716e..99dd6d00f 100644
--- a/README.md
+++ b/README.md
@@ -247,6 +247,12 @@ One can still use generic algebras though.
> **Maturity**: PROTOTYPE
+* ### [kmath-tensorflow](kmath-tensorflow)
+>
+>
+> **Maturity**: PROTOTYPE
+
+
* ### [kmath-tensors](kmath-tensors)
>
>
diff --git a/benchmarks/build.gradle.kts b/benchmarks/build.gradle.kts
index 90ec5dfbe..f8d39b9c5 100644
--- a/benchmarks/build.gradle.kts
+++ b/benchmarks/build.gradle.kts
@@ -52,6 +52,8 @@ kotlin {
implementation(project(":kmath-viktor"))
implementation(project(":kmath-jafama"))
implementation(project(":kmath-multik"))
+ implementation(projects.kmath.kmathTensorflow)
+ implementation("org.tensorflow:tensorflow-core-platform:0.4.0")
implementation("org.nd4j:nd4j-native:1.0.0-M1")
// uncomment if your system supports AVX2
// val os = System.getProperty("os.name")
@@ -122,6 +124,11 @@ benchmark {
include("JafamaBenchmark")
}
+ configurations.register("tensorAlgebra") {
+ commonConfiguration()
+ include("TensorAlgebraBenchmark")
+ }
+
configurations.register("viktor") {
commonConfiguration()
include("ViktorBenchmark")
diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt
index 63165baaa..16fd544a8 100644
--- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt
+++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt
@@ -16,6 +16,8 @@ import space.kscience.kmath.linear.linearSpace
import space.kscience.kmath.multik.multikAlgebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.Buffer
+import space.kscience.kmath.tensorflow.produceWithTF
+import space.kscience.kmath.tensors.core.tensorAlgebra
import kotlin.random.Random
@State(Scope.Benchmark)
@@ -39,6 +41,16 @@ internal class DotBenchmark {
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
}
+
+ @Benchmark
+ fun tfDot(blackhole: Blackhole) {
+ blackhole.consume(
+ DoubleField.produceWithTF {
+ matrix1 dot matrix1
+ }
+ )
+ }
+
@Benchmark
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
blackhole.consume(matrix1 dot matrix2)
@@ -59,13 +71,13 @@ internal class DotBenchmark {
blackhole.consume(matrix1 dot matrix2)
}
-// @Benchmark
-// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) {
-// blackhole.consume(matrix1 dot matrix2)
-// }
+ @Benchmark
+ fun tensorDot(blackhole: Blackhole) = with(DoubleField.tensorAlgebra) {
+ blackhole.consume(matrix1 dot matrix2)
+ }
@Benchmark
- fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) {
+ fun multikDot(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
blackhole.consume(matrix1 dot matrix2)
}
diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/TensorAlgebraBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/TensorAlgebraBenchmark.kt
new file mode 100644
index 000000000..38e064e53
--- /dev/null
+++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/TensorAlgebraBenchmark.kt
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2018-2021 KMath contributors.
+ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
+ */
+
+package space.kscience.kmath.benchmarks
+
+import kotlinx.benchmark.Benchmark
+import kotlinx.benchmark.Blackhole
+import kotlinx.benchmark.Scope
+import kotlinx.benchmark.State
+import space.kscience.kmath.linear.linearSpace
+import space.kscience.kmath.linear.matrix
+import space.kscience.kmath.linear.symmetric
+import space.kscience.kmath.operations.DoubleField
+import space.kscience.kmath.tensors.core.tensorAlgebra
+import kotlin.random.Random
+
+@State(Scope.Benchmark)
+internal class TensorAlgebraBenchmark {
+ companion object {
+ private val random = Random(12224)
+ private const val dim = 30
+
+ private val matrix = DoubleField.linearSpace.matrix(dim, dim).symmetric { _, _ -> random.nextDouble() }
+ }
+
+ @Benchmark
+ fun tensorSymEigSvd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
+ blackhole.consume(matrix.symEigSvd(1e-10))
+ }
+
+ @Benchmark
+ fun tensorSymEigJacobi(blackhole: Blackhole) = with(Double.tensorAlgebra) {
+ blackhole.consume(matrix.symEigJacobi(50, 1e-10))
+ }
+}
\ No newline at end of file
diff --git a/build.gradle.kts b/build.gradle.kts
index 1b2d9d7c0..3b48c7328 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -10,7 +10,7 @@ allprojects {
}
group = "space.kscience"
- version = "0.3.0-dev-17"
+ version = "0.3.0-dev-19"
}
subprojects {
diff --git a/kmath-ast/README.md b/kmath-ast/README.md
index 5e3366881..bedf17486 100644
--- a/kmath-ast/README.md
+++ b/kmath-ast/README.md
@@ -1,6 +1,6 @@
# Module kmath-ast
-Performance and visualization extensions to MST API.
+Extensions to MST API: transformations, dynamic compilation and visualization.
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
@@ -35,6 +35,26 @@ dependencies {
}
```
+## Parsing expressions
+
+In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
+
+Supported literals:
+1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
+2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`.
+
+Supported binary operators (from the highest precedence to the lowest one):
+1. `^`
+2. `*`, `/`
+3. `+`, `-`
+
+Supported unary operator:
+1. `-`, e. g. `-x`
+
+Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
+1. `sin(x)`
+2. `add(x, y)`
+
## Dynamic expression code generation
### On JVM
@@ -42,48 +62,41 @@ dependencies {
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression` with implemented `invoke` function.
-For example, the following builder:
+For example, the following code:
```kotlin
-import space.kscience.kmath.expressions.Symbol.Companion.x
-import space.kscience.kmath.expressions.*
-import space.kscience.kmath.operations.*
-import space.kscience.kmath.asm.*
+import space.kscience.kmath.asm.compileToExpression
+import space.kscience.kmath.complex.ComplexField
-MstField { x + 2 }.compileToExpression(DoubleField)
-```
+"x+2".parseMath().compileToExpression(ComplexField)
+```
-... leads to generation of bytecode, which can be decompiled to the following Java class:
+… leads to generation of bytecode, which can be decompiled to the following Java class:
```java
-package space.kscience.kmath.asm.generated;
-
import java.util.Map;
-
import kotlin.jvm.functions.Function2;
import space.kscience.kmath.asm.internal.MapIntrinsics;
+import space.kscience.kmath.complex.Complex;
import space.kscience.kmath.expressions.Expression;
import space.kscience.kmath.expressions.Symbol;
-public final class AsmCompiledExpression_45045_0 implements Expression {
+public final class CompiledExpression_45045_0 implements Expression {
private final Object[] constants;
- public final Double invoke(Map arguments) {
- return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
- }
-
- public AsmCompiledExpression_45045_0(Object[] constants) {
- this.constants = constants;
+ public Complex invoke(Map arguments) {
+ Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
+ return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
}
}
-
```
-#### Known issues
+Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
-- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
- loading overhead.
-- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
+#### Limitations
+
+- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
+- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
### On JS
@@ -129,7 +142,7 @@ An example of emitted Wasm IR in the form of WAT:
)
```
-#### Known issues
+#### Limitations
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
diff --git a/kmath-ast/docs/README-TEMPLATE.md b/kmath-ast/docs/README-TEMPLATE.md
index 9494af63a..e9e22f4d4 100644
--- a/kmath-ast/docs/README-TEMPLATE.md
+++ b/kmath-ast/docs/README-TEMPLATE.md
@@ -1,11 +1,31 @@
# Module kmath-ast
-Performance and visualization extensions to MST API.
+Extensions to MST API: transformations, dynamic compilation and visualization.
${features}
${artifact}
+## Parsing expressions
+
+In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
+
+Supported literals:
+1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
+2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`.
+
+Supported binary operators (from the highest precedence to the lowest one):
+1. `^`
+2. `*`, `/`
+3. `+`, `-`
+
+Supported unary operator:
+1. `-`, e. g. `-x`
+
+Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
+1. `sin(x)`
+2. `add(x, y)`
+
## Dynamic expression code generation
### On JVM
@@ -13,48 +33,66 @@ ${artifact}
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression` with implemented `invoke` function.
-For example, the following builder:
+For example, the following code:
```kotlin
-import space.kscience.kmath.expressions.Symbol.Companion.x
-import space.kscience.kmath.expressions.*
-import space.kscience.kmath.operations.*
-import space.kscience.kmath.asm.*
-
-MstField { x + 2 }.compileToExpression(DoubleField)
-```
-
-... leads to generation of bytecode, which can be decompiled to the following Java class:
-
-```java
-package space.kscience.kmath.asm.generated;
-
-import java.util.Map;
-
-import kotlin.jvm.functions.Function2;
-import space.kscience.kmath.asm.internal.MapIntrinsics;
-import space.kscience.kmath.expressions.Expression;
-import space.kscience.kmath.expressions.Symbol;
-
-public final class AsmCompiledExpression_45045_0 implements Expression {
- private final Object[] constants;
-
- public final Double invoke(Map arguments) {
- return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
- }
-
- public AsmCompiledExpression_45045_0(Object[] constants) {
- this.constants = constants;
- }
-}
+import space.kscience.kmath.asm.compileToExpression
+import space.kscience.kmath.operations.DoubleField
+"x^3-x+3".parseMath().compileToExpression(DoubleField)
```
-#### Known issues
+… leads to generation of bytecode, which can be decompiled to the following Java class:
-- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
- loading overhead.
-- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
+```java
+import java.util.*;
+import kotlin.jvm.functions.*;
+import space.kscience.kmath.asm.internal.*;
+import space.kscience.kmath.complex.*;
+import space.kscience.kmath.expressions.*;
+
+public final class CompiledExpression_45045_0 implements Expression {
+ private final Object[] constants;
+
+ public Complex invoke(Map arguments) {
+ Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
+ return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
+ }
+}
+```
+
+For `LongRing`, `IntRing`, and `DoubleField` specialization is supported for better performance:
+
+```java
+import java.util.*;
+import space.kscience.kmath.asm.internal.*;
+import space.kscience.kmath.expressions.*;
+
+public final class CompiledExpression_-386104628_0 implements DoubleExpression {
+ private final SymbolIndexer indexer;
+
+ public SymbolIndexer getIndexer() {
+ return this.indexer;
+ }
+
+ public double invoke(double[] arguments) {
+ double var2 = arguments[0];
+ return Math.pow(var2, 3.0D) - var2 + 3.0D;
+ }
+
+ public final Double invoke(Map arguments) {
+ double var2 = ((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue();
+ return Math.pow(var2, 3.0D) - var2 + 3.0D;
+ }
+}
+```
+
+Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
+
+#### Limitations
+
+- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
+- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
### On JS
@@ -100,7 +138,7 @@ An example of emitted Wasm IR in the form of WAT:
)
```
-#### Known issues
+#### Limitations
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt
new file mode 100644
index 000000000..8a8b8797d
--- /dev/null
+++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt
@@ -0,0 +1,177 @@
+/*
+ * Copyright 2018-2021 KMath contributors.
+ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
+ */
+
+package space.kscience.kmath.ast
+
+import space.kscience.kmath.expressions.Expression
+import space.kscience.kmath.expressions.Symbol
+import space.kscience.kmath.misc.UnstableKMathAPI
+import space.kscience.kmath.operations.Algebra
+import space.kscience.kmath.operations.NumericAlgebra
+
+/**
+ * MST form where all values belong to the type [T]. It is optimal for constant folding, dynamic compilation, etc.
+ *
+ * @param T the type.
+ */
+@UnstableKMathAPI
+public sealed interface TypedMst {
+ /**
+ * A node containing a unary operation.
+ *
+ * @param T the type.
+ * @property operation The identifier of operation.
+ * @property function The function implementing this operation.
+ * @property value The argument of this operation.
+ */
+ public class Unary(public val operation: String, public val function: (T) -> T, public val value: TypedMst) :
+ TypedMst {
+ override fun equals(other: Any?): Boolean {
+ if (this === other) return true
+ if (other == null || this::class != other::class) return false
+ other as Unary<*>
+ if (operation != other.operation) return false
+ if (value != other.value) return false
+ return true
+ }
+
+ override fun hashCode(): Int {
+ var result = operation.hashCode()
+ result = 31 * result + value.hashCode()
+ return result
+ }
+
+ override fun toString(): String = "Unary(operation=$operation, value=$value)"
+ }
+
+ /**
+ * A node containing binary operation.
+ *
+ * @param T the type.
+ * @property operation The identifier of operation.
+ * @property function The binary function implementing this operation.
+ * @property left The left operand.
+ * @property right The right operand.
+ */
+ public class Binary(
+ public val operation: String,
+ public val function: Function,
+ public val left: TypedMst,
+ public val right: TypedMst,
+ ) : TypedMst {
+ override fun equals(other: Any?): Boolean {
+ if (this === other) return true
+ if (other == null || this::class != other::class) return false
+
+ other as Binary<*>
+
+ if (operation != other.operation) return false
+ if (left != other.left) return false
+ if (right != other.right) return false
+
+ return true
+ }
+
+ override fun hashCode(): Int {
+ var result = operation.hashCode()
+ result = 31 * result + left.hashCode()
+ result = 31 * result + right.hashCode()
+ return result
+ }
+
+ override fun toString(): String = "Binary(operation=$operation, left=$left, right=$right)"
+ }
+
+ /**
+ * The non-numeric constant value.
+ *
+ * @param T the type.
+ * @property value The held value.
+ * @property number The number this value corresponds.
+ */
+ public class Constant(public val value: T, public val number: Number?) : TypedMst {
+ override fun equals(other: Any?): Boolean {
+ if (this === other) return true
+ if (other == null || this::class != other::class) return false
+ other as Constant<*>
+ if (value != other.value) return false
+ if (number != other.number) return false
+ return true
+ }
+
+ override fun hashCode(): Int {
+ var result = value?.hashCode() ?: 0
+ result = 31 * result + (number?.hashCode() ?: 0)
+ return result
+ }
+
+ override fun toString(): String = "Constant(value=$value, number=$number)"
+ }
+
+ /**
+ * The node containing a variable
+ *
+ * @param T the type.
+ * @property symbol The symbol of the variable.
+ */
+ public class Variable(public val symbol: Symbol) : TypedMst {
+ override fun equals(other: Any?): Boolean {
+ if (this === other) return true
+ if (other == null || this::class != other::class) return false
+ other as Variable<*>
+ if (symbol != other.symbol) return false
+ return true
+ }
+
+ override fun hashCode(): Int = symbol.hashCode()
+ override fun toString(): String = "Variable(symbol=$symbol)"
+ }
+}
+
+/**
+ * Interprets the [TypedMst] node with this [Algebra] and [arguments].
+ */
+@UnstableKMathAPI
+public fun TypedMst.interpret(algebra: Algebra, arguments: Map): T = when (this) {
+ is TypedMst.Unary -> algebra.unaryOperation(operation, interpret(algebra, arguments))
+
+ is TypedMst.Binary -> when {
+ algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null ->
+ algebra.leftSideNumberOperation(operation, left.number, right.interpret(algebra, arguments))
+
+ algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null ->
+ algebra.rightSideNumberOperation(operation, left.interpret(algebra, arguments), right.number)
+
+ else -> algebra.binaryOperation(
+ operation,
+ left.interpret(algebra, arguments),
+ right.interpret(algebra, arguments),
+ )
+ }
+
+ is TypedMst.Constant -> value
+ is TypedMst.Variable -> arguments.getValue(symbol)
+}
+
+/**
+ * Interprets the [TypedMst] node with this [Algebra] and optional [arguments].
+ */
+@UnstableKMathAPI
+public fun TypedMst.interpret(algebra: Algebra, vararg arguments: Pair): T = interpret(
+ algebra,
+ when (arguments.size) {
+ 0 -> emptyMap()
+ 1 -> mapOf(arguments[0])
+ else -> hashMapOf(*arguments)
+ },
+)
+
+/**
+ * Interpret this [TypedMst] node as expression.
+ */
+@UnstableKMathAPI
+public fun TypedMst.toExpression(algebra: Algebra): Expression = Expression { arguments ->
+ interpret(algebra, arguments)
+}
diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt
new file mode 100644
index 000000000..71fb154c9
--- /dev/null
+++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt
@@ -0,0 +1,93 @@
+/*
+ * Copyright 2018-2021 KMath contributors.
+ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
+ */
+
+package space.kscience.kmath.ast
+
+import space.kscience.kmath.expressions.MST
+import space.kscience.kmath.expressions.Symbol
+import space.kscience.kmath.misc.UnstableKMathAPI
+import space.kscience.kmath.operations.Algebra
+import space.kscience.kmath.operations.NumericAlgebra
+import space.kscience.kmath.operations.bindSymbolOrNull
+
+/**
+ * Evaluates constants in given [MST] for given [algebra] at the same time with converting to [TypedMst].
+ */
+@UnstableKMathAPI
+public fun MST.evaluateConstants(algebra: Algebra): TypedMst = when (this) {
+ is MST.Numeric -> TypedMst.Constant(
+ (algebra as? NumericAlgebra)?.number(value) ?: error("Numeric nodes are not supported by $algebra"),
+ value,
+ )
+
+ is MST.Unary -> when (val arg = value.evaluateConstants(algebra)) {
+ is TypedMst.Constant -> {
+ val value = algebra.unaryOperation(
+ operation,
+ arg.value,
+ )
+
+ TypedMst.Constant(value, if (value is Number) value else null)
+ }
+
+ else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg)
+ }
+
+ is MST.Binary -> {
+ val left = left.evaluateConstants(algebra)
+ val right = right.evaluateConstants(algebra)
+
+ when {
+ left is TypedMst.Constant && right is TypedMst.Constant -> {
+ val value = when {
+ algebra is NumericAlgebra && left.number != null -> algebra.leftSideNumberOperation(
+ operation,
+ left.number,
+ right.value,
+ )
+
+ algebra is NumericAlgebra && right.number != null -> algebra.rightSideNumberOperation(
+ operation,
+ left.value,
+ right.number,
+ )
+
+ else -> algebra.binaryOperation(
+ operation,
+ left.value,
+ right.value,
+ )
+ }
+
+ TypedMst.Constant(value, if (value is Number) value else null)
+ }
+
+ algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary(
+ operation,
+ algebra.leftSideNumberOperationFunction(operation),
+ left,
+ right,
+ )
+
+ algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null -> TypedMst.Binary(
+ operation,
+ algebra.rightSideNumberOperationFunction(operation),
+ left,
+ right,
+ )
+
+ else -> TypedMst.Binary(operation, algebra.binaryOperationFunction(operation), left, right)
+ }
+ }
+
+ is Symbol -> {
+ val boundSymbol = algebra.bindSymbolOrNull(this)
+
+ if (boundSymbol != null)
+ TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null)
+ else
+ TypedMst.Variable(this)
+ }
+}
diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt
new file mode 100644
index 000000000..954a0f330
--- /dev/null
+++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestFolding.kt
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2018-2021 KMath contributors.
+ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
+ */
+
+package space.kscience.kmath.ast
+
+import space.kscience.kmath.operations.ByteRing
+import space.kscience.kmath.operations.DoubleField
+import space.kscience.kmath.operations.IntRing
+import space.kscience.kmath.operations.pi
+import kotlin.test.Test
+import kotlin.test.assertEquals
+import kotlin.test.fail
+
+internal class TestFolding {
+ @Test
+ fun foldUnary() = assertEquals(
+ -1,
+ ("-(1)".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldDeepUnary() = assertEquals(
+ 1,
+ ("-(-(1))".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldBinary() = assertEquals(
+ 2,
+ ("1*2".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldDeepBinary() = assertEquals(
+ 10,
+ ("1*2*5".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldSymbol() = assertEquals(
+ DoubleField.pi,
+ ("pi".parseMath().evaluateConstants(DoubleField) as? TypedMst.Constant ?: fail()).value,
+ )
+
+ @Test
+ fun foldNumeric() = assertEquals(
+ 42.toByte(),
+ ("42".parseMath().evaluateConstants(ByteRing) as? TypedMst.Constant ?: fail()).value,
+ )
+}
diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt
index b838245e1..d0c3a789e 100644
--- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt
+++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt
@@ -7,7 +7,7 @@ package space.kscience.kmath.ast
import space.kscience.kmath.complex.Complex
import space.kscience.kmath.complex.ComplexField
-import space.kscience.kmath.expressions.evaluate
+import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test
@@ -17,14 +17,14 @@ internal class TestParser {
@Test
fun evaluateParsedMst() {
val mst = "2+2*(2+2)".parseMath()
- val res = ComplexField.evaluate(mst)
+ val res = mst.interpret(ComplexField)
assertEquals(Complex(10.0, 0.0), res)
}
@Test
fun evaluateMstSymbol() {
val mst = "i".parseMath()
- val res = ComplexField.evaluate(mst)
+ val res = mst.interpret(ComplexField)
assertEquals(ComplexField.i, res)
}
@@ -32,7 +32,7 @@ internal class TestParser {
@Test
fun evaluateMstUnary() {
val mst = "sin(0)".parseMath()
- val res = DoubleField.evaluate(mst)
+ val res = mst.interpret(DoubleField)
assertEquals(0.0, res)
}
@@ -53,7 +53,7 @@ internal class TestParser {
}
val mst = "magic(a, b)".parseMath()
- val res = magicalAlgebra.evaluate(mst)
+ val res = mst.interpret(magicalAlgebra)
assertEquals("a ★ b", res)
}
}
diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt
index bb6bb3ce1..42cf5ce58 100644
--- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt
+++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt
@@ -5,35 +5,35 @@
package space.kscience.kmath.ast
-import space.kscience.kmath.expressions.evaluate
+import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestParserPrecedence {
@Test
- fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
+ fun test1(): Unit = assertEquals(6.0, "2*2+2".parseMath().interpret(f))
@Test
- fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
+ fun test2(): Unit = assertEquals(6.0, "2+2*2".parseMath().interpret(f))
@Test
- fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
+ fun test3(): Unit = assertEquals(10.0, "2^3+2".parseMath().interpret(f))
@Test
- fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
+ fun test4(): Unit = assertEquals(10.0, "2+2^3".parseMath().interpret(f))
@Test
- fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
+ fun test5(): Unit = assertEquals(16.0, "2^3*2".parseMath().interpret(f))
@Test
- fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
+ fun test6(): Unit = assertEquals(16.0, "2*2^3".parseMath().interpret(f))
@Test
- fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
+ fun test7(): Unit = assertEquals(18.0, "2+2^3*2".parseMath().interpret(f))
@Test
- fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
+ fun test8(): Unit = assertEquals(18.0, "2*2^3+2".parseMath().interpret(f))
private companion object {
private val f = DoubleField
diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt
index a6b6e022b..a8b1aa2e1 100644
--- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt
+++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt
@@ -5,87 +5,48 @@
package space.kscience.kmath.estree
+import space.kscience.kmath.ast.TypedMst
+import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.estree.internal.ESTreeBuilder
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
-import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.internal.estree.BaseExpression
+import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra
-import space.kscience.kmath.operations.NumericAlgebra
-import space.kscience.kmath.operations.bindSymbolOrNull
-
-@PublishedApi
-internal fun MST.compileWith(algebra: Algebra): Expression {
- fun ESTreeBuilder.visit(node: MST): BaseExpression = when (node) {
- is Symbol -> {
- val symbol = algebra.bindSymbolOrNull(node)
-
- if (symbol != null)
- constant(symbol)
- else
- variable(node.identity)
- }
-
- is Numeric -> constant(
- (algebra as? NumericAlgebra)?.number(node.value) ?: error("Numeric nodes are not supported by $this")
- )
-
- is Unary -> when {
- algebra is NumericAlgebra && node.value is Numeric -> constant(
- algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))
- )
-
- else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
- }
-
- is Binary -> when {
- algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
- algebra.binaryOperationFunction(node.operation).invoke(
- algebra.number((node.left as Numeric).value),
- algebra.number((node.right as Numeric).value)
- )
- )
-
- algebra is NumericAlgebra && node.left is Numeric -> call(
- algebra.leftSideNumberOperationFunction(node.operation),
- visit(node.left),
- visit(node.right),
- )
-
- algebra is NumericAlgebra && node.right is Numeric -> call(
- algebra.rightSideNumberOperationFunction(node.operation),
- visit(node.left),
- visit(node.right),
- )
-
- else -> call(
- algebra.binaryOperationFunction(node.operation),
- visit(node.left),
- visit(node.right),
- )
- }
- }
-
- return ESTreeBuilder { visit(this@compileWith) }.instance
-}
/**
* Create a compiled expression with given [MST] and given [algebra].
*/
-public fun MST.compileToExpression(algebra: Algebra): Expression = compileWith(algebra)
+@OptIn(UnstableKMathAPI::class)
+public fun MST.compileToExpression(algebra: Algebra): Expression {
+ val typed = evaluateConstants(algebra)
+ if (typed is TypedMst.Constant) return Expression { typed.value }
+ fun ESTreeBuilder.visit(node: TypedMst): BaseExpression = when (node) {
+ is TypedMst.Constant -> constant(node.value)
+ is TypedMst.Variable -> variable(node.symbol)
+ is TypedMst.Unary -> call(node.function, visit(node.value))
+
+ is TypedMst.Binary -> call(
+ node.function,
+ visit(node.left),
+ visit(node.right),
+ )
+ }
+
+ return ESTreeBuilder { visit(typed) }.instance
+}
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
-public inline fun MST.compile(algebra: Algebra, arguments: Map): T =
- compileToExpression(algebra).invoke(arguments)
-
+public fun MST.compile(algebra: Algebra, arguments: Map): T =
+ compileToExpression(algebra)(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
-public inline fun MST.compile(algebra: Algebra, vararg arguments: Pair): T =
- compileToExpression(algebra).invoke(*arguments)
+public fun MST.compile(algebra: Algebra, vararg arguments: Pair): T =
+ compileToExpression(algebra)(*arguments)
diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt
index 4907d8225..10a6c4a16 100644
--- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt
+++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt
@@ -61,7 +61,7 @@ internal class ESTreeBuilder(val bodyCallback: ESTreeBuilder.() -> BaseExp
}
}
- fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name))
+ fun variable(name: Symbol): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name.identity))
fun call(function: Function, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
optional = false,
diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt
index 96090a633..aacb62f36 100644
--- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt
+++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt
@@ -5,8 +5,8 @@
package space.kscience.kmath.wasm.internal
+import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.*
-import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.internal.binaryen.*
import space.kscience.kmath.internal.webassembly.Instance
import space.kscience.kmath.misc.UnstableKMathAPI
@@ -16,11 +16,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
private val spreader = eval("(obj, args) => obj(...args)")
+@OptIn(UnstableKMathAPI::class)
@Suppress("UnsafeCastFromDynamic")
internal sealed class WasmBuilder>(
protected val binaryenType: Type,
protected val algebra: Algebra,
- protected val target: MST,
+ protected val target: TypedMst,
) {
protected val keys: MutableList = mutableListOf()
protected lateinit var ctx: BinaryenModule
@@ -51,59 +52,41 @@ internal sealed class WasmBuilder>(
Instance(c, js("{}")).exports.executable
}
- protected open fun visitSymbol(node: Symbol): ExpressionRef {
- algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) }
+ protected abstract fun visitNumber(number: Number): ExpressionRef
- var idx = keys.indexOf(node)
+ protected open fun visitVariable(node: TypedMst.Variable): ExpressionRef {
+ var idx = keys.indexOf(node.symbol)
if (idx == -1) {
- keys += node
+ keys += node.symbol
idx = keys.lastIndex
}
return ctx.local.get(idx, binaryenType)
}
- protected abstract fun visitNumeric(node: Numeric): ExpressionRef
-
- protected open fun visitUnary(node: Unary): ExpressionRef =
+ protected open fun visitUnary(node: TypedMst.Unary): ExpressionRef =
error("Unary operation ${node.operation} not defined in $this")
- protected open fun visitBinary(mst: Binary): ExpressionRef =
+ protected open fun visitBinary(mst: TypedMst.Binary): ExpressionRef =
error("Binary operation ${mst.operation} not defined in $this")
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
- protected fun visit(node: MST): ExpressionRef = when (node) {
- is Symbol -> visitSymbol(node)
- is Numeric -> visitNumeric(node)
+ protected fun visit(node: TypedMst): ExpressionRef = when (node) {
+ is TypedMst.Constant -> visitNumber(
+ node.number ?: error("Object constants are not supported by pritimive ASM builder"),
+ )
- is Unary -> when {
- algebra is NumericAlgebra && node.value is Numeric -> visitNumeric(
- Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
- )
-
- else -> visitUnary(node)
- }
-
- is Binary -> when {
- algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric(
- Numeric(
- algebra.binaryOperationFunction(node.operation)
- .invoke(
- algebra.number((node.left as Numeric).value),
- algebra.number((node.right as Numeric).value)
- )
- )
- )
-
- else -> visitBinary(node)
- }
+ is TypedMst.Variable -> visitVariable(node)
+ is TypedMst.Unary -> visitUnary(node)
+ is TypedMst.Binary -> visitBinary(node)
}
}
@UnstableKMathAPI
-internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleField, target) {
+internal class DoubleWasmBuilder(target: TypedMst) :
+ WasmBuilder(f64, DoubleField, target) {
override val instance by lazy {
object : DoubleExpression {
override val indexer = SimpleSymbolIndexer(keys)
@@ -114,9 +97,9 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder): ExpressionRef = when (node.operation) {
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
GroupOps.PLUS_OPERATION -> visit(node.value)
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
@@ -137,7 +120,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder super.visitUnary(node)
}
- override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
+ override fun visitBinary(mst: TypedMst.Binary): ExpressionRef = when (mst.operation) {
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
@@ -148,7 +131,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder(i32, IntRing, target) {
+internal class IntWasmBuilder(target: TypedMst) : WasmBuilder(i32, IntRing, target) {
override val instance by lazy {
object : IntExpression {
override val indexer = SimpleSymbolIndexer(keys)
@@ -157,15 +140,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder(i32
}
}
- override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value.toInt())
+ override fun visitNumber(number: Number) = ctx.i32.const(number.toInt())
- override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
+ override fun visitUnary(node: TypedMst.Unary): ExpressionRef = when (node.operation) {
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
GroupOps.PLUS_OPERATION -> visit(node.value)
else -> super.visitUnary(node)
}
- override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
+ override fun visitBinary(mst: TypedMst.Binary): ExpressionRef = when (mst.operation) {
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt
index 12e6b41af..f9540f9db 100644
--- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt
+++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt
@@ -7,7 +7,8 @@
package space.kscience.kmath.wasm
-import space.kscience.kmath.estree.compileWith
+import space.kscience.kmath.ast.TypedMst
+import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
@@ -21,8 +22,16 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance
+public fun MST.compileToExpression(algebra: IntRing): IntExpression {
+ val typed = evaluateConstants(algebra)
+ return if (typed is TypedMst.Constant) object : IntExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: IntArray): Int = typed.value
+ } else
+ IntWasmBuilder(typed).instance
+}
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -31,7 +40,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBui
*/
@UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map): Int =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
@@ -49,7 +58,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair): I
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: DoubleField): Expression = DoubleWasmBuilder(this).instance
+public fun MST.compileToExpression(algebra: DoubleField): Expression {
+ val typed = evaluateConstants(algebra)
+
+ return if (typed is TypedMst.Constant) object : DoubleExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: DoubleArray): Double = typed.value
+ } else
+ DoubleWasmBuilder(typed).instance
+}
/**
@@ -59,7 +77,7 @@ public fun MST.compileToExpression(algebra: DoubleField): Expression = D
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map): Double =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
@@ -69,4 +87,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map): Do
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair): Double =
- compileToExpression(algebra).invoke(*arguments)
+ compileToExpression(algebra)(*arguments)
diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt
index 8e426622d..73b9c97a7 100644
--- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt
+++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt
@@ -8,10 +8,14 @@
package space.kscience.kmath.asm
import space.kscience.kmath.asm.internal.*
+import space.kscience.kmath.ast.TypedMst
+import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.*
-import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.misc.UnstableKMathAPI
-import space.kscience.kmath.operations.*
+import space.kscience.kmath.operations.Algebra
+import space.kscience.kmath.operations.DoubleField
+import space.kscience.kmath.operations.IntRing
+import space.kscience.kmath.operations.LongRing
/**
* Compiles given MST to an Expression using AST compiler.
@@ -21,102 +25,64 @@ import space.kscience.kmath.operations.*
* @return the compiled expression.
* @author Alexander Nozik
*/
+@OptIn(UnstableKMathAPI::class)
@PublishedApi
internal fun MST.compileWith(type: Class, algebra: Algebra): Expression {
- fun GenericAsmBuilder.variablesVisitor(node: MST): Unit = when (node) {
- is Symbol -> prepareVariable(node.identity)
- is Unary -> variablesVisitor(node.value)
+ val typed = evaluateConstants(algebra)
+ if (typed is TypedMst.Constant) return Expression { typed.value }
- is Binary -> {
+ fun GenericAsmBuilder.variablesVisitor(node: TypedMst): Unit = when (node) {
+ is TypedMst.Unary -> variablesVisitor(node.value)
+
+ is TypedMst.Binary -> {
variablesVisitor(node.left)
variablesVisitor(node.right)
}
- else -> Unit
+ is TypedMst.Variable -> prepareVariable(node.symbol)
+ is TypedMst.Constant -> Unit
}
- fun GenericAsmBuilder.expressionVisitor(node: MST): Unit = when (node) {
- is Symbol -> {
- val symbol = algebra.bindSymbolOrNull(node)
+ fun GenericAsmBuilder.expressionVisitor(node: TypedMst): Unit = when (node) {
+ is TypedMst.Constant -> if (node.number != null)
+ loadNumberConstant(node.number)
+ else
+ loadObjectConstant(node.value)
- if (symbol != null)
- loadObjectConstant(symbol as Any)
- else
- loadVariable(node.identity)
- }
+ is TypedMst.Variable -> loadVariable(node.symbol)
+ is TypedMst.Unary -> buildCall(node.function) { expressionVisitor(node.value) }
- is Numeric -> if (algebra is NumericAlgebra) {
- if (Number::class.java.isAssignableFrom(type))
- loadNumberConstant(algebra.number(node.value) as Number)
- else
- loadObjectConstant(algebra.number(node.value))
- } else
- error("Numeric nodes are not supported by $this")
-
- is Unary -> when {
- algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
- algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)),
- )
-
- else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) }
- }
-
- is Binary -> when {
- algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
- algebra.binaryOperationFunction(node.operation).invoke(
- algebra.number((node.left as Numeric).value),
- algebra.number((node.right as Numeric).value),
- )
- )
-
- algebra is NumericAlgebra && node.left is Numeric -> buildCall(
- algebra.leftSideNumberOperationFunction(node.operation),
- ) {
- expressionVisitor(node.left)
- expressionVisitor(node.right)
- }
-
- algebra is NumericAlgebra && node.right is Numeric -> buildCall(
- algebra.rightSideNumberOperationFunction(node.operation),
- ) {
- expressionVisitor(node.left)
- expressionVisitor(node.right)
- }
-
- else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
- expressionVisitor(node.left)
- expressionVisitor(node.right)
- }
+ is TypedMst.Binary -> buildCall(node.function) {
+ expressionVisitor(node.left)
+ expressionVisitor(node.right)
}
}
return GenericAsmBuilder(
type,
- buildName(this),
- { variablesVisitor(this@compileWith) },
- { expressionVisitor(this@compileWith) },
+ buildName("${typed.hashCode()}_${type.simpleName}"),
+ { variablesVisitor(typed) },
+ { expressionVisitor(typed) },
).instance
}
-
/**
* Create a compiled expression with given [MST] and given [algebra].
*/
public inline fun MST.compileToExpression(algebra: Algebra): Expression =
compileWith(T::class.java, algebra)
-
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun MST.compile(algebra: Algebra, arguments: Map): T =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun MST.compile(algebra: Algebra, vararg arguments: Pair): T =
- compileToExpression(algebra).invoke(*arguments)
+ compileToExpression(algebra)(*arguments)
/**
@@ -125,7 +91,16 @@ public inline fun MST.compile(algebra: Algebra, vararg argu
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance
+public fun MST.compileToExpression(algebra: IntRing): IntExpression {
+ val typed = evaluateConstants(algebra)
+
+ return if (typed is TypedMst.Constant) object : IntExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: IntArray): Int = typed.value
+ } else
+ IntAsmBuilder(typed).instance
+}
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -134,7 +109,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuil
*/
@UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map): Int =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -152,8 +127,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair): I
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance
+public fun MST.compileToExpression(algebra: LongRing): LongExpression {
+ val typed = evaluateConstants(algebra)
+ return if (typed is TypedMst.Constant) object : LongExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: LongArray): Long = typed.value
+ } else
+ LongAsmBuilder(typed).instance
+}
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -162,7 +145,7 @@ public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmB
*/
@UnstableKMathAPI
public fun MST.compile(algebra: LongRing, arguments: Map): Long =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
@@ -181,7 +164,17 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair):
* @author Iaroslav Postovalov
*/
@UnstableKMathAPI
-public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance
+public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression {
+ val typed = evaluateConstants(algebra)
+
+ return if (typed is TypedMst.Constant) object : DoubleExpression {
+ override val indexer = SimpleSymbolIndexer(emptyList())
+
+ override fun invoke(arguments: DoubleArray): Double = typed.value
+ } else
+ DoubleAsmBuilder(typed).instance
+}
+
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -190,7 +183,7 @@ public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = Dou
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map): Double =
- compileToExpression(algebra).invoke(arguments)
+ compileToExpression(algebra)(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments].
@@ -199,4 +192,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map): Do
*/
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair): Double =
- compileToExpression(algebra).invoke(*arguments)
+ compileToExpression(algebra)(*arguments)
diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt
index 5eb739956..6cf3d8721 100644
--- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt
+++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt
@@ -56,7 +56,7 @@ internal class GenericAsmBuilder(
/**
* Local variables indices are indices of symbols in this list.
*/
- private val argumentsLocals = mutableListOf()
+ private val argumentsLocals = mutableListOf()
/**
* Subclasses, loads and instantiates [Expression] for given parameters.
@@ -253,10 +253,10 @@ internal class GenericAsmBuilder(
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
* [loadVariable].
*/
- fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
+ fun prepareVariable(name: Symbol): Unit = invokeMethodVisitor.run {
if (name in argumentsLocals) return@run
load(1, MAP_TYPE)
- aconst(name)
+ aconst(name.identity)
invokestatic(
MAP_INTRINSICS_TYPE.internalName,
@@ -280,7 +280,7 @@ internal class GenericAsmBuilder(
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
* with [prepareVariable] first.
*/
- fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
+ fun loadVariable(name: Symbol): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
inline fun buildCall(function: Function, parameters: GenericAsmBuilder.() -> Unit) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt
index bf1f42395..01bad83e5 100644
--- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt
+++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt
@@ -11,6 +11,7 @@ import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type
import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter
+import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
@@ -25,9 +26,9 @@ internal sealed class PrimitiveAsmBuilder>(
classOfT: Class<*>,
protected val classOfTPrimitive: Class<*>,
expressionParent: Class,
- protected val target: MST,
+ protected val target: TypedMst,
) : AsmBuilder() {
- private val className: String = buildName(target)
+ private val className: String = buildName("${target.hashCode()}_${classOfT.simpleName}")
/**
* ASM type for [tType].
@@ -329,63 +330,39 @@ internal sealed class PrimitiveAsmBuilder>(
}
private fun visitVariables(
- node: MST,
+ node: TypedMst,
arrayMode: Boolean,
alreadyLoaded: MutableList = mutableListOf()
): Unit = when (node) {
- is Symbol -> when (node) {
- !in alreadyLoaded -> {
- alreadyLoaded += node
- prepareVariable(node, arrayMode)
- }
- else -> {
- }
- }
+ is TypedMst.Variable -> if (node.symbol !in alreadyLoaded) {
+ alreadyLoaded += node.symbol
+ prepareVariable(node.symbol, arrayMode)
+ } else Unit
- is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
+ is TypedMst.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
- is MST.Binary -> {
+ is TypedMst.Binary -> {
visitVariables(node.left, arrayMode, alreadyLoaded)
visitVariables(node.right, arrayMode, alreadyLoaded)
}
- else -> Unit
+ is TypedMst.Constant -> Unit
}
- private fun visitExpression(node: MST): Unit = when (node) {
- is Symbol -> {
- val symbol = algebra.bindSymbolOrNull(node)
+ private fun visitExpression(node: TypedMst): Unit = when (node) {
+ is TypedMst.Variable -> loadVariable(node.symbol)
- if (symbol != null)
- loadNumberConstant(symbol)
- else
- loadVariable(node)
- }
+ is TypedMst.Constant -> loadNumberConstant(
+ node.number ?: error("Object constants are not supported by pritimive ASM builder"),
+ )
- is MST.Numeric -> loadNumberConstant(algebra.number(node.value))
-
- is MST.Unary -> if (node.value is MST.Numeric)
- loadNumberConstant(
- algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)),
- )
- else
- visitUnary(node)
-
- is MST.Binary -> when {
- node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant(
- algebra.binaryOperationFunction(node.operation)(
- algebra.number((node.left as MST.Numeric).value),
- algebra.number((node.right as MST.Numeric).value),
- ),
- )
-
- else -> visitBinary(node)
- }
+ is TypedMst.Unary -> visitUnary(node)
+ is TypedMst.Binary -> visitBinary(node)
}
- protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value)
+ protected open fun visitUnary(node: TypedMst.Unary) = visitExpression(node.value)
- protected open fun visitBinary(node: MST.Binary) {
+ protected open fun visitBinary(node: TypedMst.Binary) {
visitExpression(node.left)
visitExpression(node.right)
}
@@ -404,14 +381,13 @@ internal sealed class PrimitiveAsmBuilder>(
}
@UnstableKMathAPI
-internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder(
+internal class DoubleAsmBuilder(target: TypedMst) : PrimitiveAsmBuilder(
DoubleField,
java.lang.Double::class.java,
java.lang.Double.TYPE,
DoubleExpression::class.java,
target,
) {
-
private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
MATH_TYPE.internalName,
name,
@@ -434,7 +410,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder) {
super.visitUnary(node)
when (node.operation) {
@@ -459,7 +435,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder) {
super.visitBinary(node)
when (node.operation) {
@@ -479,7 +455,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder) :
PrimitiveAsmBuilder(
IntRing,
Integer::class.java,
@@ -487,7 +463,7 @@ internal class IntAsmBuilder(target: MST) :
IntExpression::class.java,
target
) {
- override fun visitUnary(node: MST.Unary) {
+ override fun visitUnary(node: TypedMst.Unary) {
super.visitUnary(node)
when (node.operation) {
@@ -497,7 +473,7 @@ internal class IntAsmBuilder(target: MST) :
}
}
- override fun visitBinary(node: MST.Binary) {
+ override fun visitBinary(node: TypedMst.Binary) {
super.visitBinary(node)
when (node.operation) {
@@ -510,14 +486,14 @@ internal class IntAsmBuilder(target: MST) :
}
@UnstableKMathAPI
-internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder(
+internal class LongAsmBuilder(target: TypedMst) : PrimitiveAsmBuilder(
LongRing,
java.lang.Long::class.java,
java.lang.Long.TYPE,
LongExpression::class.java,
target,
) {
- override fun visitUnary(node: MST.Unary) {
+ override fun visitUnary(node: TypedMst.Unary) {
super.visitUnary(node)
when (node.operation) {
@@ -527,7 +503,7 @@ internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder) {
super.visitBinary(node)
when (node.operation) {
diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt
index 06e040e93..9e880f4fc 100644
--- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt
+++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt
@@ -55,15 +55,15 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.(
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
/**
- * Creates a class name for [Expression] subclassed to implement [mst] provided.
+ * Creates a class name for [Expression] based with appending [marker] to reduce collisions.
*
* These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
*
* @author Iaroslav Postovalov
*/
-internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
- val name = "space.kscience.kmath.asm.generated.CompiledExpression_${mst.hashCode()}_$collision"
+internal tailrec fun buildName(marker: String, collision: Int = 0): String {
+ val name = "space.kscience.kmath.asm.generated.CompiledExpression_${marker}_$collision"
try {
Class.forName(name)
@@ -71,7 +71,7 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
return name
}
- return buildName(mst, collision + 1)
+ return buildName(marker, collision + 1)
}
@Suppress("FunctionName")
diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
index 5f194f2ea..68cc8e791 100644
--- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
+++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
@@ -34,12 +34,12 @@ public abstract class FunctionalExpressionAlgebra>(
override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression =
{ left, right ->
Expression { arguments ->
- algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments))
+ algebra.binaryOperationFunction(operation)(left(arguments), right(arguments))
}
}
override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = { arg ->
- Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) }
+ Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) }
}
}
@@ -164,8 +164,6 @@ public open class FunctionalExpressionExtendedField>
override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression =
super.binaryOperationFunction(operation)
-
- override fun bindSymbol(value: String): Expression = super.bindSymbol(value)
}
public inline fun > A.expressionInGroup(
diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt
index 7533024a1..18226119b 100644
--- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt
+++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt
@@ -7,7 +7,7 @@ package space.kscience.kmath.expressions
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
-import space.kscience.kmath.operations.bindSymbol
+import space.kscience.kmath.operations.bindSymbolOrNull
/**
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
@@ -24,7 +24,7 @@ public sealed interface MST {
public data class Numeric(val value: Number) : MST
/**
- * A node containing an unary operation.
+ * A node containing a unary operation.
*
* @property operation the identifier of operation.
* @property value the argument of this operation.
@@ -34,7 +34,7 @@ public sealed interface MST {
/**
* A node containing binary operation.
*
- * @property operation the identifier operation.
+ * @property operation the identifier of operation.
* @property left the left operand.
* @property right the right operand.
*/
@@ -43,66 +43,50 @@ public sealed interface MST {
// TODO add a function with named arguments
-/**
- * Interprets the [MST] node with this [Algebra].
- *
- * @receiver the algebra that provides operations.
- * @param node the node to evaluate.
- * @return the value of expression.
- * @author Alexander Nozik
- */
-public fun Algebra.evaluate(node: MST): T = when (node) {
- is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value)
- ?: error("Numeric nodes are not supported by $this")
-
- is Symbol -> bindSymbol(node)
-
- is MST.Unary -> when {
- this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
- else -> unaryOperationFunction(node.operation)(evaluate(node.value))
- }
-
- is MST.Binary -> when {
- this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
- binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
-
- this is NumericAlgebra && node.left is MST.Numeric ->
- leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
-
- this is NumericAlgebra && node.right is MST.Numeric ->
- rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
-
- else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
- }
-}
-
-internal class InnerAlgebra(val algebra: Algebra, val arguments: Map) : NumericAlgebra {
- override fun bindSymbolOrNull(value: String): T? = algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)]
-
- override fun unaryOperation(operation: String, arg: T): T =
- algebra.unaryOperation(operation, arg)
-
- override fun binaryOperation(operation: String, left: T, right: T): T =
- algebra.binaryOperation(operation, left, right)
-
- override fun unaryOperationFunction(operation: String): (arg: T) -> T =
- algebra.unaryOperationFunction(operation)
-
- override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
- algebra.binaryOperationFunction(operation)
-
- @Suppress("UNCHECKED_CAST")
- override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
- (algebra as NumericAlgebra).number(value)
- else
- error("Numeric nodes are not supported by $this")
-}
/**
* Interprets the [MST] node with this [Algebra] and optional [arguments]
*/
-public fun MST.interpret(algebra: Algebra, arguments: Map): T =
- InnerAlgebra(algebra, arguments).evaluate(this)
+public fun MST.interpret(algebra: Algebra, arguments: Map): T = when (this) {
+ is MST.Numeric -> (algebra as NumericAlgebra?)?.number(value)
+ ?: error("Numeric nodes are not supported by $algebra")
+
+ is Symbol -> algebra.bindSymbolOrNull(this) ?: arguments.getValue(this)
+
+ is MST.Unary -> when {
+ algebra is NumericAlgebra && this.value is MST.Numeric -> algebra.unaryOperation(
+ this.operation,
+ algebra.number(this.value.value),
+ )
+ else -> algebra.unaryOperationFunction(this.operation)(this.value.interpret(algebra, arguments))
+ }
+
+ is MST.Binary -> when {
+ algebra is NumericAlgebra && this.left is MST.Numeric && this.right is MST.Numeric -> algebra.binaryOperation(
+ this.operation,
+ algebra.number(this.left.value),
+ algebra.number(this.right.value),
+ )
+
+ algebra is NumericAlgebra && this.left is MST.Numeric -> algebra.leftSideNumberOperation(
+ this.operation,
+ this.left.value,
+ this.right.interpret(algebra, arguments),
+ )
+
+ algebra is NumericAlgebra && this.right is MST.Numeric -> algebra.rightSideNumberOperation(
+ this.operation,
+ left.interpret(algebra, arguments),
+ right.value,
+ )
+
+ else -> algebra.binaryOperation(
+ this.operation,
+ this.left.interpret(algebra, arguments),
+ this.right.interpret(algebra, arguments),
+ )
+ }
+}
/**
* Interprets the [MST] node with this [Algebra] and optional [arguments]
@@ -111,12 +95,17 @@ public fun MST.interpret(algebra: Algebra, arguments: Map): T
* @param algebra the algebra that provides operations.
* @return the value of expression.
*/
-public fun MST.interpret(algebra: Algebra, vararg arguments: Pair): T =
- interpret(algebra, mapOf(*arguments))
+public fun MST.interpret(algebra: Algebra, vararg arguments: Pair): T = interpret(
+ algebra,
+ when (arguments.size) {
+ 0 -> emptyMap()
+ 1 -> mapOf(arguments[0])
+ else -> hashMapOf(*arguments)
+ },
+)
/**
* Interpret this [MST] as expression.
*/
-public fun MST.toExpression(algebra: Algebra): Expression = Expression { arguments ->
- interpret(algebra, arguments)
-}
+public fun MST.toExpression(algebra: Algebra): Expression =
+ Expression { arguments -> interpret(algebra, arguments) }
diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt
index 96fc73249..ac8c44446 100644
--- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt
+++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt
@@ -272,7 +272,7 @@ public fun > SimpleAutoDiffField.sqrt(x: Aut
public fun > SimpleAutoDiffField.pow(
x: AutoDiffValue,
y: Double,
-): AutoDiffValue = derive(const { x.value.pow(y)}) { z ->
+): AutoDiffValue = derive(const { x.value.pow(y) }) { z ->
x.d += z.d * y * x.value.pow(y - 1)
}
@@ -343,10 +343,7 @@ public fun > SimpleAutoDiffField.atanh(x: Au
public class SimpleAutoDiffExtendedField>(
context: F,
bindings: Map,
-) : ExtendedField>, ScaleOperations>,
- SimpleAutoDiffField(context, bindings) {
-
- override fun bindSymbol(value: String): AutoDiffValue = super.bindSymbol(value)
+) : ExtendedField>, ScaleOperations>, SimpleAutoDiffField(context, bindings) {
override fun number(value: Number): AutoDiffValue = const { number(value) }
diff --git a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/DoubleVector.kt
similarity index 100%
rename from kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealVector.kt
rename to kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/DoubleVector.kt
diff --git a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt
index e2541a73e..b40739ee0 100644
--- a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt
+++ b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt
@@ -199,8 +199,9 @@ public abstract class TensorFlowAlgebra> internal c
override fun StructureND.dot(other: StructureND): TensorFlowOutput = operate(other) { l, r ->
ops.linalg.matMul(
- if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
- if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r)
+ if (l.shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
+ if (r.shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r
+ )
}
override fun diagonalEmbedding(
@@ -241,6 +242,16 @@ public abstract class TensorFlowAlgebra> internal c
ops.math.argMax(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output()
).actualTensor
+// private val symbolCache = HashMap>()
+//
+// override fun bindSymbolOrNull(value: String): TensorFlowOutput? {
+// return symbolCache.getOrPut(value){ops.var}
+// }
+//
+// public fun StructureND.grad(
+//
+// )= operate { ops.gradients() }
+
@OptIn(UnstableKMathAPI::class)
override fun export(arg: StructureND): StructureND =
if (arg is TensorFlowOutput) arg.actualTensor else arg
diff --git a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/tfOperations.kt b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/tfOperations.kt
index 257d4d6ea..f67c333ce 100644
--- a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/tfOperations.kt
+++ b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/tfOperations.kt
@@ -20,4 +20,4 @@ public fun TensorFlowAlgebra.sin(
public fun TensorFlowAlgebra.cos(
arg: StructureND,
-): TensorFlowOutput where A : TrigonometricOperations, A : Ring