forked from kscience/kmath
Merge branch 'dev' into feature/noa
This commit is contained in:
commit
f3a411f0e2
@ -247,6 +247,12 @@ One can still use generic algebras though.
|
|||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
<hr/>
|
<hr/>
|
||||||
|
|
||||||
|
* ### [kmath-tensorflow](kmath-tensorflow)
|
||||||
|
>
|
||||||
|
>
|
||||||
|
> **Maturity**: PROTOTYPE
|
||||||
|
<hr/>
|
||||||
|
|
||||||
* ### [kmath-tensors](kmath-tensors)
|
* ### [kmath-tensors](kmath-tensors)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
|
@ -52,6 +52,8 @@ kotlin {
|
|||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-jafama"))
|
implementation(project(":kmath-jafama"))
|
||||||
implementation(project(":kmath-multik"))
|
implementation(project(":kmath-multik"))
|
||||||
|
implementation(projects.kmath.kmathTensorflow)
|
||||||
|
implementation("org.tensorflow:tensorflow-core-platform:0.4.0")
|
||||||
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
||||||
// uncomment if your system supports AVX2
|
// uncomment if your system supports AVX2
|
||||||
// val os = System.getProperty("os.name")
|
// val os = System.getProperty("os.name")
|
||||||
@ -122,6 +124,11 @@ benchmark {
|
|||||||
include("JafamaBenchmark")
|
include("JafamaBenchmark")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
configurations.register("tensorAlgebra") {
|
||||||
|
commonConfiguration()
|
||||||
|
include("TensorAlgebraBenchmark")
|
||||||
|
}
|
||||||
|
|
||||||
configurations.register("viktor") {
|
configurations.register("viktor") {
|
||||||
commonConfiguration()
|
commonConfiguration()
|
||||||
include("ViktorBenchmark")
|
include("ViktorBenchmark")
|
||||||
|
@ -16,6 +16,8 @@ import space.kscience.kmath.linear.linearSpace
|
|||||||
import space.kscience.kmath.multik.multikAlgebra
|
import space.kscience.kmath.multik.multikAlgebra
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.tensorflow.produceWithTF
|
||||||
|
import space.kscience.kmath.tensors.core.tensorAlgebra
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
@ -39,6 +41,16 @@ internal class DotBenchmark {
|
|||||||
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
|
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun tfDot(blackhole: Blackhole) {
|
||||||
|
blackhole.consume(
|
||||||
|
DoubleField.produceWithTF {
|
||||||
|
matrix1 dot matrix1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
|
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
@ -59,13 +71,13 @@ internal class DotBenchmark {
|
|||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Benchmark
|
@Benchmark
|
||||||
// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
fun tensorDot(blackhole: Blackhole) = with(DoubleField.tensorAlgebra) {
|
||||||
// blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
// }
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) {
|
fun multikDot(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.benchmarks
|
||||||
|
|
||||||
|
import kotlinx.benchmark.Benchmark
|
||||||
|
import kotlinx.benchmark.Blackhole
|
||||||
|
import kotlinx.benchmark.Scope
|
||||||
|
import kotlinx.benchmark.State
|
||||||
|
import space.kscience.kmath.linear.linearSpace
|
||||||
|
import space.kscience.kmath.linear.matrix
|
||||||
|
import space.kscience.kmath.linear.symmetric
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.tensors.core.tensorAlgebra
|
||||||
|
import kotlin.random.Random
|
||||||
|
|
||||||
|
@State(Scope.Benchmark)
|
||||||
|
internal class TensorAlgebraBenchmark {
|
||||||
|
companion object {
|
||||||
|
private val random = Random(12224)
|
||||||
|
private const val dim = 30
|
||||||
|
|
||||||
|
private val matrix = DoubleField.linearSpace.matrix(dim, dim).symmetric { _, _ -> random.nextDouble() }
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun tensorSymEigSvd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||||
|
blackhole.consume(matrix.symEigSvd(1e-10))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun tensorSymEigJacobi(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||||
|
blackhole.consume(matrix.symEigJacobi(50, 1e-10))
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,7 @@ allprojects {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group = "space.kscience"
|
group = "space.kscience"
|
||||||
version = "0.3.0-dev-17"
|
version = "0.3.0-dev-19"
|
||||||
}
|
}
|
||||||
|
|
||||||
subprojects {
|
subprojects {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Module kmath-ast
|
# Module kmath-ast
|
||||||
|
|
||||||
Performance and visualization extensions to MST API.
|
Extensions to MST API: transformations, dynamic compilation and visualization.
|
||||||
|
|
||||||
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
|
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
|
||||||
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
|
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
|
||||||
@ -35,6 +35,26 @@ dependencies {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Parsing expressions
|
||||||
|
|
||||||
|
In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
|
||||||
|
|
||||||
|
Supported literals:
|
||||||
|
1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
|
||||||
|
2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`.
|
||||||
|
|
||||||
|
Supported binary operators (from the highest precedence to the lowest one):
|
||||||
|
1. `^`
|
||||||
|
2. `*`, `/`
|
||||||
|
3. `+`, `-`
|
||||||
|
|
||||||
|
Supported unary operator:
|
||||||
|
1. `-`, e. g. `-x`
|
||||||
|
|
||||||
|
Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
|
||||||
|
1. `sin(x)`
|
||||||
|
2. `add(x, y)`
|
||||||
|
|
||||||
## Dynamic expression code generation
|
## Dynamic expression code generation
|
||||||
|
|
||||||
### On JVM
|
### On JVM
|
||||||
@ -42,48 +62,41 @@ dependencies {
|
|||||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
||||||
special implementation of `Expression<T>` with implemented `invoke` function.
|
special implementation of `Expression<T>` with implemented `invoke` function.
|
||||||
|
|
||||||
For example, the following builder:
|
For example, the following code:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
import space.kscience.kmath.asm.compileToExpression
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.complex.ComplexField
|
||||||
import space.kscience.kmath.operations.*
|
|
||||||
import space.kscience.kmath.asm.*
|
|
||||||
|
|
||||||
MstField { x + 2 }.compileToExpression(DoubleField)
|
"x+2".parseMath().compileToExpression(ComplexField)
|
||||||
```
|
```
|
||||||
|
|
||||||
... leads to generation of bytecode, which can be decompiled to the following Java class:
|
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||||
|
|
||||||
```java
|
```java
|
||||||
package space.kscience.kmath.asm.generated;
|
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import kotlin.jvm.functions.Function2;
|
import kotlin.jvm.functions.Function2;
|
||||||
import space.kscience.kmath.asm.internal.MapIntrinsics;
|
import space.kscience.kmath.asm.internal.MapIntrinsics;
|
||||||
|
import space.kscience.kmath.complex.Complex;
|
||||||
import space.kscience.kmath.expressions.Expression;
|
import space.kscience.kmath.expressions.Expression;
|
||||||
import space.kscience.kmath.expressions.Symbol;
|
import space.kscience.kmath.expressions.Symbol;
|
||||||
|
|
||||||
public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
public final class CompiledExpression_45045_0 implements Expression<Complex> {
|
||||||
private final Object[] constants;
|
private final Object[] constants;
|
||||||
|
|
||||||
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
|
public Complex invoke(Map<Symbol, ? extends Complex> arguments) {
|
||||||
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
|
Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
|
||||||
}
|
return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
|
||||||
|
|
||||||
public AsmCompiledExpression_45045_0(Object[] constants) {
|
|
||||||
this.constants = constants;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Known issues
|
Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
|
||||||
|
|
||||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
|
#### Limitations
|
||||||
loading overhead.
|
|
||||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
|
||||||
|
- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
|
||||||
|
|
||||||
### On JS
|
### On JS
|
||||||
|
|
||||||
@ -129,7 +142,7 @@ An example of emitted Wasm IR in the form of WAT:
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Known issues
|
#### Limitations
|
||||||
|
|
||||||
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
|
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
|
||||||
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
|
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
|
||||||
|
@ -1,11 +1,31 @@
|
|||||||
# Module kmath-ast
|
# Module kmath-ast
|
||||||
|
|
||||||
Performance and visualization extensions to MST API.
|
Extensions to MST API: transformations, dynamic compilation and visualization.
|
||||||
|
|
||||||
${features}
|
${features}
|
||||||
|
|
||||||
${artifact}
|
${artifact}
|
||||||
|
|
||||||
|
## Parsing expressions
|
||||||
|
|
||||||
|
In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
|
||||||
|
|
||||||
|
Supported literals:
|
||||||
|
1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
|
||||||
|
2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`.
|
||||||
|
|
||||||
|
Supported binary operators (from the highest precedence to the lowest one):
|
||||||
|
1. `^`
|
||||||
|
2. `*`, `/`
|
||||||
|
3. `+`, `-`
|
||||||
|
|
||||||
|
Supported unary operator:
|
||||||
|
1. `-`, e. g. `-x`
|
||||||
|
|
||||||
|
Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
|
||||||
|
1. `sin(x)`
|
||||||
|
2. `add(x, y)`
|
||||||
|
|
||||||
## Dynamic expression code generation
|
## Dynamic expression code generation
|
||||||
|
|
||||||
### On JVM
|
### On JVM
|
||||||
@ -13,48 +33,66 @@ ${artifact}
|
|||||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
||||||
special implementation of `Expression<T>` with implemented `invoke` function.
|
special implementation of `Expression<T>` with implemented `invoke` function.
|
||||||
|
|
||||||
For example, the following builder:
|
For example, the following code:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
import space.kscience.kmath.asm.compileToExpression
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.*
|
|
||||||
import space.kscience.kmath.asm.*
|
|
||||||
|
|
||||||
MstField { x + 2 }.compileToExpression(DoubleField)
|
"x^3-x+3".parseMath().compileToExpression(DoubleField)
|
||||||
```
|
```
|
||||||
|
|
||||||
... leads to generation of bytecode, which can be decompiled to the following Java class:
|
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||||
|
|
||||||
```java
|
```java
|
||||||
package space.kscience.kmath.asm.generated;
|
import java.util.*;
|
||||||
|
import kotlin.jvm.functions.*;
|
||||||
|
import space.kscience.kmath.asm.internal.*;
|
||||||
|
import space.kscience.kmath.complex.*;
|
||||||
|
import space.kscience.kmath.expressions.*;
|
||||||
|
|
||||||
import java.util.Map;
|
public final class CompiledExpression_45045_0 implements Expression<Complex> {
|
||||||
|
|
||||||
import kotlin.jvm.functions.Function2;
|
|
||||||
import space.kscience.kmath.asm.internal.MapIntrinsics;
|
|
||||||
import space.kscience.kmath.expressions.Expression;
|
|
||||||
import space.kscience.kmath.expressions.Symbol;
|
|
||||||
|
|
||||||
public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
|
||||||
private final Object[] constants;
|
private final Object[] constants;
|
||||||
|
|
||||||
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
|
public Complex invoke(Map<Symbol, ? extends Complex> arguments) {
|
||||||
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
|
Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
|
||||||
}
|
return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
|
||||||
|
|
||||||
public AsmCompiledExpression_45045_0(Object[] constants) {
|
|
||||||
this.constants = constants;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Known issues
|
For `LongRing`, `IntRing`, and `DoubleField` specialization is supported for better performance:
|
||||||
|
|
||||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
|
```java
|
||||||
loading overhead.
|
import java.util.*;
|
||||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
import space.kscience.kmath.asm.internal.*;
|
||||||
|
import space.kscience.kmath.expressions.*;
|
||||||
|
|
||||||
|
public final class CompiledExpression_-386104628_0 implements DoubleExpression {
|
||||||
|
private final SymbolIndexer indexer;
|
||||||
|
|
||||||
|
public SymbolIndexer getIndexer() {
|
||||||
|
return this.indexer;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double invoke(double[] arguments) {
|
||||||
|
double var2 = arguments[0];
|
||||||
|
return Math.pow(var2, 3.0D) - var2 + 3.0D;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
|
||||||
|
double var2 = ((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue();
|
||||||
|
return Math.pow(var2, 3.0D) - var2 + 3.0D;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
|
||||||
|
|
||||||
|
#### Limitations
|
||||||
|
|
||||||
|
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
|
||||||
|
- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
|
||||||
|
|
||||||
### On JS
|
### On JS
|
||||||
|
|
||||||
@ -100,7 +138,7 @@ An example of emitted Wasm IR in the form of WAT:
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Known issues
|
#### Limitations
|
||||||
|
|
||||||
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
|
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
|
||||||
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
|
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
|
||||||
|
@ -0,0 +1,177 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.kmath.expressions.Expression
|
||||||
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.Algebra
|
||||||
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MST form where all values belong to the type [T]. It is optimal for constant folding, dynamic compilation, etc.
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public sealed interface TypedMst<T> {
|
||||||
|
/**
|
||||||
|
* A node containing a unary operation.
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
* @property operation The identifier of operation.
|
||||||
|
* @property function The function implementing this operation.
|
||||||
|
* @property value The argument of this operation.
|
||||||
|
*/
|
||||||
|
public class Unary<T>(public val operation: String, public val function: (T) -> T, public val value: TypedMst<T>) :
|
||||||
|
TypedMst<T> {
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
other as Unary<*>
|
||||||
|
if (operation != other.operation) return false
|
||||||
|
if (value != other.value) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = operation.hashCode()
|
||||||
|
result = 31 * result + value.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "Unary(operation=$operation, value=$value)"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A node containing binary operation.
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
* @property operation The identifier of operation.
|
||||||
|
* @property function The binary function implementing this operation.
|
||||||
|
* @property left The left operand.
|
||||||
|
* @property right The right operand.
|
||||||
|
*/
|
||||||
|
public class Binary<T>(
|
||||||
|
public val operation: String,
|
||||||
|
public val function: Function<T>,
|
||||||
|
public val left: TypedMst<T>,
|
||||||
|
public val right: TypedMst<T>,
|
||||||
|
) : TypedMst<T> {
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
|
||||||
|
other as Binary<*>
|
||||||
|
|
||||||
|
if (operation != other.operation) return false
|
||||||
|
if (left != other.left) return false
|
||||||
|
if (right != other.right) return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = operation.hashCode()
|
||||||
|
result = 31 * result + left.hashCode()
|
||||||
|
result = 31 * result + right.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "Binary(operation=$operation, left=$left, right=$right)"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The non-numeric constant value.
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
* @property value The held value.
|
||||||
|
* @property number The number this value corresponds.
|
||||||
|
*/
|
||||||
|
public class Constant<T>(public val value: T, public val number: Number?) : TypedMst<T> {
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
other as Constant<*>
|
||||||
|
if (value != other.value) return false
|
||||||
|
if (number != other.number) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = value?.hashCode() ?: 0
|
||||||
|
result = 31 * result + (number?.hashCode() ?: 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "Constant(value=$value, number=$number)"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The node containing a variable
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
* @property symbol The symbol of the variable.
|
||||||
|
*/
|
||||||
|
public class Variable<T>(public val symbol: Symbol) : TypedMst<T> {
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
other as Variable<*>
|
||||||
|
if (symbol != other.symbol) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int = symbol.hashCode()
|
||||||
|
override fun toString(): String = "Variable(symbol=$symbol)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interprets the [TypedMst] node with this [Algebra] and [arguments].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = when (this) {
|
||||||
|
is TypedMst.Unary -> algebra.unaryOperation(operation, interpret(algebra, arguments))
|
||||||
|
|
||||||
|
is TypedMst.Binary -> when {
|
||||||
|
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null ->
|
||||||
|
algebra.leftSideNumberOperation(operation, left.number, right.interpret(algebra, arguments))
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null ->
|
||||||
|
algebra.rightSideNumberOperation(operation, left.interpret(algebra, arguments), right.number)
|
||||||
|
|
||||||
|
else -> algebra.binaryOperation(
|
||||||
|
operation,
|
||||||
|
left.interpret(algebra, arguments),
|
||||||
|
right.interpret(algebra, arguments),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
is TypedMst.Constant -> value
|
||||||
|
is TypedMst.Variable -> arguments.getValue(symbol)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interprets the [TypedMst] node with this [Algebra] and optional [arguments].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = interpret(
|
||||||
|
algebra,
|
||||||
|
when (arguments.size) {
|
||||||
|
0 -> emptyMap()
|
||||||
|
1 -> mapOf(arguments[0])
|
||||||
|
else -> hashMapOf(*arguments)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interpret this [TypedMst] node as expression.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T : Any> TypedMst<T>.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
|
||||||
|
interpret(algebra, arguments)
|
||||||
|
}
|
@ -0,0 +1,93 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.kmath.expressions.MST
|
||||||
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.Algebra
|
||||||
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
import space.kscience.kmath.operations.bindSymbolOrNull
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluates constants in given [MST] for given [algebra] at the same time with converting to [TypedMst].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (this) {
|
||||||
|
is MST.Numeric -> TypedMst.Constant(
|
||||||
|
(algebra as? NumericAlgebra<T>)?.number(value) ?: error("Numeric nodes are not supported by $algebra"),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
|
is MST.Unary -> when (val arg = value.evaluateConstants(algebra)) {
|
||||||
|
is TypedMst.Constant<T> -> {
|
||||||
|
val value = algebra.unaryOperation(
|
||||||
|
operation,
|
||||||
|
arg.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
TypedMst.Constant(value, if (value is Number) value else null)
|
||||||
|
}
|
||||||
|
|
||||||
|
else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
is MST.Binary -> {
|
||||||
|
val left = left.evaluateConstants(algebra)
|
||||||
|
val right = right.evaluateConstants(algebra)
|
||||||
|
|
||||||
|
when {
|
||||||
|
left is TypedMst.Constant<T> && right is TypedMst.Constant<T> -> {
|
||||||
|
val value = when {
|
||||||
|
algebra is NumericAlgebra && left.number != null -> algebra.leftSideNumberOperation(
|
||||||
|
operation,
|
||||||
|
left.number,
|
||||||
|
right.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && right.number != null -> algebra.rightSideNumberOperation(
|
||||||
|
operation,
|
||||||
|
left.value,
|
||||||
|
right.number,
|
||||||
|
)
|
||||||
|
|
||||||
|
else -> algebra.binaryOperation(
|
||||||
|
operation,
|
||||||
|
left.value,
|
||||||
|
right.value,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
TypedMst.Constant(value, if (value is Number) value else null)
|
||||||
|
}
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary(
|
||||||
|
operation,
|
||||||
|
algebra.leftSideNumberOperationFunction(operation),
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
)
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null -> TypedMst.Binary(
|
||||||
|
operation,
|
||||||
|
algebra.rightSideNumberOperationFunction(operation),
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
)
|
||||||
|
|
||||||
|
else -> TypedMst.Binary(operation, algebra.binaryOperationFunction(operation), left, right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
is Symbol -> {
|
||||||
|
val boundSymbol = algebra.bindSymbolOrNull(this)
|
||||||
|
|
||||||
|
if (boundSymbol != null)
|
||||||
|
TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null)
|
||||||
|
else
|
||||||
|
TypedMst.Variable(this)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,52 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.kmath.operations.ByteRing
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.IntRing
|
||||||
|
import space.kscience.kmath.operations.pi
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
|
internal class TestFolding {
|
||||||
|
@Test
|
||||||
|
fun foldUnary() = assertEquals(
|
||||||
|
-1,
|
||||||
|
("-(1)".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldDeepUnary() = assertEquals(
|
||||||
|
1,
|
||||||
|
("-(-(1))".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldBinary() = assertEquals(
|
||||||
|
2,
|
||||||
|
("1*2".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldDeepBinary() = assertEquals(
|
||||||
|
10,
|
||||||
|
("1*2*5".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldSymbol() = assertEquals(
|
||||||
|
DoubleField.pi,
|
||||||
|
("pi".parseMath().evaluateConstants(DoubleField) as? TypedMst.Constant<Double> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldNumeric() = assertEquals(
|
||||||
|
42.toByte(),
|
||||||
|
("42".parseMath().evaluateConstants(ByteRing) as? TypedMst.Constant<Byte> ?: fail()).value,
|
||||||
|
)
|
||||||
|
}
|
@ -7,7 +7,7 @@ package space.kscience.kmath.ast
|
|||||||
|
|
||||||
import space.kscience.kmath.complex.Complex
|
import space.kscience.kmath.complex.Complex
|
||||||
import space.kscience.kmath.complex.ComplexField
|
import space.kscience.kmath.complex.ComplexField
|
||||||
import space.kscience.kmath.expressions.evaluate
|
import space.kscience.kmath.expressions.interpret
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -17,14 +17,14 @@ internal class TestParser {
|
|||||||
@Test
|
@Test
|
||||||
fun evaluateParsedMst() {
|
fun evaluateParsedMst() {
|
||||||
val mst = "2+2*(2+2)".parseMath()
|
val mst = "2+2*(2+2)".parseMath()
|
||||||
val res = ComplexField.evaluate(mst)
|
val res = mst.interpret(ComplexField)
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun evaluateMstSymbol() {
|
fun evaluateMstSymbol() {
|
||||||
val mst = "i".parseMath()
|
val mst = "i".parseMath()
|
||||||
val res = ComplexField.evaluate(mst)
|
val res = mst.interpret(ComplexField)
|
||||||
assertEquals(ComplexField.i, res)
|
assertEquals(ComplexField.i, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ internal class TestParser {
|
|||||||
@Test
|
@Test
|
||||||
fun evaluateMstUnary() {
|
fun evaluateMstUnary() {
|
||||||
val mst = "sin(0)".parseMath()
|
val mst = "sin(0)".parseMath()
|
||||||
val res = DoubleField.evaluate(mst)
|
val res = mst.interpret(DoubleField)
|
||||||
assertEquals(0.0, res)
|
assertEquals(0.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ internal class TestParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
val mst = "magic(a, b)".parseMath()
|
val mst = "magic(a, b)".parseMath()
|
||||||
val res = magicalAlgebra.evaluate(mst)
|
val res = mst.interpret(magicalAlgebra)
|
||||||
assertEquals("a ★ b", res)
|
assertEquals("a ★ b", res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,35 +5,35 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.ast
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
import space.kscience.kmath.expressions.evaluate
|
import space.kscience.kmath.expressions.interpret
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class TestParserPrecedence {
|
internal class TestParserPrecedence {
|
||||||
@Test
|
@Test
|
||||||
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
|
fun test1(): Unit = assertEquals(6.0, "2*2+2".parseMath().interpret(f))
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
|
fun test2(): Unit = assertEquals(6.0, "2+2*2".parseMath().interpret(f))
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
|
fun test3(): Unit = assertEquals(10.0, "2^3+2".parseMath().interpret(f))
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
|
fun test4(): Unit = assertEquals(10.0, "2+2^3".parseMath().interpret(f))
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
|
fun test5(): Unit = assertEquals(16.0, "2^3*2".parseMath().interpret(f))
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
|
fun test6(): Unit = assertEquals(16.0, "2*2^3".parseMath().interpret(f))
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
|
fun test7(): Unit = assertEquals(18.0, "2+2^3*2".parseMath().interpret(f))
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
|
fun test8(): Unit = assertEquals(18.0, "2*2^3+2".parseMath().interpret(f))
|
||||||
|
|
||||||
private companion object {
|
private companion object {
|
||||||
private val f = DoubleField
|
private val f = DoubleField
|
||||||
|
@ -5,87 +5,48 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.estree
|
package space.kscience.kmath.estree
|
||||||
|
|
||||||
|
import space.kscience.kmath.ast.TypedMst
|
||||||
|
import space.kscience.kmath.ast.evaluateConstants
|
||||||
import space.kscience.kmath.estree.internal.ESTreeBuilder
|
import space.kscience.kmath.estree.internal.ESTreeBuilder
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
import space.kscience.kmath.expressions.MST
|
import space.kscience.kmath.expressions.MST
|
||||||
import space.kscience.kmath.expressions.MST.*
|
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.internal.estree.BaseExpression
|
import space.kscience.kmath.internal.estree.BaseExpression
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
import space.kscience.kmath.operations.NumericAlgebra
|
|
||||||
import space.kscience.kmath.operations.bindSymbolOrNull
|
|
||||||
|
|
||||||
@PublishedApi
|
|
||||||
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
|
||||||
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
|
|
||||||
is Symbol -> {
|
|
||||||
val symbol = algebra.bindSymbolOrNull(node)
|
|
||||||
|
|
||||||
if (symbol != null)
|
|
||||||
constant(symbol)
|
|
||||||
else
|
|
||||||
variable(node.identity)
|
|
||||||
}
|
|
||||||
|
|
||||||
is Numeric -> constant(
|
|
||||||
(algebra as? NumericAlgebra<T>)?.number(node.value) ?: error("Numeric nodes are not supported by $this")
|
|
||||||
)
|
|
||||||
|
|
||||||
is Unary -> when {
|
|
||||||
algebra is NumericAlgebra && node.value is Numeric -> constant(
|
|
||||||
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
|
||||||
}
|
|
||||||
|
|
||||||
is Binary -> when {
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
|
|
||||||
algebra.binaryOperationFunction(node.operation).invoke(
|
|
||||||
algebra.number((node.left as Numeric).value),
|
|
||||||
algebra.number((node.right as Numeric).value)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric -> call(
|
|
||||||
algebra.leftSideNumberOperationFunction(node.operation),
|
|
||||||
visit(node.left),
|
|
||||||
visit(node.right),
|
|
||||||
)
|
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.right is Numeric -> call(
|
|
||||||
algebra.rightSideNumberOperationFunction(node.operation),
|
|
||||||
visit(node.left),
|
|
||||||
visit(node.right),
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> call(
|
|
||||||
algebra.binaryOperationFunction(node.operation),
|
|
||||||
visit(node.left),
|
|
||||||
visit(node.right),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a compiled expression with given [MST] and given [algebra].
|
* Create a compiled expression with given [MST] and given [algebra].
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = compileWith(algebra)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
if (typed is TypedMst.Constant<T>) return Expression { typed.value }
|
||||||
|
|
||||||
|
fun ESTreeBuilder<T>.visit(node: TypedMst<T>): BaseExpression = when (node) {
|
||||||
|
is TypedMst.Constant -> constant(node.value)
|
||||||
|
is TypedMst.Variable -> variable(node.symbol)
|
||||||
|
is TypedMst.Unary -> call(node.function, visit(node.value))
|
||||||
|
|
||||||
|
is TypedMst.Binary -> call(
|
||||||
|
node.function,
|
||||||
|
visit(node.left),
|
||||||
|
visit(node.right),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ESTreeBuilder<T> { visit(typed) }.instance
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments]
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
public fun <T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments]
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
public fun <T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
||||||
compileToExpression(algebra).invoke(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
@ -61,7 +61,7 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name))
|
fun variable(name: Symbol): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name.identity))
|
||||||
|
|
||||||
fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
|
fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
|
||||||
optional = false,
|
optional = false,
|
||||||
|
@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.wasm.internal
|
package space.kscience.kmath.wasm.internal
|
||||||
|
|
||||||
|
import space.kscience.kmath.ast.TypedMst
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.MST.*
|
|
||||||
import space.kscience.kmath.internal.binaryen.*
|
import space.kscience.kmath.internal.binaryen.*
|
||||||
import space.kscience.kmath.internal.webassembly.Instance
|
import space.kscience.kmath.internal.webassembly.Instance
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
@ -16,11 +16,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
|
|||||||
|
|
||||||
private val spreader = eval("(obj, args) => obj(...args)")
|
private val spreader = eval("(obj, args) => obj(...args)")
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
@Suppress("UnsafeCastFromDynamic")
|
@Suppress("UnsafeCastFromDynamic")
|
||||||
internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
|
internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
|
||||||
protected val binaryenType: Type,
|
protected val binaryenType: Type,
|
||||||
protected val algebra: Algebra<T>,
|
protected val algebra: Algebra<T>,
|
||||||
protected val target: MST,
|
protected val target: TypedMst<T>,
|
||||||
) {
|
) {
|
||||||
protected val keys: MutableList<Symbol> = mutableListOf()
|
protected val keys: MutableList<Symbol> = mutableListOf()
|
||||||
protected lateinit var ctx: BinaryenModule
|
protected lateinit var ctx: BinaryenModule
|
||||||
@ -51,59 +52,41 @@ internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
|
|||||||
Instance(c, js("{}")).exports.executable
|
Instance(c, js("{}")).exports.executable
|
||||||
}
|
}
|
||||||
|
|
||||||
protected open fun visitSymbol(node: Symbol): ExpressionRef {
|
protected abstract fun visitNumber(number: Number): ExpressionRef
|
||||||
algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) }
|
|
||||||
|
|
||||||
var idx = keys.indexOf(node)
|
protected open fun visitVariable(node: TypedMst.Variable<T>): ExpressionRef {
|
||||||
|
var idx = keys.indexOf(node.symbol)
|
||||||
|
|
||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
keys += node
|
keys += node.symbol
|
||||||
idx = keys.lastIndex
|
idx = keys.lastIndex
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx.local.get(idx, binaryenType)
|
return ctx.local.get(idx, binaryenType)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract fun visitNumeric(node: Numeric): ExpressionRef
|
protected open fun visitUnary(node: TypedMst.Unary<T>): ExpressionRef =
|
||||||
|
|
||||||
protected open fun visitUnary(node: Unary): ExpressionRef =
|
|
||||||
error("Unary operation ${node.operation} not defined in $this")
|
error("Unary operation ${node.operation} not defined in $this")
|
||||||
|
|
||||||
protected open fun visitBinary(mst: Binary): ExpressionRef =
|
protected open fun visitBinary(mst: TypedMst.Binary<T>): ExpressionRef =
|
||||||
error("Binary operation ${mst.operation} not defined in $this")
|
error("Binary operation ${mst.operation} not defined in $this")
|
||||||
|
|
||||||
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
|
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
|
||||||
|
|
||||||
protected fun visit(node: MST): ExpressionRef = when (node) {
|
protected fun visit(node: TypedMst<T>): ExpressionRef = when (node) {
|
||||||
is Symbol -> visitSymbol(node)
|
is TypedMst.Constant -> visitNumber(
|
||||||
is Numeric -> visitNumeric(node)
|
node.number ?: error("Object constants are not supported by pritimive ASM builder"),
|
||||||
|
)
|
||||||
|
|
||||||
is Unary -> when {
|
is TypedMst.Variable -> visitVariable(node)
|
||||||
algebra is NumericAlgebra && node.value is Numeric -> visitNumeric(
|
is TypedMst.Unary -> visitUnary(node)
|
||||||
Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
|
is TypedMst.Binary -> visitBinary(node)
|
||||||
)
|
|
||||||
|
|
||||||
else -> visitUnary(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
is Binary -> when {
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric(
|
|
||||||
Numeric(
|
|
||||||
algebra.binaryOperationFunction(node.operation)
|
|
||||||
.invoke(
|
|
||||||
algebra.number((node.left as Numeric).value),
|
|
||||||
algebra.number((node.right as Numeric).value)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> visitBinary(node)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) {
|
internal class DoubleWasmBuilder(target: TypedMst<Double>) :
|
||||||
|
WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) {
|
||||||
override val instance by lazy {
|
override val instance by lazy {
|
||||||
object : DoubleExpression {
|
object : DoubleExpression {
|
||||||
override val indexer = SimpleSymbolIndexer(keys)
|
override val indexer = SimpleSymbolIndexer(keys)
|
||||||
@ -114,9 +97,9 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
|
|||||||
|
|
||||||
override fun createModule() = readBinary(f64StandardFunctions)
|
override fun createModule() = readBinary(f64StandardFunctions)
|
||||||
|
|
||||||
override fun visitNumeric(node: Numeric) = ctx.f64.const(node.value.toDouble())
|
override fun visitNumber(number: Number) = ctx.f64.const(number.toDouble())
|
||||||
|
|
||||||
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
|
override fun visitUnary(node: TypedMst.Unary<Double>): ExpressionRef = when (node.operation) {
|
||||||
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
|
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
|
||||||
GroupOps.PLUS_OPERATION -> visit(node.value)
|
GroupOps.PLUS_OPERATION -> visit(node.value)
|
||||||
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
|
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
|
||||||
@ -137,7 +120,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
|
|||||||
else -> super.visitUnary(node)
|
else -> super.visitUnary(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: TypedMst.Binary<Double>): ExpressionRef = when (mst.operation) {
|
||||||
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
||||||
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
||||||
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
||||||
@ -148,7 +131,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) {
|
internal class IntWasmBuilder(target: TypedMst<Int>) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) {
|
||||||
override val instance by lazy {
|
override val instance by lazy {
|
||||||
object : IntExpression {
|
object : IntExpression {
|
||||||
override val indexer = SimpleSymbolIndexer(keys)
|
override val indexer = SimpleSymbolIndexer(keys)
|
||||||
@ -157,15 +140,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value.toInt())
|
override fun visitNumber(number: Number) = ctx.i32.const(number.toInt())
|
||||||
|
|
||||||
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
|
override fun visitUnary(node: TypedMst.Unary<Int>): ExpressionRef = when (node.operation) {
|
||||||
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
|
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
|
||||||
GroupOps.PLUS_OPERATION -> visit(node.value)
|
GroupOps.PLUS_OPERATION -> visit(node.value)
|
||||||
else -> super.visitUnary(node)
|
else -> super.visitUnary(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: TypedMst.Binary<Int>): ExpressionRef = when (mst.operation) {
|
||||||
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
||||||
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
||||||
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
||||||
|
@ -7,7 +7,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.wasm
|
package space.kscience.kmath.wasm
|
||||||
|
|
||||||
import space.kscience.kmath.estree.compileWith
|
import space.kscience.kmath.ast.TypedMst
|
||||||
|
import space.kscience.kmath.ast.evaluateConstants
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
@ -21,8 +22,16 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: IntRing): IntExpression {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant) object : IntExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: IntArray): Int = typed.value
|
||||||
|
} else
|
||||||
|
IntWasmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -31,7 +40,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBui
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -49,7 +58,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleWasmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant) object : DoubleExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: DoubleArray): Double = typed.value
|
||||||
|
} else
|
||||||
|
DoubleWasmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -59,7 +77,7 @@ public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = D
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -69,4 +87,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
@ -8,10 +8,14 @@
|
|||||||
package space.kscience.kmath.asm
|
package space.kscience.kmath.asm
|
||||||
|
|
||||||
import space.kscience.kmath.asm.internal.*
|
import space.kscience.kmath.asm.internal.*
|
||||||
|
import space.kscience.kmath.ast.TypedMst
|
||||||
|
import space.kscience.kmath.ast.evaluateConstants
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.MST.*
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.Algebra
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.IntRing
|
||||||
|
import space.kscience.kmath.operations.LongRing
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compiles given MST to an Expression using AST compiler.
|
* Compiles given MST to an Expression using AST compiler.
|
||||||
@ -21,102 +25,64 @@ import space.kscience.kmath.operations.*
|
|||||||
* @return the compiled expression.
|
* @return the compiled expression.
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun GenericAsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) {
|
val typed = evaluateConstants(algebra)
|
||||||
is Symbol -> prepareVariable(node.identity)
|
if (typed is TypedMst.Constant<T>) return Expression { typed.value }
|
||||||
is Unary -> variablesVisitor(node.value)
|
|
||||||
|
|
||||||
is Binary -> {
|
fun GenericAsmBuilder<T>.variablesVisitor(node: TypedMst<T>): Unit = when (node) {
|
||||||
|
is TypedMst.Unary -> variablesVisitor(node.value)
|
||||||
|
|
||||||
|
is TypedMst.Binary -> {
|
||||||
variablesVisitor(node.left)
|
variablesVisitor(node.left)
|
||||||
variablesVisitor(node.right)
|
variablesVisitor(node.right)
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> Unit
|
is TypedMst.Variable -> prepareVariable(node.symbol)
|
||||||
|
is TypedMst.Constant -> Unit
|
||||||
}
|
}
|
||||||
|
|
||||||
fun GenericAsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) {
|
fun GenericAsmBuilder<T>.expressionVisitor(node: TypedMst<T>): Unit = when (node) {
|
||||||
is Symbol -> {
|
is TypedMst.Constant -> if (node.number != null)
|
||||||
val symbol = algebra.bindSymbolOrNull(node)
|
loadNumberConstant(node.number)
|
||||||
|
else
|
||||||
|
loadObjectConstant(node.value)
|
||||||
|
|
||||||
if (symbol != null)
|
is TypedMst.Variable -> loadVariable(node.symbol)
|
||||||
loadObjectConstant(symbol as Any)
|
is TypedMst.Unary -> buildCall(node.function) { expressionVisitor(node.value) }
|
||||||
else
|
|
||||||
loadVariable(node.identity)
|
|
||||||
}
|
|
||||||
|
|
||||||
is Numeric -> if (algebra is NumericAlgebra) {
|
is TypedMst.Binary -> buildCall(node.function) {
|
||||||
if (Number::class.java.isAssignableFrom(type))
|
expressionVisitor(node.left)
|
||||||
loadNumberConstant(algebra.number(node.value) as Number)
|
expressionVisitor(node.right)
|
||||||
else
|
|
||||||
loadObjectConstant(algebra.number(node.value))
|
|
||||||
} else
|
|
||||||
error("Numeric nodes are not supported by $this")
|
|
||||||
|
|
||||||
is Unary -> when {
|
|
||||||
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
|
||||||
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)),
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
is Binary -> when {
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
|
||||||
algebra.binaryOperationFunction(node.operation).invoke(
|
|
||||||
algebra.number((node.left as Numeric).value),
|
|
||||||
algebra.number((node.right as Numeric).value),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
|
||||||
algebra.leftSideNumberOperationFunction(node.operation),
|
|
||||||
) {
|
|
||||||
expressionVisitor(node.left)
|
|
||||||
expressionVisitor(node.right)
|
|
||||||
}
|
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
|
||||||
algebra.rightSideNumberOperationFunction(node.operation),
|
|
||||||
) {
|
|
||||||
expressionVisitor(node.left)
|
|
||||||
expressionVisitor(node.right)
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
|
|
||||||
expressionVisitor(node.left)
|
|
||||||
expressionVisitor(node.right)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return GenericAsmBuilder<T>(
|
return GenericAsmBuilder<T>(
|
||||||
type,
|
type,
|
||||||
buildName(this),
|
buildName("${typed.hashCode()}_${type.simpleName}"),
|
||||||
{ variablesVisitor(this@compileWith) },
|
{ variablesVisitor(typed) },
|
||||||
{ expressionVisitor(this@compileWith) },
|
{ expressionVisitor(typed) },
|
||||||
).instance
|
).instance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a compiled expression with given [MST] and given [algebra].
|
* Create a compiled expression with given [MST] and given [algebra].
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
|
public inline fun <reified T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
|
||||||
compileWith(T::class.java, algebra)
|
compileWith(T::class.java, algebra)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments]
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments]
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
||||||
compileToExpression(algebra).invoke(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -125,7 +91,16 @@ public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg argu
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: IntRing): IntExpression {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant) object : IntExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: IntArray): Int = typed.value
|
||||||
|
} else
|
||||||
|
IntAsmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -134,7 +109,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuil
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -152,8 +127,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: LongRing): LongExpression {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant<Long>) object : LongExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: LongArray): Long = typed.value
|
||||||
|
} else
|
||||||
|
LongAsmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -162,7 +145,7 @@ public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmB
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
|
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -181,7 +164,17 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>):
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant) object : DoubleExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: DoubleArray): Double = typed.value
|
||||||
|
} else
|
||||||
|
DoubleAsmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -190,7 +183,7 @@ public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = Dou
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -199,4 +192,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
@ -56,7 +56,7 @@ internal class GenericAsmBuilder<T>(
|
|||||||
/**
|
/**
|
||||||
* Local variables indices are indices of symbols in this list.
|
* Local variables indices are indices of symbols in this list.
|
||||||
*/
|
*/
|
||||||
private val argumentsLocals = mutableListOf<String>()
|
private val argumentsLocals = mutableListOf<Symbol>()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subclasses, loads and instantiates [Expression] for given parameters.
|
* Subclasses, loads and instantiates [Expression] for given parameters.
|
||||||
@ -253,10 +253,10 @@ internal class GenericAsmBuilder<T>(
|
|||||||
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
|
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
|
||||||
* [loadVariable].
|
* [loadVariable].
|
||||||
*/
|
*/
|
||||||
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
|
fun prepareVariable(name: Symbol): Unit = invokeMethodVisitor.run {
|
||||||
if (name in argumentsLocals) return@run
|
if (name in argumentsLocals) return@run
|
||||||
load(1, MAP_TYPE)
|
load(1, MAP_TYPE)
|
||||||
aconst(name)
|
aconst(name.identity)
|
||||||
|
|
||||||
invokestatic(
|
invokestatic(
|
||||||
MAP_INTRINSICS_TYPE.internalName,
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
@ -280,7 +280,7 @@ internal class GenericAsmBuilder<T>(
|
|||||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
|
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
|
||||||
* with [prepareVariable] first.
|
* with [prepareVariable] first.
|
||||||
*/
|
*/
|
||||||
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
|
fun loadVariable(name: Symbol): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
|
||||||
|
|
||||||
inline fun buildCall(function: Function<T>, parameters: GenericAsmBuilder<T>.() -> Unit) {
|
inline fun buildCall(function: Function<T>, parameters: GenericAsmBuilder<T>.() -> Unit) {
|
||||||
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
@ -11,6 +11,7 @@ import org.objectweb.asm.Opcodes.*
|
|||||||
import org.objectweb.asm.Type
|
import org.objectweb.asm.Type
|
||||||
import org.objectweb.asm.Type.*
|
import org.objectweb.asm.Type.*
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
|
import space.kscience.kmath.ast.TypedMst
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
@ -25,9 +26,9 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
|
|||||||
classOfT: Class<*>,
|
classOfT: Class<*>,
|
||||||
protected val classOfTPrimitive: Class<*>,
|
protected val classOfTPrimitive: Class<*>,
|
||||||
expressionParent: Class<E>,
|
expressionParent: Class<E>,
|
||||||
protected val target: MST,
|
protected val target: TypedMst<T>,
|
||||||
) : AsmBuilder() {
|
) : AsmBuilder() {
|
||||||
private val className: String = buildName(target)
|
private val className: String = buildName("${target.hashCode()}_${classOfT.simpleName}")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [tType].
|
* ASM type for [tType].
|
||||||
@ -329,63 +330,39 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
private fun visitVariables(
|
private fun visitVariables(
|
||||||
node: MST,
|
node: TypedMst<T>,
|
||||||
arrayMode: Boolean,
|
arrayMode: Boolean,
|
||||||
alreadyLoaded: MutableList<Symbol> = mutableListOf()
|
alreadyLoaded: MutableList<Symbol> = mutableListOf()
|
||||||
): Unit = when (node) {
|
): Unit = when (node) {
|
||||||
is Symbol -> when (node) {
|
is TypedMst.Variable -> if (node.symbol !in alreadyLoaded) {
|
||||||
!in alreadyLoaded -> {
|
alreadyLoaded += node.symbol
|
||||||
alreadyLoaded += node
|
prepareVariable(node.symbol, arrayMode)
|
||||||
prepareVariable(node, arrayMode)
|
} else Unit
|
||||||
}
|
|
||||||
else -> {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
|
is TypedMst.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
|
||||||
|
|
||||||
is MST.Binary -> {
|
is TypedMst.Binary -> {
|
||||||
visitVariables(node.left, arrayMode, alreadyLoaded)
|
visitVariables(node.left, arrayMode, alreadyLoaded)
|
||||||
visitVariables(node.right, arrayMode, alreadyLoaded)
|
visitVariables(node.right, arrayMode, alreadyLoaded)
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> Unit
|
is TypedMst.Constant -> Unit
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun visitExpression(node: MST): Unit = when (node) {
|
private fun visitExpression(node: TypedMst<T>): Unit = when (node) {
|
||||||
is Symbol -> {
|
is TypedMst.Variable -> loadVariable(node.symbol)
|
||||||
val symbol = algebra.bindSymbolOrNull(node)
|
|
||||||
|
|
||||||
if (symbol != null)
|
is TypedMst.Constant -> loadNumberConstant(
|
||||||
loadNumberConstant(symbol)
|
node.number ?: error("Object constants are not supported by pritimive ASM builder"),
|
||||||
else
|
)
|
||||||
loadVariable(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
is MST.Numeric -> loadNumberConstant(algebra.number(node.value))
|
is TypedMst.Unary -> visitUnary(node)
|
||||||
|
is TypedMst.Binary -> visitBinary(node)
|
||||||
is MST.Unary -> if (node.value is MST.Numeric)
|
|
||||||
loadNumberConstant(
|
|
||||||
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)),
|
|
||||||
)
|
|
||||||
else
|
|
||||||
visitUnary(node)
|
|
||||||
|
|
||||||
is MST.Binary -> when {
|
|
||||||
node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant(
|
|
||||||
algebra.binaryOperationFunction(node.operation)(
|
|
||||||
algebra.number((node.left as MST.Numeric).value),
|
|
||||||
algebra.number((node.right as MST.Numeric).value),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> visitBinary(node)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value)
|
protected open fun visitUnary(node: TypedMst.Unary<T>) = visitExpression(node.value)
|
||||||
|
|
||||||
protected open fun visitBinary(node: MST.Binary) {
|
protected open fun visitBinary(node: TypedMst.Binary<T>) {
|
||||||
visitExpression(node.left)
|
visitExpression(node.left)
|
||||||
visitExpression(node.right)
|
visitExpression(node.right)
|
||||||
}
|
}
|
||||||
@ -404,14 +381,13 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, DoubleExpression>(
|
internal class DoubleAsmBuilder(target: TypedMst<Double>) : PrimitiveAsmBuilder<Double, DoubleExpression>(
|
||||||
DoubleField,
|
DoubleField,
|
||||||
java.lang.Double::class.java,
|
java.lang.Double::class.java,
|
||||||
java.lang.Double.TYPE,
|
java.lang.Double.TYPE,
|
||||||
DoubleExpression::class.java,
|
DoubleExpression::class.java,
|
||||||
target,
|
target,
|
||||||
) {
|
) {
|
||||||
|
|
||||||
private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
|
private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
|
||||||
MATH_TYPE.internalName,
|
MATH_TYPE.internalName,
|
||||||
name,
|
name,
|
||||||
@ -434,7 +410,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
|
|||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun visitUnary(node: MST.Unary) {
|
override fun visitUnary(node: TypedMst.Unary<Double>) {
|
||||||
super.visitUnary(node)
|
super.visitUnary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -459,7 +435,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(node: MST.Binary) {
|
override fun visitBinary(node: TypedMst.Binary<Double>) {
|
||||||
super.visitBinary(node)
|
super.visitBinary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -479,7 +455,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class IntAsmBuilder(target: MST) :
|
internal class IntAsmBuilder(target: TypedMst<Int>) :
|
||||||
PrimitiveAsmBuilder<Int, IntExpression>(
|
PrimitiveAsmBuilder<Int, IntExpression>(
|
||||||
IntRing,
|
IntRing,
|
||||||
Integer::class.java,
|
Integer::class.java,
|
||||||
@ -487,7 +463,7 @@ internal class IntAsmBuilder(target: MST) :
|
|||||||
IntExpression::class.java,
|
IntExpression::class.java,
|
||||||
target
|
target
|
||||||
) {
|
) {
|
||||||
override fun visitUnary(node: MST.Unary) {
|
override fun visitUnary(node: TypedMst.Unary<Int>) {
|
||||||
super.visitUnary(node)
|
super.visitUnary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -497,7 +473,7 @@ internal class IntAsmBuilder(target: MST) :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(node: MST.Binary) {
|
override fun visitBinary(node: TypedMst.Binary<Int>) {
|
||||||
super.visitBinary(node)
|
super.visitBinary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -510,14 +486,14 @@ internal class IntAsmBuilder(target: MST) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpression>(
|
internal class LongAsmBuilder(target: TypedMst<Long>) : PrimitiveAsmBuilder<Long, LongExpression>(
|
||||||
LongRing,
|
LongRing,
|
||||||
java.lang.Long::class.java,
|
java.lang.Long::class.java,
|
||||||
java.lang.Long.TYPE,
|
java.lang.Long.TYPE,
|
||||||
LongExpression::class.java,
|
LongExpression::class.java,
|
||||||
target,
|
target,
|
||||||
) {
|
) {
|
||||||
override fun visitUnary(node: MST.Unary) {
|
override fun visitUnary(node: TypedMst.Unary<Long>) {
|
||||||
super.visitUnary(node)
|
super.visitUnary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -527,7 +503,7 @@ internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpre
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(node: MST.Binary) {
|
override fun visitBinary(node: TypedMst.Binary<Long>) {
|
||||||
super.visitBinary(node)
|
super.visitBinary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
|
@ -55,15 +55,15 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.(
|
|||||||
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
|
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
* Creates a class name for [Expression] based with appending [marker] to reduce collisions.
|
||||||
*
|
*
|
||||||
* These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there
|
* These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there
|
||||||
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
|
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
|
||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
|
internal tailrec fun buildName(marker: String, collision: Int = 0): String {
|
||||||
val name = "space.kscience.kmath.asm.generated.CompiledExpression_${mst.hashCode()}_$collision"
|
val name = "space.kscience.kmath.asm.generated.CompiledExpression_${marker}_$collision"
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Class.forName(name)
|
Class.forName(name)
|
||||||
@ -71,7 +71,7 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
|
|||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
return buildName(mst, collision + 1)
|
return buildName(marker, collision + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
|
@ -34,12 +34,12 @@ public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
|
|||||||
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
{ left, right ->
|
{ left, right ->
|
||||||
Expression { arguments ->
|
Expression { arguments ->
|
||||||
algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments))
|
algebra.binaryOperationFunction(operation)(left(arguments), right(arguments))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
|
override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
|
||||||
Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) }
|
Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,8 +164,6 @@ public open class FunctionalExpressionExtendedField<T, out A : ExtendedField<T>>
|
|||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionField>.binaryOperationFunction(operation)
|
super<FunctionalExpressionField>.binaryOperationFunction(operation)
|
||||||
|
|
||||||
override fun bindSymbol(value: String): Expression<T> = super<FunctionalExpressionField>.bindSymbol(value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <T, A : Group<T>> A.expressionInGroup(
|
public inline fun <T, A : Group<T>> A.expressionInGroup(
|
||||||
|
@ -7,7 +7,7 @@ package space.kscience.kmath.expressions
|
|||||||
|
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
import space.kscience.kmath.operations.NumericAlgebra
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
import space.kscience.kmath.operations.bindSymbol
|
import space.kscience.kmath.operations.bindSymbolOrNull
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
|
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
|
||||||
@ -24,7 +24,7 @@ public sealed interface MST {
|
|||||||
public data class Numeric(val value: Number) : MST
|
public data class Numeric(val value: Number) : MST
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A node containing an unary operation.
|
* A node containing a unary operation.
|
||||||
*
|
*
|
||||||
* @property operation the identifier of operation.
|
* @property operation the identifier of operation.
|
||||||
* @property value the argument of this operation.
|
* @property value the argument of this operation.
|
||||||
@ -34,7 +34,7 @@ public sealed interface MST {
|
|||||||
/**
|
/**
|
||||||
* A node containing binary operation.
|
* A node containing binary operation.
|
||||||
*
|
*
|
||||||
* @property operation the identifier operation.
|
* @property operation the identifier of operation.
|
||||||
* @property left the left operand.
|
* @property left the left operand.
|
||||||
* @property right the right operand.
|
* @property right the right operand.
|
||||||
*/
|
*/
|
||||||
@ -43,66 +43,50 @@ public sealed interface MST {
|
|||||||
|
|
||||||
// TODO add a function with named arguments
|
// TODO add a function with named arguments
|
||||||
|
|
||||||
/**
|
|
||||||
* Interprets the [MST] node with this [Algebra].
|
|
||||||
*
|
|
||||||
* @receiver the algebra that provides operations.
|
|
||||||
* @param node the node to evaluate.
|
|
||||||
* @return the value of expression.
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
|
||||||
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
|
||||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
|
||||||
?: error("Numeric nodes are not supported by $this")
|
|
||||||
|
|
||||||
is Symbol -> bindSymbol(node)
|
|
||||||
|
|
||||||
is MST.Unary -> when {
|
|
||||||
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
|
|
||||||
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
|
||||||
}
|
|
||||||
|
|
||||||
is MST.Binary -> when {
|
|
||||||
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
|
|
||||||
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
|
|
||||||
|
|
||||||
this is NumericAlgebra && node.left is MST.Numeric ->
|
|
||||||
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
|
||||||
|
|
||||||
this is NumericAlgebra && node.right is MST.Numeric ->
|
|
||||||
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
|
||||||
|
|
||||||
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class InnerAlgebra<T>(val algebra: Algebra<T>, val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
|
||||||
override fun bindSymbolOrNull(value: String): T? = algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)]
|
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: T): T =
|
|
||||||
algebra.unaryOperation(operation, arg)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
|
||||||
algebra.binaryOperation(operation, left, right)
|
|
||||||
|
|
||||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
|
|
||||||
algebra.unaryOperationFunction(operation)
|
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
|
|
||||||
algebra.binaryOperationFunction(operation)
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
|
||||||
(algebra as NumericAlgebra<T>).number(value)
|
|
||||||
else
|
|
||||||
error("Numeric nodes are not supported by $this")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
||||||
*/
|
*/
|
||||||
public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = when (this) {
|
||||||
InnerAlgebra(algebra, arguments).evaluate(this)
|
is MST.Numeric -> (algebra as NumericAlgebra<T>?)?.number(value)
|
||||||
|
?: error("Numeric nodes are not supported by $algebra")
|
||||||
|
|
||||||
|
is Symbol -> algebra.bindSymbolOrNull(this) ?: arguments.getValue(this)
|
||||||
|
|
||||||
|
is MST.Unary -> when {
|
||||||
|
algebra is NumericAlgebra && this.value is MST.Numeric -> algebra.unaryOperation(
|
||||||
|
this.operation,
|
||||||
|
algebra.number(this.value.value),
|
||||||
|
)
|
||||||
|
else -> algebra.unaryOperationFunction(this.operation)(this.value.interpret(algebra, arguments))
|
||||||
|
}
|
||||||
|
|
||||||
|
is MST.Binary -> when {
|
||||||
|
algebra is NumericAlgebra && this.left is MST.Numeric && this.right is MST.Numeric -> algebra.binaryOperation(
|
||||||
|
this.operation,
|
||||||
|
algebra.number(this.left.value),
|
||||||
|
algebra.number(this.right.value),
|
||||||
|
)
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && this.left is MST.Numeric -> algebra.leftSideNumberOperation(
|
||||||
|
this.operation,
|
||||||
|
this.left.value,
|
||||||
|
this.right.interpret(algebra, arguments),
|
||||||
|
)
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && this.right is MST.Numeric -> algebra.rightSideNumberOperation(
|
||||||
|
this.operation,
|
||||||
|
left.interpret(algebra, arguments),
|
||||||
|
right.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
else -> algebra.binaryOperation(
|
||||||
|
this.operation,
|
||||||
|
this.left.interpret(algebra, arguments),
|
||||||
|
this.right.interpret(algebra, arguments),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
* Interprets the [MST] node with this [Algebra] and optional [arguments]
|
||||||
@ -111,12 +95,17 @@ public fun <T> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T
|
|||||||
* @param algebra the algebra that provides operations.
|
* @param algebra the algebra that provides operations.
|
||||||
* @return the value of expression.
|
* @return the value of expression.
|
||||||
*/
|
*/
|
||||||
public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = interpret(
|
||||||
interpret(algebra, mapOf(*arguments))
|
algebra,
|
||||||
|
when (arguments.size) {
|
||||||
|
0 -> emptyMap()
|
||||||
|
1 -> mapOf(arguments[0])
|
||||||
|
else -> hashMapOf(*arguments)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Interpret this [MST] as expression.
|
* Interpret this [MST] as expression.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
|
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> =
|
||||||
interpret(algebra, arguments)
|
Expression { arguments -> interpret(algebra, arguments) }
|
||||||
}
|
|
||||||
|
@ -272,7 +272,7 @@ public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: Aut
|
|||||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
x: AutoDiffValue<T>,
|
x: AutoDiffValue<T>,
|
||||||
y: Double,
|
y: Double,
|
||||||
): AutoDiffValue<T> = derive(const { x.value.pow(y)}) { z ->
|
): AutoDiffValue<T> = derive(const { x.value.pow(y) }) { z ->
|
||||||
x.d += z.d * y * x.value.pow(y - 1)
|
x.d += z.d * y * x.value.pow(y - 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -343,10 +343,7 @@ public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: Au
|
|||||||
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
||||||
context: F,
|
context: F,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>,
|
) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
|
||||||
SimpleAutoDiffField<T, F>(context, bindings) {
|
|
||||||
|
|
||||||
override fun bindSymbol(value: String): AutoDiffValue<T> = super<SimpleAutoDiffField>.bindSymbol(value)
|
|
||||||
|
|
||||||
override fun number(value: Number): AutoDiffValue<T> = const { number(value) }
|
override fun number(value: Number): AutoDiffValue<T> = const { number(value) }
|
||||||
|
|
||||||
|
@ -199,8 +199,9 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
|||||||
|
|
||||||
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = operate(other) { l, r ->
|
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = operate(other) { l, r ->
|
||||||
ops.linalg.matMul(
|
ops.linalg.matMul(
|
||||||
if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
|
if (l.shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
|
||||||
if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r)
|
if (r.shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(
|
override fun diagonalEmbedding(
|
||||||
@ -241,6 +242,16 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
|||||||
ops.math.argMax(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output()
|
ops.math.argMax(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output()
|
||||||
).actualTensor
|
).actualTensor
|
||||||
|
|
||||||
|
// private val symbolCache = HashMap<String, TensorFlowOutput<T, TT>>()
|
||||||
|
//
|
||||||
|
// override fun bindSymbolOrNull(value: String): TensorFlowOutput<T, TT>? {
|
||||||
|
// return symbolCache.getOrPut(value){ops.var}
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// public fun StructureND<T>.grad(
|
||||||
|
//
|
||||||
|
// )= operate { ops.gradients() }
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
override fun export(arg: StructureND<T>): StructureND<T> =
|
override fun export(arg: StructureND<T>): StructureND<T> =
|
||||||
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
|
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
|
||||||
|
@ -4,6 +4,9 @@ import org.junit.jupiter.api.Test
|
|||||||
import space.kscience.kmath.nd.get
|
import space.kscience.kmath.nd.get
|
||||||
import space.kscience.kmath.nd.structureND
|
import space.kscience.kmath.nd.structureND
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.sum
|
||||||
|
import kotlin.random.Random
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class DoubleTensorFlowOps {
|
class DoubleTensorFlowOps {
|
||||||
@ -18,6 +21,19 @@ class DoubleTensorFlowOps {
|
|||||||
assertEquals(3.0, res[0, 0])
|
assertEquals(3.0, res[0, 0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun dot(){
|
||||||
|
val random = Random(12224)
|
||||||
|
val dim = 1000
|
||||||
|
|
||||||
|
val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224)
|
||||||
|
val tensor2 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12225)
|
||||||
|
|
||||||
|
DoubleField.produceWithTF {
|
||||||
|
tensor1 dot tensor2
|
||||||
|
}.sum()
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun extensionOps(){
|
fun extensionOps(){
|
||||||
val res = DoubleField.produceWithTF {
|
val res = DoubleField.produceWithTF {
|
||||||
|
@ -9,10 +9,7 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.MutableStructure2D
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.nd.StructureND
|
|
||||||
import space.kscience.kmath.nd.as1D
|
|
||||||
import space.kscience.kmath.nd.as2D
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.structures.indices
|
import space.kscience.kmath.structures.indices
|
||||||
@ -421,7 +418,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
|
|
||||||
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
|
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
|
||||||
val (a, b) = ab
|
val (a, b) = ab
|
||||||
dotTo(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
|
dotTo(a, b, res, l, m1, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
return if (penultimateDim) {
|
return if (penultimateDim) {
|
||||||
@ -885,7 +882,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
|
return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = symEig(epsilon = 1e-15)
|
override fun StructureND<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = symEigJacobi(maxIteration = 50, epsilon = 1e-15)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
|
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
|
||||||
@ -895,7 +892,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
* and when the cosine approaches 1 in the SVD algorithm.
|
* and when the cosine approaches 1 in the SVD algorithm.
|
||||||
* @return a pair `eigenvalues to eigenvectors`.
|
* @return a pair `eigenvalues to eigenvectors`.
|
||||||
*/
|
*/
|
||||||
public fun StructureND<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
|
public fun StructureND<Double>.symEigSvd(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
|
||||||
checkSymmetric(tensor, epsilon)
|
checkSymmetric(tensor, epsilon)
|
||||||
|
|
||||||
fun MutableStructure2D<Double>.cleanSym(n: Int) {
|
fun MutableStructure2D<Double>.cleanSym(n: Int) {
|
||||||
@ -922,6 +919,151 @@ public open class DoubleTensorAlgebra :
|
|||||||
return eig to v
|
return eig to v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun StructureND<Double>.symEigJacobi(maxIteration: Int, epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
|
||||||
|
checkSymmetric(tensor, epsilon)
|
||||||
|
|
||||||
|
val size = this.dimension
|
||||||
|
val eigenvectors = zeros(this.shape)
|
||||||
|
val eigenvalues = zeros(this.shape.sliceArray(0 until size - 1))
|
||||||
|
|
||||||
|
var eigenvalueStart = 0
|
||||||
|
var eigenvectorStart = 0
|
||||||
|
for (matrix in tensor.matrixSequence()) {
|
||||||
|
val matrix2D = matrix.as2D()
|
||||||
|
val (d, v) = matrix2D.jacobiHelper(maxIteration, epsilon)
|
||||||
|
|
||||||
|
for (i in 0 until matrix2D.rowNum) {
|
||||||
|
for (j in 0 until matrix2D.colNum) {
|
||||||
|
eigenvectors.mutableBuffer.array()[eigenvectorStart + i * matrix2D.rowNum + j] = v[i, j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i in 0 until matrix2D.rowNum) {
|
||||||
|
eigenvalues.mutableBuffer.array()[eigenvalueStart + i] = d[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
eigenvalueStart += this.shape.last()
|
||||||
|
eigenvectorStart += this.shape.last() * this.shape.last()
|
||||||
|
}
|
||||||
|
|
||||||
|
return eigenvalues to eigenvectors
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun MutableStructure2D<Double>.jacobiHelper(
|
||||||
|
maxIteration: Int,
|
||||||
|
epsilon: Double
|
||||||
|
): Pair<Structure1D<Double>, Structure2D<Double>> {
|
||||||
|
val n = this.shape[0]
|
||||||
|
val A_ = this.copy()
|
||||||
|
val V = eye(n)
|
||||||
|
val D = DoubleTensor(intArrayOf(n), (0 until this.rowNum).map { this[it, it] }.toDoubleArray()).as1D()
|
||||||
|
val B = DoubleTensor(intArrayOf(n), (0 until this.rowNum).map { this[it, it] }.toDoubleArray()).as1D()
|
||||||
|
val Z = zeros(intArrayOf(n)).as1D()
|
||||||
|
|
||||||
|
// assume that buffered tensor is square matrix
|
||||||
|
operator fun BufferedTensor<Double>.get(i: Int, j: Int): Double {
|
||||||
|
return this.mutableBuffer.array()[bufferStart + i * this.shape[0] + j]
|
||||||
|
}
|
||||||
|
|
||||||
|
operator fun BufferedTensor<Double>.set(i: Int, j: Int, value: Double) {
|
||||||
|
this.mutableBuffer.array()[bufferStart + i * this.shape[0] + j] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
fun maxOffDiagonal(matrix: BufferedTensor<Double>): Double {
|
||||||
|
var maxOffDiagonalElement = 0.0
|
||||||
|
for (i in 0 until n - 1) {
|
||||||
|
for (j in i + 1 until n) {
|
||||||
|
maxOffDiagonalElement = max(maxOffDiagonalElement, abs(matrix[i, j]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxOffDiagonalElement
|
||||||
|
}
|
||||||
|
|
||||||
|
fun rotate(a: BufferedTensor<Double>, s: Double, tau: Double, i: Int, j: Int, k: Int, l: Int) {
|
||||||
|
val g = a[i, j]
|
||||||
|
val h = a[k, l]
|
||||||
|
a[i, j] = g - s * (h + g * tau)
|
||||||
|
a[k, l] = h + s * (g - h * tau)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun jacobiIteration(
|
||||||
|
a: BufferedTensor<Double>,
|
||||||
|
v: BufferedTensor<Double>,
|
||||||
|
d: MutableStructure1D<Double>,
|
||||||
|
z: MutableStructure1D<Double>,
|
||||||
|
) {
|
||||||
|
for (ip in 0 until n - 1) {
|
||||||
|
for (iq in ip + 1 until n) {
|
||||||
|
val g = 100.0 * abs(a[ip, iq])
|
||||||
|
|
||||||
|
if (g <= epsilon * abs(d[ip]) && g <= epsilon * abs(d[iq])) {
|
||||||
|
a[ip, iq] = 0.0
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var h = d[iq] - d[ip]
|
||||||
|
val t = when {
|
||||||
|
g <= epsilon * abs(h) -> (a[ip, iq]) / h
|
||||||
|
else -> {
|
||||||
|
val theta = 0.5 * h / (a[ip, iq])
|
||||||
|
val denominator = abs(theta) + sqrt(1.0 + theta * theta)
|
||||||
|
if (theta < 0.0) -1.0 / denominator else 1.0 / denominator
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val c = 1.0 / sqrt(1 + t * t)
|
||||||
|
val s = t * c
|
||||||
|
val tau = s / (1.0 + c)
|
||||||
|
h = t * a[ip, iq]
|
||||||
|
z[ip] -= h
|
||||||
|
z[iq] += h
|
||||||
|
d[ip] -= h
|
||||||
|
d[iq] += h
|
||||||
|
a[ip, iq] = 0.0
|
||||||
|
|
||||||
|
for (j in 0 until ip) {
|
||||||
|
rotate(a, s, tau, j, ip, j, iq)
|
||||||
|
}
|
||||||
|
for (j in (ip + 1) until iq) {
|
||||||
|
rotate(a, s, tau, ip, j, j, iq)
|
||||||
|
}
|
||||||
|
for (j in (iq + 1) until n) {
|
||||||
|
rotate(a, s, tau, ip, j, iq, j)
|
||||||
|
}
|
||||||
|
for (j in 0 until n) {
|
||||||
|
rotate(v, s, tau, j, ip, j, iq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun updateDiagonal(
|
||||||
|
d: MutableStructure1D<Double>,
|
||||||
|
z: MutableStructure1D<Double>,
|
||||||
|
b: MutableStructure1D<Double>,
|
||||||
|
) {
|
||||||
|
for (ip in 0 until d.size) {
|
||||||
|
b[ip] += z[ip]
|
||||||
|
d[ip] = b[ip]
|
||||||
|
z[ip] = 0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var sm = maxOffDiagonal(A_)
|
||||||
|
for (iteration in 0 until maxIteration) {
|
||||||
|
if (sm < epsilon) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
jacobiIteration(A_, V, D, Z)
|
||||||
|
updateDiagonal(D, Z, B)
|
||||||
|
sm = maxOffDiagonal(A_)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO sort eigenvalues
|
||||||
|
return D to V.as2D()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes the determinant of a square matrix input, or of each square matrix in a batched input
|
* Computes the determinant of a square matrix input, or of each square matrix in a batched input
|
||||||
* using LU factorization algorithm.
|
* using LU factorization algorithm.
|
||||||
@ -997,5 +1139,6 @@ public open class DoubleTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra
|
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra
|
||||||
|
public val DoubleField.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,18 +54,26 @@ internal val <T> BufferedTensor<T>.matrices: VirtualBuffer<BufferedTensor<T>>
|
|||||||
internal fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTensor<T>> = matrices.asSequence()
|
internal fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTensor<T>> = matrices.asSequence()
|
||||||
|
|
||||||
internal fun dotTo(
|
internal fun dotTo(
|
||||||
a: MutableStructure2D<Double>,
|
a: BufferedTensor<Double>,
|
||||||
b: MutableStructure2D<Double>,
|
b: BufferedTensor<Double>,
|
||||||
res: MutableStructure2D<Double>,
|
res: BufferedTensor<Double>,
|
||||||
l: Int, m: Int, n: Int,
|
l: Int, m: Int, n: Int,
|
||||||
) {
|
) {
|
||||||
|
val aStart = a.bufferStart
|
||||||
|
val bStart = b.bufferStart
|
||||||
|
val resStart = res.bufferStart
|
||||||
|
|
||||||
|
val aBuffer = a.mutableBuffer
|
||||||
|
val bBuffer = b.mutableBuffer
|
||||||
|
val resBuffer = res.mutableBuffer
|
||||||
|
|
||||||
for (i in 0 until l) {
|
for (i in 0 until l) {
|
||||||
for (j in 0 until n) {
|
for (j in 0 until n) {
|
||||||
var curr = 0.0
|
var curr = 0.0
|
||||||
for (k in 0 until m) {
|
for (k in 0 until m) {
|
||||||
curr += a[i, k] * b[k, j]
|
curr += aBuffer[aStart + i * m + k] * bBuffer[bStart + k * n + j]
|
||||||
}
|
}
|
||||||
res[i, j] = curr
|
resBuffer[resStart + i * n + j] = curr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,6 +107,8 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
val tensor11 = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor11 = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val tensor2 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
|
val tensor2 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
val tensor3 = fromArray(intArrayOf(1, 1, 3), doubleArrayOf(-1.0, -2.0, -3.0))
|
val tensor3 = fromArray(intArrayOf(1, 1, 3), doubleArrayOf(-1.0, -2.0, -3.0))
|
||||||
|
val tensor4 = fromArray(intArrayOf(2, 3, 3), (1..18).map { it.toDouble() }.toDoubleArray())
|
||||||
|
val tensor5 = fromArray(intArrayOf(2, 3, 3), (1..18).map { 1 + it.toDouble() }.toDoubleArray())
|
||||||
|
|
||||||
val res12 = tensor1.dot(tensor2)
|
val res12 = tensor1.dot(tensor2)
|
||||||
assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0))
|
assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0))
|
||||||
@ -123,6 +125,13 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
val res11 = tensor1.dot(tensor11)
|
val res11 = tensor1.dot(tensor11)
|
||||||
assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
|
assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
|
||||||
assertTrue(res11.shape contentEquals intArrayOf(2, 2))
|
assertTrue(res11.shape contentEquals intArrayOf(2, 2))
|
||||||
|
|
||||||
|
val res45 = tensor4.dot(tensor5)
|
||||||
|
assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf(
|
||||||
|
36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0,
|
||||||
|
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
||||||
|
))
|
||||||
|
assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
Reference in New Issue
Block a user