Merge branch 'dev' into feature/mp-samplers
# Conflicts: # examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt # kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt
This commit is contained in:
commit
624460c52d
23
CHANGELOG.md
23
CHANGELOG.md
@ -4,27 +4,28 @@
|
||||
### Added
|
||||
- `fun` annotation for SAM interfaces in library
|
||||
- Explicit `public` visibility for all public APIs
|
||||
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
||||
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140)
|
||||
- Automatic README generation for features (#139)
|
||||
- Native support for `memory`, `core` and `dimensions`
|
||||
- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136).
|
||||
- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136)
|
||||
- A separate `Symbol` entity, which is used for global unbound symbol.
|
||||
- A `Symbol` indexing scope.
|
||||
- Basic optimization API for Commons-math.
|
||||
- Chi squared optimization for array-like data in CM
|
||||
- `Fitting` utility object in prob/stat
|
||||
- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`.
|
||||
- Coroutine-deterministic Monte-Carlo scope with a random number generator.
|
||||
- Some minor utilities to `kmath-for-real`.
|
||||
- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`
|
||||
- Coroutine-deterministic Monte-Carlo scope with a random number generator
|
||||
- Some minor utilities to `kmath-for-real`
|
||||
- Generic operation result parameter to `MatrixContext`
|
||||
- New `MatrixFeature` interfaces for matrix decompositions
|
||||
|
||||
### Changed
|
||||
- Package changed from `scientifik` to `kscience.kmath`.
|
||||
- Gradle version: 6.6 -> 6.7.1
|
||||
- Package changed from `scientifik` to `kscience.kmath`
|
||||
- Gradle version: 6.6 -> 6.8
|
||||
- Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`)
|
||||
- `Polynomial` secondary constructor made function.
|
||||
- Kotlin version: 1.3.72 -> 1.4.20
|
||||
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
|
||||
- `Polynomial` secondary constructor made function
|
||||
- Kotlin version: 1.3.72 -> 1.4.21
|
||||
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library
|
||||
- Full autodiff refactoring based on `Symbol`
|
||||
- `kmath-prob` renamed to `kmath-stat`
|
||||
- Grid generators moved to `kmath-for-real`
|
||||
@ -32,6 +33,8 @@
|
||||
- Optimized dot product for buffer matrices moved to `kmath-for-real`
|
||||
- EjmlMatrix context is an object
|
||||
- Matrix LUP `inverse` renamed to `inverseWithLUP`
|
||||
- `NumericAlgebra` moved outside of regular algebra chain (`Ring` no longer implements it).
|
||||
- Features moved to NDStructure and became transparent.
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
13
README.md
13
README.md
@ -89,7 +89,16 @@ submit a feature request if you want something to be implemented first.
|
||||
* ### [kmath-ast](kmath-ast)
|
||||
>
|
||||
>
|
||||
> **Maturity**: EXPERIMENTAL
|
||||
> **Maturity**: PROTOTYPE
|
||||
>
|
||||
> **Features:**
|
||||
> - [expression-language](kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt) : Expression language and its parser
|
||||
> - [mst](kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation
|
||||
> - [mst-building](kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt) : MST building algebraic structure
|
||||
> - [mst-interpreter](kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST interpreter
|
||||
> - [mst-jvm-codegen](kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
|
||||
> - [mst-js-codegen](kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler
|
||||
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-commons](kmath-commons)
|
||||
@ -122,7 +131,7 @@ submit a feature request if you want something to be implemented first.
|
||||
* ### [kmath-dimensions](kmath-dimensions)
|
||||
>
|
||||
>
|
||||
> **Maturity**: EXPERIMENTAL
|
||||
> **Maturity**: PROTOTYPE
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-ejml](kmath-ejml)
|
||||
|
@ -4,7 +4,7 @@ plugins {
|
||||
id("ru.mipt.npm.project")
|
||||
}
|
||||
|
||||
internal val kmathVersion: String by extra("0.2.0-dev-4")
|
||||
internal val kmathVersion: String by extra("0.2.0-dev-5")
|
||||
internal val bintrayRepo: String by extra("kscience")
|
||||
internal val githubProject: String by extra("kmath")
|
||||
|
||||
@ -38,3 +38,7 @@ readme {
|
||||
apiValidation {
|
||||
validationDisabled = true
|
||||
}
|
||||
|
||||
ksciencePublish {
|
||||
spaceRepo = "https://maven.pkg.jetbrains.space/mipt-npm/p/sci/maven"
|
||||
}
|
||||
|
2
docs/templates/ARTIFACT-TEMPLATE.md
vendored
2
docs/templates/ARTIFACT-TEMPLATE.md
vendored
@ -14,7 +14,7 @@
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
|
||||
>
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
|
@ -4,13 +4,19 @@ import kscience.kmath.asm.compile
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.expressionInField
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.expressions.symbol
|
||||
import kscience.kmath.operations.Field
|
||||
import kscience.kmath.operations.RealField
|
||||
import org.openjdk.jmh.annotations.Benchmark
|
||||
import org.openjdk.jmh.annotations.Scope
|
||||
import org.openjdk.jmh.annotations.State
|
||||
import kotlin.random.Random
|
||||
import kotlin.system.measureTimeMillis
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
internal class ExpressionsInterpretersBenchmark {
|
||||
private val algebra: Field<Double> = RealField
|
||||
|
||||
@Benchmark
|
||||
fun functionalExpression() {
|
||||
val expr = algebra.expressionInField {
|
||||
symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
|
||||
@ -19,22 +25,31 @@ internal class ExpressionsInterpretersBenchmark {
|
||||
invokeAndSum(expr)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun mstExpression() {
|
||||
val expr = algebra.mstInField {
|
||||
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||
symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0
|
||||
}
|
||||
|
||||
invokeAndSum(expr)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun asmExpression() {
|
||||
val expr = algebra.mstInField {
|
||||
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||
symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0
|
||||
}.compile()
|
||||
|
||||
invokeAndSum(expr)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun rawExpression() {
|
||||
val x by symbol
|
||||
val expr = Expression<Double> { args -> args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 }
|
||||
invokeAndSum(expr)
|
||||
}
|
||||
|
||||
private fun invokeAndSum(expr: Expression<Double>) {
|
||||
val random = Random(0)
|
||||
var sum = 0.0
|
||||
@ -46,35 +61,3 @@ internal class ExpressionsInterpretersBenchmark {
|
||||
println(sum)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
|
||||
* core FunctionalExpressions API.
|
||||
*
|
||||
* The expected rating is:
|
||||
*
|
||||
* 1. ASM.
|
||||
* 2. MST.
|
||||
* 3. FE.
|
||||
*/
|
||||
fun main() {
|
||||
val benchmark = ExpressionsInterpretersBenchmark()
|
||||
|
||||
val fe = measureTimeMillis {
|
||||
benchmark.functionalExpression()
|
||||
}
|
||||
|
||||
println("fe=$fe")
|
||||
|
||||
val mst = measureTimeMillis {
|
||||
benchmark.mstExpression()
|
||||
}
|
||||
|
||||
println("mst=$mst")
|
||||
|
||||
val asm = measureTimeMillis {
|
||||
benchmark.asmExpression()
|
||||
}
|
||||
|
||||
println("asm=$asm")
|
||||
}
|
@ -16,17 +16,13 @@ internal class ArrayBenchmark {
|
||||
@Benchmark
|
||||
fun benchmarkBufferRead() {
|
||||
var res = 0
|
||||
for (i in 1..size) res += arrayBuffer.get(
|
||||
size - i
|
||||
)
|
||||
for (i in 1..size) res += arrayBuffer[size - i]
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun nativeBufferRead() {
|
||||
var res = 0
|
||||
for (i in 1..size) res += nativeBuffer.get(
|
||||
size - i
|
||||
)
|
||||
for (i in 1..size) res += nativeBuffer[size - i]
|
||||
}
|
||||
|
||||
companion object {
|
||||
|
@ -2,19 +2,21 @@ package kscience.kmath.benchmarks
|
||||
|
||||
import kotlinx.benchmark.Benchmark
|
||||
import kscience.kmath.commons.linear.CMMatrixContext
|
||||
import kscience.kmath.commons.linear.CMMatrixContext.dot
|
||||
import kscience.kmath.commons.linear.toCM
|
||||
import kscience.kmath.ejml.EjmlMatrixContext
|
||||
import kscience.kmath.ejml.toEjml
|
||||
|
||||
import kscience.kmath.linear.BufferMatrixContext
|
||||
import kscience.kmath.linear.RealMatrixContext
|
||||
import kscience.kmath.linear.real
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.Matrix
|
||||
import org.openjdk.jmh.annotations.Scope
|
||||
import org.openjdk.jmh.annotations.State
|
||||
import kotlin.random.Random
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
class MultiplicationBenchmark {
|
||||
class DotBenchmark {
|
||||
companion object {
|
||||
val random = Random(12224)
|
||||
val dim = 1000
|
||||
@ -23,38 +25,48 @@ class MultiplicationBenchmark {
|
||||
val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||
val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||
|
||||
val cmMatrix1 = matrix1.toCM()
|
||||
val cmMatrix2 = matrix2.toCM()
|
||||
val cmMatrix1 = CMMatrixContext { matrix1.toCM() }
|
||||
val cmMatrix2 = CMMatrixContext { matrix2.toCM() }
|
||||
|
||||
val ejmlMatrix1 = matrix1.toEjml()
|
||||
val ejmlMatrix2 = matrix2.toEjml()
|
||||
val ejmlMatrix1 = EjmlMatrixContext { matrix1.toEjml() }
|
||||
val ejmlMatrix2 = EjmlMatrixContext { matrix2.toEjml() }
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun commonsMathMultiplication() {
|
||||
CMMatrixContext.invoke {
|
||||
CMMatrixContext {
|
||||
cmMatrix1 dot cmMatrix2
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun ejmlMultiplication() {
|
||||
EjmlMatrixContext.invoke {
|
||||
EjmlMatrixContext {
|
||||
ejmlMatrix1 dot ejmlMatrix2
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun ejmlMultiplicationwithConversion() {
|
||||
EjmlMatrixContext {
|
||||
val ejmlMatrix1 = matrix1.toEjml()
|
||||
val ejmlMatrix2 = matrix2.toEjml()
|
||||
EjmlMatrixContext.invoke {
|
||||
|
||||
ejmlMatrix1 dot ejmlMatrix2
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun bufferedMultiplication() {
|
||||
BufferMatrixContext(RealField, Buffer.Companion::real).invoke {
|
||||
matrix1 dot matrix2
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun realMultiplication() {
|
||||
RealMatrixContext {
|
||||
matrix1 dot matrix2
|
||||
}
|
||||
}
|
||||
}
|
@ -5,10 +5,8 @@ import kotlinx.benchmark.Benchmark
|
||||
import kscience.kmath.commons.linear.CMMatrixContext
|
||||
import kscience.kmath.commons.linear.CMMatrixContext.dot
|
||||
import kscience.kmath.commons.linear.inverse
|
||||
import kscience.kmath.commons.linear.toCM
|
||||
import kscience.kmath.ejml.EjmlMatrixContext
|
||||
import kscience.kmath.ejml.inverse
|
||||
import kscience.kmath.ejml.toEjml
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.structures.Matrix
|
||||
import org.openjdk.jmh.annotations.Scope
|
||||
@ -35,16 +33,14 @@ class LinearAlgebraBenchmark {
|
||||
@Benchmark
|
||||
fun cmLUPInversion() {
|
||||
CMMatrixContext {
|
||||
val cm = matrix.toCM() //avoid overhead on conversion
|
||||
inverse(cm)
|
||||
inverse(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun ejmlInverse() {
|
||||
EjmlMatrixContext {
|
||||
val km = matrix.toEjml() //avoid overhead on conversion
|
||||
inverse(km)
|
||||
inverse(matrix)
|
||||
}
|
||||
}
|
||||
}
|
@ -63,4 +63,6 @@ fun main(): Unit = runBlocking(Dispatchers.Default) {
|
||||
val directJob = async { runApacheDirect() }
|
||||
println("KMath Chained: ${chainJob.await()}")
|
||||
println("Apache Direct: ${directJob.await()}")
|
||||
val normal = GaussianSampler.of(7.0, 2.0)
|
||||
val chain = normal.sample(generator).blocking()
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ fun main() {
|
||||
val n = 1000
|
||||
|
||||
val realField = NDField.real(dim, dim)
|
||||
val complexField = NDField.complex(dim, dim)
|
||||
val complexField: ComplexNDField = NDField.complex(dim, dim)
|
||||
|
||||
val realTime = measureTimeMillis {
|
||||
realField {
|
||||
|
@ -33,7 +33,7 @@ fun main() {
|
||||
measureAndPrint("Automatic field addition") {
|
||||
autoField {
|
||||
var res: NDBuffer<Double> = one
|
||||
repeat(n) { res += number(1.0) }
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ fun main() {
|
||||
measureAndPrint("Nd4j specialized addition") {
|
||||
nd4jField {
|
||||
var res = one
|
||||
repeat(n) { res += 1.0 as Number }
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ fun main() {
|
||||
genericField {
|
||||
var res: NDBuffer<Double> = one
|
||||
repeat(n) {
|
||||
res += one // couldn't avoid using `one` due to resolution ambiguity }
|
||||
res += 1.0 // couldn't avoid using `one` due to resolution ambiguity }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,9 +4,8 @@ import kscience.kmath.dimensions.D2
|
||||
import kscience.kmath.dimensions.D3
|
||||
import kscience.kmath.dimensions.DMatrixContext
|
||||
import kscience.kmath.dimensions.Dimension
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
private fun DMatrixContext<Double, RealField>.simple() {
|
||||
private fun DMatrixContext<Double>.simple() {
|
||||
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
||||
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
||||
|
||||
@ -18,7 +17,7 @@ private object D5 : Dimension {
|
||||
override val dim: UInt = 5u
|
||||
}
|
||||
|
||||
private fun DMatrixContext<Double, RealField>.custom() {
|
||||
private fun DMatrixContext<Double>.custom() {
|
||||
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
|
||||
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
|
||||
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
|
||||
|
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
@ -2,43 +2,55 @@
|
||||
|
||||
This subproject implements the following features:
|
||||
|
||||
- Expression Language and its parser.
|
||||
- MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation.
|
||||
- Type-safe builder for MST.
|
||||
- Evaluating expressions by traversing MST.
|
||||
- [expression-language](src/jvmMain/kotlin/kscience/kmath/ast/parser.kt) : Expression language and its parser
|
||||
- [mst](src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation
|
||||
- [mst-building](src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt) : MST building algebraic structure
|
||||
- [mst-interpreter](src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST interpreter
|
||||
- [mst-jvm-codegen](src/jvmMain/kotlin/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
|
||||
- [mst-js-codegen](src/jsMain/kotlin/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler
|
||||
|
||||
|
||||
> #### Artifact:
|
||||
> This module is distributed in the artifact `kscience.kmath:kmath-ast:0.1.4-dev-8`.
|
||||
>
|
||||
> This module artifact: `kscience.kmath:kmath-ast:0.2.0-dev-4`.
|
||||
>
|
||||
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-ast/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-ast/_latestVersion)
|
||||
>
|
||||
> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-ast/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-ast/_latestVersion)
|
||||
>
|
||||
> **Gradle:**
|
||||
>
|
||||
> ```gradle
|
||||
> repositories {
|
||||
> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||
> maven { url https://dl.bintray.com/hotkeytlt/maven' }
|
||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
>
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
> implementation 'kscience.kmath:kmath-ast:0.1.4-dev-8'
|
||||
> implementation 'kscience.kmath:kmath-ast:0.2.0-dev-4'
|
||||
> }
|
||||
> ```
|
||||
> **Gradle Kotlin DSL:**
|
||||
>
|
||||
> ```kotlin
|
||||
> repositories {
|
||||
> maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
> implementation("kscience.kmath:kmath-ast:0.1.4-dev-8")
|
||||
> implementation("kscience.kmath:kmath-ast:0.2.0-dev-4")
|
||||
> }
|
||||
> ```
|
||||
>
|
||||
|
||||
## Dynamic Expression Code Generation with ObjectWeb ASM
|
||||
## Dynamic expression code generation
|
||||
|
||||
### On JVM
|
||||
|
||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
|
||||
a special implementation of `Expression<T>` with implemented `invoke` function.
|
||||
@ -55,19 +67,20 @@ RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
package kscience.kmath.asm.generated;
|
||||
|
||||
import java.util.Map;
|
||||
import kotlin.jvm.functions.Function2;
|
||||
import kscience.kmath.asm.internal.MapIntrinsics;
|
||||
import kscience.kmath.expressions.Expression;
|
||||
import kscience.kmath.operations.RealField;
|
||||
import kscience.kmath.expressions.Symbol;
|
||||
|
||||
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
|
||||
private final RealField algebra;
|
||||
public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
||||
private final Object[] constants;
|
||||
|
||||
public final Double invoke(Map<String, ? extends Double> arguments) {
|
||||
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D);
|
||||
public final Double invoke(Map<Symbol, Double> arguments) {
|
||||
return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2);
|
||||
}
|
||||
|
||||
public AsmCompiledExpression_1073786867_0(RealField algebra) {
|
||||
this.algebra = algebra;
|
||||
public AsmCompiledExpression_45045_0(Object[] constants) {
|
||||
this.constants = constants;
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,10 +95,28 @@ RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
RealField.expression("x+2".parseMath())
|
||||
```
|
||||
|
||||
### Known issues
|
||||
#### Known issues
|
||||
|
||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
|
||||
class loading overhead.
|
||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
||||
|
||||
Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis).
|
||||
### On JS
|
||||
|
||||
A similar feature is also available on JS.
|
||||
|
||||
```kotlin
|
||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
```
|
||||
|
||||
The code above returns expression implemented with such a JS function:
|
||||
|
||||
```js
|
||||
var executable = function (constants, arguments) {
|
||||
return constants[1](constants[0](arguments, "x"), 2);
|
||||
};
|
||||
```
|
||||
|
||||
#### Known issues
|
||||
|
||||
- This feature uses `eval` which can be unavailable in several environments.
|
||||
|
@ -1,7 +1,23 @@
|
||||
import ru.mipt.npm.gradle.Maturity
|
||||
|
||||
plugins {
|
||||
id("ru.mipt.npm.mpp")
|
||||
}
|
||||
|
||||
kotlin.js {
|
||||
nodejs {
|
||||
testTask {
|
||||
useMocha().timeout = "0"
|
||||
}
|
||||
}
|
||||
|
||||
browser {
|
||||
testTask {
|
||||
useMocha().timeout = "0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
@ -9,15 +25,58 @@ kotlin.sourceSets {
|
||||
}
|
||||
}
|
||||
|
||||
jsMain {
|
||||
dependencies {
|
||||
implementation(npm("astring", "1.4.3"))
|
||||
}
|
||||
}
|
||||
|
||||
jvmMain {
|
||||
dependencies {
|
||||
api("com.github.h0tk3y.betterParse:better-parse:0.4.0")
|
||||
implementation("org.ow2.asm:asm:8.0.1")
|
||||
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||
implementation("org.ow2.asm:asm:9.0")
|
||||
implementation("org.ow2.asm:asm-commons:9.0")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
readme {
|
||||
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||
maturity = Maturity.PROTOTYPE
|
||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||
|
||||
feature(
|
||||
id = "expression-language",
|
||||
description = "Expression language and its parser",
|
||||
ref = "src/jvmMain/kotlin/kscience/kmath/ast/parser.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "mst",
|
||||
description = "MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/ast/MST.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "mst-building",
|
||||
description = "MST building algebraic structure",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "mst-interpreter",
|
||||
description = "MST interpreter",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/ast/MST.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "mst-jvm-codegen",
|
||||
description = "Dynamic MST to JVM bytecode compiler",
|
||||
ref = "src/jvmMain/kotlin/kscience/kmath/asm/asm.kt"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "mst-js-codegen",
|
||||
description = "Dynamic MST to JS compiler",
|
||||
ref = "src/jsMain/kotlin/kscience/kmath/estree/estree.kt"
|
||||
)
|
||||
}
|
80
kmath-ast/docs/README-TEMPLATE.md
Normal file
80
kmath-ast/docs/README-TEMPLATE.md
Normal file
@ -0,0 +1,80 @@
|
||||
# Abstract Syntax Tree Expression Representation and Operations (`kmath-ast`)
|
||||
|
||||
This subproject implements the following features:
|
||||
|
||||
${features}
|
||||
|
||||
${artifact}
|
||||
|
||||
## Dynamic expression code generation
|
||||
|
||||
### On JVM
|
||||
|
||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds
|
||||
a special implementation of `Expression<T>` with implemented `invoke` function.
|
||||
|
||||
For example, the following builder:
|
||||
|
||||
```kotlin
|
||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
```
|
||||
|
||||
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||
|
||||
```java
|
||||
package kscience.kmath.asm.generated;
|
||||
|
||||
import java.util.Map;
|
||||
import kotlin.jvm.functions.Function2;
|
||||
import kscience.kmath.asm.internal.MapIntrinsics;
|
||||
import kscience.kmath.expressions.Expression;
|
||||
import kscience.kmath.expressions.Symbol;
|
||||
|
||||
public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
||||
private final Object[] constants;
|
||||
|
||||
public final Double invoke(Map<Symbol, Double> arguments) {
|
||||
return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2);
|
||||
}
|
||||
|
||||
public AsmCompiledExpression_45045_0(Object[] constants) {
|
||||
this.constants = constants;
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### Example Usage
|
||||
|
||||
This API extends MST and MstExpression, so you may optimize as both of them:
|
||||
|
||||
```kotlin
|
||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
RealField.expression("x+2".parseMath())
|
||||
```
|
||||
|
||||
#### Known issues
|
||||
|
||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid
|
||||
class loading overhead.
|
||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
||||
|
||||
### On JS
|
||||
|
||||
A similar feature is also available on JS.
|
||||
|
||||
```kotlin
|
||||
RealField.mstInField { symbol("x") + 2 }.compile()
|
||||
```
|
||||
|
||||
The code above returns expression implemented with such a JS function:
|
||||
|
||||
```js
|
||||
var executable = function (constants, arguments) {
|
||||
return constants[1](constants[0](arguments, "x"), 2);
|
||||
};
|
||||
```
|
||||
|
||||
#### Known issues
|
||||
|
||||
- This feature uses `eval` which can be unavailable in several environments.
|
@ -2,10 +2,9 @@ package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
* A Mathematical Syntax Tree node for mathematical expressions.
|
||||
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
@ -55,24 +54,25 @@ public sealed class MST {
|
||||
public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField.binaryOperation(
|
||||
node.operation,
|
||||
node.left.value.toDouble(),
|
||||
node.right.value.toDouble()
|
||||
)
|
||||
|
||||
number(number)
|
||||
is MST.Unary -> when {
|
||||
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
|
||||
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
||||
}
|
||||
|
||||
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
|
||||
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
|
||||
is MST.Binary -> when {
|
||||
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
|
||||
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
|
||||
|
||||
this is NumericAlgebra && node.left is MST.Numeric ->
|
||||
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
||||
|
||||
this is NumericAlgebra && node.right is MST.Numeric ->
|
||||
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
||||
|
||||
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,64 +1,83 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
|
||||
/**
|
||||
* [Algebra] over [MST] nodes.
|
||||
*/
|
||||
public object MstAlgebra : NumericAlgebra<MST> {
|
||||
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||
public override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||
public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
||||
|
||||
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
||||
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary =
|
||||
{ arg -> MST.Unary(operation, arg) }
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
|
||||
MST.Unary(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MST.Binary(operation, left, right)
|
||||
public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
{ left, right -> MST.Binary(operation, left, right) }
|
||||
}
|
||||
|
||||
/**
|
||||
* [Space] over [MST] nodes.
|
||||
*/
|
||||
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST.Numeric by lazy { number(0.0) }
|
||||
public override val zero: MST.Numeric by lazy { number(0.0) }
|
||||
|
||||
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||
public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||
public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||
public override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(SpaceOperations.PLUS_OPERATION)(a, b)
|
||||
public override operator fun MST.unaryPlus(): MST.Unary =
|
||||
unaryOperationFunction(SpaceOperations.PLUS_OPERATION)(this)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstAlgebra.binaryOperation(operation, left, right)
|
||||
public override operator fun MST.unaryMinus(): MST.Unary =
|
||||
unaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
|
||||
public override operator fun MST.minus(b: MST): MST.Binary =
|
||||
binaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this, b)
|
||||
|
||||
public override fun multiply(a: MST, k: Number): MST.Binary =
|
||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(k))
|
||||
|
||||
public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
MstAlgebra.binaryOperationFunction(operation)
|
||||
|
||||
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary =
|
||||
MstAlgebra.unaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
/**
|
||||
* [Ring] over [MST] nodes.
|
||||
*/
|
||||
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST.Numeric
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object MstRing : Ring<MST>, RingWithNumbers<MST> {
|
||||
public override val zero: MST.Numeric
|
||||
get() = MstSpace.zero
|
||||
|
||||
override val one: MST.Numeric by lazy { number(1.0) }
|
||||
public override val one: MST.Numeric by lazy { number(1.0) }
|
||||
|
||||
override fun number(value: Number): MST.Numeric = MstSpace.number(value)
|
||||
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
public override fun number(value: Number): MST.Numeric = MstSpace.number(value)
|
||||
public override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
||||
public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST.Binary =
|
||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstSpace.binaryOperation(operation, left, right)
|
||||
public override operator fun MST.unaryPlus(): MST.Unary = MstSpace { +this@unaryPlus }
|
||||
public override operator fun MST.unaryMinus(): MST.Unary = MstSpace { -this@unaryMinus }
|
||||
public override operator fun MST.minus(b: MST): MST.Binary = MstSpace { this@minus - b }
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
|
||||
public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
MstSpace.binaryOperationFunction(operation)
|
||||
|
||||
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary =
|
||||
MstAlgebra.unaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
/**
|
||||
* [Field] over [MST] nodes.
|
||||
*/
|
||||
public object MstField : Field<MST> {
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object MstField : Field<MST>, RingWithNumbers<MST> {
|
||||
public override val zero: MST.Numeric
|
||||
get() = MstRing.zero
|
||||
|
||||
@ -70,51 +89,61 @@ public object MstField : Field<MST> {
|
||||
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
public override fun divide(a: MST, b: MST): MST.Binary =
|
||||
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstRing.binaryOperation(operation, left, right)
|
||||
public override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
|
||||
public override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
|
||||
public override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b }
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
|
||||
public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
MstRing.binaryOperationFunction(operation)
|
||||
|
||||
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary =
|
||||
MstRing.unaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
/**
|
||||
* [ExtendedField] over [MST] nodes.
|
||||
*/
|
||||
public object MstExtendedField : ExtendedField<MST> {
|
||||
override val zero: MST.Numeric
|
||||
public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
|
||||
public override val zero: MST.Numeric
|
||||
get() = MstField.zero
|
||||
|
||||
override val one: MST.Numeric
|
||||
public override val one: MST.Numeric
|
||||
get() = MstField.one
|
||||
|
||||
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
|
||||
override fun number(value: Number): MST.Numeric = MstField.number(value)
|
||||
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
||||
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
||||
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
||||
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
||||
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
||||
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
||||
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
||||
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||
public override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
|
||||
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
||||
public override fun sin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg)
|
||||
public override fun cos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg)
|
||||
public override fun tan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.TAN_OPERATION)(arg)
|
||||
public override fun asin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ASIN_OPERATION)(arg)
|
||||
public override fun acos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ACOS_OPERATION)(arg)
|
||||
public override fun atan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ATAN_OPERATION)(arg)
|
||||
public override fun sinh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.SINH_OPERATION)(arg)
|
||||
public override fun cosh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.COSH_OPERATION)(arg)
|
||||
public override fun tanh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.TANH_OPERATION)(arg)
|
||||
public override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.ASINH_OPERATION)(arg)
|
||||
public override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.ACOSH_OPERATION)(arg)
|
||||
public override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.ATANH_OPERATION)(arg)
|
||||
public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||
public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||
public override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus }
|
||||
public override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus }
|
||||
public override operator fun MST.minus(b: MST): MST.Binary = MstField { this@minus - b }
|
||||
|
||||
override fun power(arg: MST, pow: Number): MST.Binary =
|
||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||
public override fun power(arg: MST, pow: Number): MST.Binary =
|
||||
binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow))
|
||||
|
||||
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||
public override fun exp(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg)
|
||||
public override fun ln(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstField.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
MstField.binaryOperationFunction(operation)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
|
||||
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary =
|
||||
MstField.unaryOperationFunction(operation)
|
||||
}
|
||||
|
@ -15,11 +15,14 @@ import kotlin.contracts.contract
|
||||
*/
|
||||
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||
override fun symbol(value: String): T = try {
|
||||
algebra.symbol(value)
|
||||
} catch (ignored: IllegalStateException) {
|
||||
null
|
||||
} ?: arguments.getValue(StringSymbol(value))
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
algebra.binaryOperation(operation, left, right)
|
||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T = algebra.unaryOperationFunction(operation)
|
||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = algebra.binaryOperationFunction(operation)
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||
|
82
kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt
Normal file
82
kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt
Normal file
@ -0,0 +1,82 @@
|
||||
package kscience.kmath.estree
|
||||
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MST.*
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.estree.internal.ESTreeBuilder
|
||||
import kscience.kmath.estree.internal.estree.BaseExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
@PublishedApi
|
||||
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
||||
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
|
||||
is Symbolic -> {
|
||||
val symbol = try {
|
||||
algebra.symbol(node.value)
|
||||
} catch (ignored: IllegalStateException) {
|
||||
null
|
||||
}
|
||||
|
||||
if (symbol != null)
|
||||
constant(symbol)
|
||||
else
|
||||
variable(node.value)
|
||||
}
|
||||
|
||||
is Numeric -> constant(node.value)
|
||||
|
||||
is Unary -> when {
|
||||
algebra is NumericAlgebra && node.value is Numeric -> constant(
|
||||
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||
|
||||
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
||||
}
|
||||
|
||||
is Binary -> when {
|
||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
|
||||
algebra
|
||||
.binaryOperationFunction(node.operation)
|
||||
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra && node.left is Numeric -> call(
|
||||
algebra.leftSideNumberOperationFunction(node.operation),
|
||||
visit(node.left),
|
||||
visit(node.right),
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra && node.right is Numeric -> call(
|
||||
algebra.rightSideNumberOperationFunction(node.operation),
|
||||
visit(node.left),
|
||||
visit(node.right),
|
||||
)
|
||||
|
||||
else -> call(
|
||||
algebra.binaryOperationFunction(node.operation),
|
||||
visit(node.left),
|
||||
visit(node.right),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compiles an [MST] to ESTree generated expression using given algebra.
|
||||
*
|
||||
* @author Alexander Nozik.
|
||||
*/
|
||||
public fun <T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
|
||||
mst.compileWith(this)
|
||||
|
||||
/**
|
||||
* Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression.
|
||||
*
|
||||
* @author Alexander Nozik.
|
||||
*/
|
||||
public fun <T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
|
||||
mst.compileWith(algebra)
|
@ -0,0 +1,79 @@
|
||||
package kscience.kmath.estree.internal
|
||||
|
||||
import kscience.kmath.estree.internal.astring.generate
|
||||
import kscience.kmath.estree.internal.estree.*
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
|
||||
internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExpression) {
|
||||
private class GeneratedExpression<T>(val executable: dynamic, val constants: Array<dynamic>) : Expression<T> {
|
||||
@Suppress("UNUSED_VARIABLE")
|
||||
override fun invoke(arguments: Map<Symbol, T>): T {
|
||||
val e = executable
|
||||
val c = constants
|
||||
val a = js("{}")
|
||||
arguments.forEach { (key, value) -> a[key.identity] = value }
|
||||
return js("e(c, a)").unsafeCast<T>()
|
||||
}
|
||||
}
|
||||
|
||||
val instance: Expression<T> by lazy {
|
||||
val node = Program(
|
||||
sourceType = "script",
|
||||
VariableDeclaration(
|
||||
kind = "var",
|
||||
VariableDeclarator(
|
||||
id = Identifier("executable"),
|
||||
init = FunctionExpression(
|
||||
params = arrayOf(Identifier("constants"), Identifier("arguments")),
|
||||
body = BlockStatement(ReturnStatement(bodyCallback())),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
eval(generate(node))
|
||||
GeneratedExpression(js("executable"), constants.toTypedArray())
|
||||
}
|
||||
|
||||
private val constants = mutableListOf<Any>()
|
||||
|
||||
fun constant(value: Any?) = when {
|
||||
value == null || jsTypeOf(value) == "number" || jsTypeOf(value) == "string" || jsTypeOf(value) == "boolean" ->
|
||||
SimpleLiteral(value)
|
||||
|
||||
jsTypeOf(value) == "undefined" -> Identifier("undefined")
|
||||
|
||||
else -> {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
|
||||
|
||||
MemberExpression(
|
||||
computed = true,
|
||||
optional = false,
|
||||
`object` = Identifier("constants"),
|
||||
property = SimpleLiteral(idx),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name))
|
||||
|
||||
fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
|
||||
optional = false,
|
||||
callee = constant(function),
|
||||
*args,
|
||||
)
|
||||
|
||||
private companion object {
|
||||
@Suppress("UNUSED_VARIABLE")
|
||||
val getOrFail: (`object`: dynamic, key: String) -> dynamic = { `object`, key ->
|
||||
val k = key
|
||||
val o = `object`
|
||||
|
||||
if (!(js("k in o") as Boolean))
|
||||
throw NoSuchElementException("Key $key is missing in the map.")
|
||||
|
||||
js("o[k]")
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
@file:JsModule("astring")
|
||||
@file:JsNonModule
|
||||
|
||||
package kscience.kmath.estree.internal.astring
|
||||
|
||||
import kscience.kmath.estree.internal.estree.BaseNode
|
||||
|
||||
internal external interface Options {
|
||||
var indent: String?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var lineEnd: String?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var startingIndentLevel: Number?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var comments: Boolean?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var generator: Any?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var sourceMap: Any?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external fun generate(node: BaseNode, options: Options /* Options & `T$0` */ = definedExternally): String
|
||||
|
||||
internal external fun generate(node: BaseNode): String
|
||||
|
||||
internal external var baseGenerator: Generator
|
@ -0,0 +1,3 @@
|
||||
package kscience.kmath.estree.internal.astring
|
||||
|
||||
internal typealias Generator = Any
|
@ -0,0 +1,13 @@
|
||||
package kscience.kmath.estree.internal.emitter
|
||||
|
||||
internal open external class Emitter {
|
||||
constructor(obj: Any)
|
||||
constructor()
|
||||
|
||||
open fun on(event: String, fn: () -> Unit)
|
||||
open fun off(event: String, fn: () -> Unit)
|
||||
open fun once(event: String, fn: () -> Unit)
|
||||
open fun emit(event: String, vararg any: Any)
|
||||
open fun listeners(event: String): Array<() -> Unit>
|
||||
open fun hasListeners(event: String): Boolean
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
package kscience.kmath.estree.internal.estree
|
||||
|
||||
internal fun Program(sourceType: String, vararg body: dynamic) = object : Program {
|
||||
override var type = "Program"
|
||||
override var sourceType = sourceType
|
||||
override var body = body
|
||||
}
|
||||
|
||||
internal fun VariableDeclaration(kind: String, vararg declarations: VariableDeclarator) = object : VariableDeclaration {
|
||||
override var type = "VariableDeclaration"
|
||||
override var declarations = declarations.toList().toTypedArray()
|
||||
override var kind = kind
|
||||
}
|
||||
|
||||
internal fun VariableDeclarator(id: dynamic, init: dynamic) = object : VariableDeclarator {
|
||||
override var type = "VariableDeclarator"
|
||||
override var id = id
|
||||
override var init = init
|
||||
}
|
||||
|
||||
internal fun Identifier(name: String) = object : Identifier {
|
||||
override var type = "Identifier"
|
||||
override var name = name
|
||||
}
|
||||
|
||||
internal fun FunctionExpression(params: Array<dynamic>, body: BlockStatement) = object : FunctionExpression {
|
||||
override var params = params
|
||||
override var type = "FunctionExpression"
|
||||
override var body = body
|
||||
}
|
||||
|
||||
internal fun BlockStatement(vararg body: dynamic) = object : BlockStatement {
|
||||
override var type = "BlockStatement"
|
||||
override var body = body
|
||||
}
|
||||
|
||||
internal fun ReturnStatement(argument: dynamic) = object : ReturnStatement {
|
||||
override var type = "ReturnStatement"
|
||||
override var argument = argument
|
||||
}
|
||||
|
||||
internal fun SimpleLiteral(value: dynamic) = object : SimpleLiteral {
|
||||
override var type = "Literal"
|
||||
override var value = value
|
||||
}
|
||||
|
||||
internal fun MemberExpression(computed: Boolean, optional: Boolean, `object`: dynamic, property: dynamic) =
|
||||
object : MemberExpression {
|
||||
override var type = "MemberExpression"
|
||||
override var computed = computed
|
||||
override var optional = optional
|
||||
override var `object` = `object`
|
||||
override var property = property
|
||||
}
|
||||
|
||||
internal fun SimpleCallExpression(optional: Boolean, callee: dynamic, vararg arguments: dynamic) =
|
||||
object : SimpleCallExpression {
|
||||
override var type = "CallExpression"
|
||||
override var optional = optional
|
||||
override var callee = callee
|
||||
override var arguments = arguments
|
||||
}
|
@ -0,0 +1,644 @@
|
||||
package kscience.kmath.estree.internal.estree
|
||||
|
||||
import kotlin.js.RegExp
|
||||
|
||||
internal external interface BaseNodeWithoutComments {
|
||||
var type: String
|
||||
var loc: SourceLocation?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var range: dynamic /* JsTuple<Number, Number> */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface BaseNode : BaseNodeWithoutComments {
|
||||
var leadingComments: Array<Comment>?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var trailingComments: Array<Comment>?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface Comment : BaseNodeWithoutComments {
|
||||
override var type: String /* "Line" | "Block" */
|
||||
var value: String
|
||||
}
|
||||
|
||||
internal external interface SourceLocation {
|
||||
var source: String?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var start: Position
|
||||
var end: Position
|
||||
}
|
||||
|
||||
internal external interface Position {
|
||||
var line: Number
|
||||
var column: Number
|
||||
}
|
||||
|
||||
internal external interface Program : BaseNode {
|
||||
override var type: String /* "Program" */
|
||||
var sourceType: String /* "script" | "module" */
|
||||
var body: Array<dynamic /* Directive | ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration | ImportDeclaration | ExportNamedDeclaration | ExportDefaultDeclaration | ExportAllDeclaration */>
|
||||
var comments: Array<Comment>?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface Directive : BaseNode {
|
||||
override var type: String /* "ExpressionStatement" */
|
||||
var expression: dynamic /* SimpleLiteral | RegExpLiteral */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var directive: String
|
||||
}
|
||||
|
||||
internal external interface BaseFunction : BaseNode {
|
||||
var params: Array<dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */>
|
||||
var generator: Boolean?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var async: Boolean?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var body: dynamic /* BlockStatement | ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface BaseStatement : BaseNode
|
||||
|
||||
internal external interface EmptyStatement : BaseStatement {
|
||||
override var type: String /* "EmptyStatement" */
|
||||
}
|
||||
|
||||
internal external interface BlockStatement : BaseStatement {
|
||||
override var type: String /* "BlockStatement" */
|
||||
var body: Array<dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */>
|
||||
var innerComments: Array<Comment>?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ExpressionStatement : BaseStatement {
|
||||
override var type: String /* "ExpressionStatement" */
|
||||
var expression: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface IfStatement : BaseStatement {
|
||||
override var type: String /* "IfStatement" */
|
||||
var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var consequent: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var alternate: dynamic /* ExpressionStatement? | BlockStatement? | EmptyStatement? | DebuggerStatement? | WithStatement? | ReturnStatement? | LabeledStatement? | BreakStatement? | ContinueStatement? | IfStatement? | SwitchStatement? | ThrowStatement? | TryStatement? | WhileStatement? | DoWhileStatement? | ForStatement? | ForInStatement? | ForOfStatement? | FunctionDeclaration? | VariableDeclaration? | ClassDeclaration? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface LabeledStatement : BaseStatement {
|
||||
override var type: String /* "LabeledStatement" */
|
||||
var label: Identifier
|
||||
var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface BreakStatement : BaseStatement {
|
||||
override var type: String /* "BreakStatement" */
|
||||
var label: Identifier?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ContinueStatement : BaseStatement {
|
||||
override var type: String /* "ContinueStatement" */
|
||||
var label: Identifier?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface WithStatement : BaseStatement {
|
||||
override var type: String /* "WithStatement" */
|
||||
var `object`: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface SwitchStatement : BaseStatement {
|
||||
override var type: String /* "SwitchStatement" */
|
||||
var discriminant: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var cases: Array<SwitchCase>
|
||||
}
|
||||
|
||||
internal external interface ReturnStatement : BaseStatement {
|
||||
override var type: String /* "ReturnStatement" */
|
||||
var argument: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ThrowStatement : BaseStatement {
|
||||
override var type: String /* "ThrowStatement" */
|
||||
var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface TryStatement : BaseStatement {
|
||||
override var type: String /* "TryStatement" */
|
||||
var block: BlockStatement
|
||||
var handler: CatchClause?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var finalizer: BlockStatement?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface WhileStatement : BaseStatement {
|
||||
override var type: String /* "WhileStatement" */
|
||||
var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface DoWhileStatement : BaseStatement {
|
||||
override var type: String /* "DoWhileStatement" */
|
||||
var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ForStatement : BaseStatement {
|
||||
override var type: String /* "ForStatement" */
|
||||
var init: dynamic /* VariableDeclaration? | ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var test: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var update: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface BaseForXStatement : BaseStatement {
|
||||
var left: dynamic /* VariableDeclaration | Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ForInStatement : BaseForXStatement {
|
||||
override var type: String /* "ForInStatement" */
|
||||
}
|
||||
|
||||
internal external interface DebuggerStatement : BaseStatement {
|
||||
override var type: String /* "DebuggerStatement" */
|
||||
}
|
||||
|
||||
internal external interface BaseDeclaration : BaseStatement
|
||||
|
||||
internal external interface FunctionDeclaration : BaseFunction, BaseDeclaration {
|
||||
override var type: String /* "FunctionDeclaration" */
|
||||
var id: Identifier?
|
||||
override var body: BlockStatement
|
||||
}
|
||||
|
||||
internal external interface VariableDeclaration : BaseDeclaration {
|
||||
override var type: String /* "VariableDeclaration" */
|
||||
var declarations: Array<VariableDeclarator>
|
||||
var kind: String /* "var" | "let" | "const" */
|
||||
}
|
||||
|
||||
internal external interface VariableDeclarator : BaseNode {
|
||||
override var type: String /* "VariableDeclarator" */
|
||||
var id: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var init: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface BaseExpression : BaseNode
|
||||
|
||||
internal external interface ChainExpression : BaseExpression {
|
||||
override var type: String /* "ChainExpression" */
|
||||
var expression: dynamic /* SimpleCallExpression | MemberExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ThisExpression : BaseExpression {
|
||||
override var type: String /* "ThisExpression" */
|
||||
}
|
||||
|
||||
internal external interface ArrayExpression : BaseExpression {
|
||||
override var type: String /* "ArrayExpression" */
|
||||
var elements: Array<dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | SpreadElement */>
|
||||
}
|
||||
|
||||
internal external interface ObjectExpression : BaseExpression {
|
||||
override var type: String /* "ObjectExpression" */
|
||||
var properties: Array<dynamic /* Property | SpreadElement */>
|
||||
}
|
||||
|
||||
internal external interface Property : BaseNode {
|
||||
override var type: String /* "Property" */
|
||||
var key: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var value: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var kind: String /* "init" | "get" | "set" */
|
||||
var method: Boolean
|
||||
var shorthand: Boolean
|
||||
var computed: Boolean
|
||||
}
|
||||
|
||||
internal external interface FunctionExpression : BaseFunction, BaseExpression {
|
||||
var id: Identifier?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
override var type: String /* "FunctionExpression" */
|
||||
override var body: BlockStatement
|
||||
}
|
||||
|
||||
internal external interface SequenceExpression : BaseExpression {
|
||||
override var type: String /* "SequenceExpression" */
|
||||
var expressions: Array<dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */>
|
||||
}
|
||||
|
||||
internal external interface UnaryExpression : BaseExpression {
|
||||
override var type: String /* "UnaryExpression" */
|
||||
var operator: String /* "-" | "+" | "!" | "~" | "typeof" | "void" | "delete" */
|
||||
var prefix: Boolean
|
||||
var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface BinaryExpression : BaseExpression {
|
||||
override var type: String /* "BinaryExpression" */
|
||||
var operator: String /* "==" | "!=" | "===" | "!==" | "<" | "<=" | ">" | ">=" | "<<" | ">>" | ">>>" | "+" | "-" | "*" | "/" | "%" | "**" | "|" | "^" | "&" | "in" | "instanceof" */
|
||||
var left: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface AssignmentExpression : BaseExpression {
|
||||
override var type: String /* "AssignmentExpression" */
|
||||
var operator: String /* "=" | "+=" | "-=" | "*=" | "/=" | "%=" | "**=" | "<<=" | ">>=" | ">>>=" | "|=" | "^=" | "&=" */
|
||||
var left: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface UpdateExpression : BaseExpression {
|
||||
override var type: String /* "UpdateExpression" */
|
||||
var operator: String /* "++" | "--" */
|
||||
var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var prefix: Boolean
|
||||
}
|
||||
|
||||
internal external interface LogicalExpression : BaseExpression {
|
||||
override var type: String /* "LogicalExpression" */
|
||||
var operator: String /* "||" | "&&" | "??" */
|
||||
var left: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ConditionalExpression : BaseExpression {
|
||||
override var type: String /* "ConditionalExpression" */
|
||||
var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var alternate: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var consequent: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface BaseCallExpression : BaseExpression {
|
||||
var callee: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | Super */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var arguments: Array<dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | SpreadElement */>
|
||||
}
|
||||
|
||||
internal external interface SimpleCallExpression : BaseCallExpression {
|
||||
override var type: String /* "CallExpression" */
|
||||
var optional: Boolean
|
||||
}
|
||||
|
||||
internal external interface NewExpression : BaseCallExpression {
|
||||
override var type: String /* "NewExpression" */
|
||||
}
|
||||
|
||||
internal external interface MemberExpression : BaseExpression, BasePattern {
|
||||
override var type: String /* "MemberExpression" */
|
||||
var `object`: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | Super */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var property: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var computed: Boolean
|
||||
var optional: Boolean
|
||||
}
|
||||
|
||||
internal external interface BasePattern : BaseNode
|
||||
|
||||
internal external interface SwitchCase : BaseNode {
|
||||
override var type: String /* "SwitchCase" */
|
||||
var test: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var consequent: Array<dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */>
|
||||
}
|
||||
|
||||
internal external interface CatchClause : BaseNode {
|
||||
override var type: String /* "CatchClause" */
|
||||
var param: dynamic /* Identifier? | ObjectPattern? | ArrayPattern? | RestElement? | AssignmentPattern? | MemberExpression? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var body: BlockStatement
|
||||
}
|
||||
|
||||
internal external interface Identifier : BaseNode, BaseExpression, BasePattern {
|
||||
override var type: String /* "Identifier" */
|
||||
var name: String
|
||||
}
|
||||
|
||||
internal external interface SimpleLiteral : BaseNode, BaseExpression {
|
||||
override var type: String /* "Literal" */
|
||||
var value: dynamic /* String? | Boolean? | Number? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var raw: String?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface `T$1` {
|
||||
var pattern: String
|
||||
var flags: String
|
||||
}
|
||||
|
||||
internal external interface RegExpLiteral : BaseNode, BaseExpression {
|
||||
override var type: String /* "Literal" */
|
||||
var value: RegExp?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var regex: `T$1`
|
||||
var raw: String?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ForOfStatement : BaseForXStatement {
|
||||
override var type: String /* "ForOfStatement" */
|
||||
var await: Boolean
|
||||
}
|
||||
|
||||
internal external interface Super : BaseNode {
|
||||
override var type: String /* "Super" */
|
||||
}
|
||||
|
||||
internal external interface SpreadElement : BaseNode {
|
||||
override var type: String /* "SpreadElement" */
|
||||
var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ArrowFunctionExpression : BaseExpression, BaseFunction {
|
||||
override var type: String /* "ArrowFunctionExpression" */
|
||||
var expression: Boolean
|
||||
override var body: dynamic /* BlockStatement | ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface YieldExpression : BaseExpression {
|
||||
override var type: String /* "YieldExpression" */
|
||||
var argument: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var delegate: Boolean
|
||||
}
|
||||
|
||||
internal external interface TemplateLiteral : BaseExpression {
|
||||
override var type: String /* "TemplateLiteral" */
|
||||
var quasis: Array<TemplateElement>
|
||||
var expressions: Array<dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */>
|
||||
}
|
||||
|
||||
internal external interface TaggedTemplateExpression : BaseExpression {
|
||||
override var type: String /* "TaggedTemplateExpression" */
|
||||
var tag: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var quasi: TemplateLiteral
|
||||
}
|
||||
|
||||
internal external interface `T$2` {
|
||||
var cooked: String
|
||||
var raw: String
|
||||
}
|
||||
|
||||
internal external interface TemplateElement : BaseNode {
|
||||
override var type: String /* "TemplateElement" */
|
||||
var tail: Boolean
|
||||
var value: `T$2`
|
||||
}
|
||||
|
||||
internal external interface AssignmentProperty : Property {
|
||||
override var value: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
override var kind: String /* "init" */
|
||||
override var method: Boolean
|
||||
}
|
||||
|
||||
internal external interface ObjectPattern : BasePattern {
|
||||
override var type: String /* "ObjectPattern" */
|
||||
var properties: Array<dynamic /* AssignmentProperty | RestElement */>
|
||||
}
|
||||
|
||||
internal external interface ArrayPattern : BasePattern {
|
||||
override var type: String /* "ArrayPattern" */
|
||||
var elements: Array<dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */>
|
||||
}
|
||||
|
||||
internal external interface RestElement : BasePattern {
|
||||
override var type: String /* "RestElement" */
|
||||
var argument: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface AssignmentPattern : BasePattern {
|
||||
override var type: String /* "AssignmentPattern" */
|
||||
var left: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface BaseClass : BaseNode {
|
||||
var superClass: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var body: ClassBody
|
||||
}
|
||||
|
||||
internal external interface ClassBody : BaseNode {
|
||||
override var type: String /* "ClassBody" */
|
||||
var body: Array<MethodDefinition>
|
||||
}
|
||||
|
||||
internal external interface MethodDefinition : BaseNode {
|
||||
override var type: String /* "MethodDefinition" */
|
||||
var key: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var value: FunctionExpression
|
||||
var kind: String /* "constructor" | "method" | "get" | "set" */
|
||||
var computed: Boolean
|
||||
var static: Boolean
|
||||
}
|
||||
|
||||
internal external interface ClassDeclaration : BaseClass, BaseDeclaration {
|
||||
override var type: String /* "ClassDeclaration" */
|
||||
var id: Identifier?
|
||||
}
|
||||
|
||||
internal external interface ClassExpression : BaseClass, BaseExpression {
|
||||
override var type: String /* "ClassExpression" */
|
||||
var id: Identifier?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface MetaProperty : BaseExpression {
|
||||
override var type: String /* "MetaProperty" */
|
||||
var meta: Identifier
|
||||
var property: Identifier
|
||||
}
|
||||
|
||||
internal external interface BaseModuleDeclaration : BaseNode
|
||||
|
||||
internal external interface BaseModuleSpecifier : BaseNode {
|
||||
var local: Identifier
|
||||
}
|
||||
|
||||
internal external interface ImportDeclaration : BaseModuleDeclaration {
|
||||
override var type: String /* "ImportDeclaration" */
|
||||
var specifiers: Array<dynamic /* ImportSpecifier | ImportDefaultSpecifier | ImportNamespaceSpecifier */>
|
||||
var source: dynamic /* SimpleLiteral | RegExpLiteral */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ImportSpecifier : BaseModuleSpecifier {
|
||||
override var type: String /* "ImportSpecifier" */
|
||||
var imported: Identifier
|
||||
}
|
||||
|
||||
internal external interface ImportExpression : BaseExpression {
|
||||
override var type: String /* "ImportExpression" */
|
||||
var source: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ImportDefaultSpecifier : BaseModuleSpecifier {
|
||||
override var type: String /* "ImportDefaultSpecifier" */
|
||||
}
|
||||
|
||||
internal external interface ImportNamespaceSpecifier : BaseModuleSpecifier {
|
||||
override var type: String /* "ImportNamespaceSpecifier" */
|
||||
}
|
||||
|
||||
internal external interface ExportNamedDeclaration : BaseModuleDeclaration {
|
||||
override var type: String /* "ExportNamedDeclaration" */
|
||||
var declaration: dynamic /* FunctionDeclaration? | VariableDeclaration? | ClassDeclaration? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var specifiers: Array<ExportSpecifier>
|
||||
var source: dynamic /* SimpleLiteral? | RegExpLiteral? */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ExportSpecifier : BaseModuleSpecifier {
|
||||
override var type: String /* "ExportSpecifier" */
|
||||
var exported: Identifier
|
||||
}
|
||||
|
||||
internal external interface ExportDefaultDeclaration : BaseModuleDeclaration {
|
||||
override var type: String /* "ExportDefaultDeclaration" */
|
||||
var declaration: dynamic /* FunctionDeclaration | VariableDeclaration | ClassDeclaration | ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface ExportAllDeclaration : BaseModuleDeclaration {
|
||||
override var type: String /* "ExportAllDeclaration" */
|
||||
var source: dynamic /* SimpleLiteral | RegExpLiteral */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
||||
|
||||
internal external interface AwaitExpression : BaseExpression {
|
||||
override var type: String /* "AwaitExpression" */
|
||||
var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package kscience.kmath.estree.internal.stream
|
||||
|
||||
import kscience.kmath.estree.internal.emitter.Emitter
|
||||
|
||||
internal open external class Stream : Emitter {
|
||||
open fun pipe(dest: Any, options: Any): Any
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package kscience.kmath.estree.internal.tsstdlib
|
||||
|
||||
internal external interface IteratorYieldResult<TYield> {
|
||||
var done: Boolean?
|
||||
get() = definedExternally
|
||||
set(value) = definedExternally
|
||||
var value: TYield
|
||||
}
|
||||
|
||||
internal external interface IteratorReturnResult<TReturn> {
|
||||
var done: Boolean
|
||||
var value: TReturn
|
||||
}
|
||||
|
||||
internal external interface Iterator<T, TReturn, TNext> {
|
||||
fun next(vararg args: Any /* JsTuple<> | JsTuple<TNext> */): dynamic /* IteratorYieldResult<T> | IteratorReturnResult<TReturn> */
|
||||
val `return`: ((value: TReturn) -> dynamic)?
|
||||
val `throw`: ((e: Any) -> dynamic)?
|
||||
}
|
||||
|
||||
internal typealias Iterator__1<T> = Iterator<T, Any, Nothing?>
|
||||
|
||||
internal external interface Iterable<T>
|
||||
|
||||
internal external interface IterableIterator<T> : Iterator__1<T>
|
@ -0,0 +1,82 @@
|
||||
@file:Suppress("UNUSED_TYPEALIAS_PARAMETER", "DEPRECATION")
|
||||
|
||||
package kscience.kmath.estree.internal.tsstdlib
|
||||
|
||||
import kotlin.js.RegExp
|
||||
|
||||
internal typealias RegExpMatchArray = Array<String>
|
||||
|
||||
internal typealias RegExpExecArray = Array<String>
|
||||
|
||||
internal external interface RegExpConstructor {
|
||||
@nativeInvoke
|
||||
operator fun invoke(pattern: RegExp, flags: String = definedExternally): RegExp
|
||||
|
||||
@nativeInvoke
|
||||
operator fun invoke(pattern: RegExp): RegExp
|
||||
|
||||
@nativeInvoke
|
||||
operator fun invoke(pattern: String, flags: String = definedExternally): RegExp
|
||||
|
||||
@nativeInvoke
|
||||
operator fun invoke(pattern: String): RegExp
|
||||
var prototype: RegExp
|
||||
var `$1`: String
|
||||
var `$2`: String
|
||||
var `$3`: String
|
||||
var `$4`: String
|
||||
var `$5`: String
|
||||
var `$6`: String
|
||||
var `$7`: String
|
||||
var `$8`: String
|
||||
var `$9`: String
|
||||
var lastMatch: String
|
||||
}
|
||||
|
||||
internal external interface ConcatArray<T> {
|
||||
var length: Number
|
||||
|
||||
@nativeGetter
|
||||
operator fun get(n: Number): T?
|
||||
|
||||
@nativeSetter
|
||||
operator fun set(n: Number, value: T)
|
||||
fun join(separator: String = definedExternally): String
|
||||
fun slice(start: Number = definedExternally, end: Number = definedExternally): Array<T>
|
||||
}
|
||||
|
||||
internal external interface ArrayConstructor {
|
||||
fun <T> from(iterable: Iterable<T>): Array<T>
|
||||
fun <T> from(iterable: ArrayLike<T>): Array<T>
|
||||
fun <T, U> from(iterable: Iterable<T>, mapfn: (v: T, k: Number) -> U, thisArg: Any = definedExternally): Array<U>
|
||||
fun <T, U> from(iterable: Iterable<T>, mapfn: (v: T, k: Number) -> U): Array<U>
|
||||
fun <T, U> from(iterable: ArrayLike<T>, mapfn: (v: T, k: Number) -> U, thisArg: Any = definedExternally): Array<U>
|
||||
fun <T, U> from(iterable: ArrayLike<T>, mapfn: (v: T, k: Number) -> U): Array<U>
|
||||
fun <T> of(vararg items: T): Array<T>
|
||||
|
||||
@nativeInvoke
|
||||
operator fun invoke(arrayLength: Number = definedExternally): Array<Any>
|
||||
|
||||
@nativeInvoke
|
||||
operator fun invoke(): Array<Any>
|
||||
|
||||
@nativeInvoke
|
||||
operator fun <T> invoke(arrayLength: Number): Array<T>
|
||||
|
||||
@nativeInvoke
|
||||
operator fun <T> invoke(vararg items: T): Array<T>
|
||||
fun isArray(arg: Any): Boolean
|
||||
var prototype: Array<Any>
|
||||
}
|
||||
|
||||
internal external interface ArrayLike<T> {
|
||||
var length: Number
|
||||
|
||||
@nativeGetter
|
||||
operator fun get(n: Number): T?
|
||||
|
||||
@nativeSetter
|
||||
operator fun set(n: Number, value: T)
|
||||
}
|
||||
|
||||
internal typealias Extract<T, U> = Any
|
@ -0,0 +1,115 @@
|
||||
package kscience.kmath.estree
|
||||
|
||||
import kscience.kmath.ast.*
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.ByteRing
|
||||
import kscience.kmath.operations.ComplexField
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.toComplex
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestESTreeConsistencyWithInterpreter {
|
||||
@Test
|
||||
fun mstSpace() {
|
||||
val res1 = MstSpace.mstInSpace {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + symbol("x") + zero
|
||||
}("x" to MST.Numeric(2))
|
||||
|
||||
val res2 = MstSpace.mstInSpace {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + symbol("x") + zero
|
||||
}.compile()("x" to MST.Numeric(2))
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun byteRing() {
|
||||
val res1 = ByteRing.mstInRing {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(symbol("x") - (2.toByte() + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
|
||||
number(1)
|
||||
) * number(2)
|
||||
}("x" to 3.toByte())
|
||||
|
||||
val res2 = ByteRing.mstInRing {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(symbol("x") - (2.toByte() + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
number(1)
|
||||
) * number(2)
|
||||
}.compile()("x" to 3.toByte())
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun realField() {
|
||||
val res1 = RealField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}("x" to 2.0)
|
||||
|
||||
val res2 = RealField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}.compile()("x" to 2.0)
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun complexField() {
|
||||
val res1 = ComplexField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}("x" to 2.0.toComplex())
|
||||
|
||||
val res2 = ComplexField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}.compile()("x" to 2.0.toComplex())
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
package kscience.kmath.estree
|
||||
|
||||
import kscience.kmath.ast.mstInExtendedField
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.ast.mstInSpace
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.RealField
|
||||
import kotlin.random.Random
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestESTreeOperationsSupport {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
||||
val res = expression("x" to 2.0)
|
||||
assertEquals(-2.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testBinaryOperationInvocation() {
|
||||
val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile()
|
||||
val res = expression("x" to 2.0)
|
||||
assertEquals(-1.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testConstProductInvocation() {
|
||||
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMultipleCalls() {
|
||||
val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile()
|
||||
val r = Random(0)
|
||||
var s = 0.0
|
||||
repeat(1000000) { s += e("x" to r.nextDouble()) }
|
||||
println(s)
|
||||
}
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
package kscience.kmath.estree
|
||||
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestESTreeSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = RealField.mstInField { unaryOperationFunction("+")(symbol("x")) }.compile()
|
||||
assertEquals(2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = RealField.mstInField { unaryOperationFunction("-")(symbol("x")) }.compile()
|
||||
assertEquals(-2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = RealField.mstInField { binaryOperationFunction("+")(symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() {
|
||||
val expr = RealField.mstInField { unaryOperationFunction("sin")(symbol("x")) }.compile()
|
||||
assertEquals(0.0, expr("x" to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
val expr = RealField.mstInField { binaryOperationFunction("-")(symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(0.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = RealField.mstInField { binaryOperationFunction("/")(symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(1.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
val expr = RealField
|
||||
.mstInField { binaryOperationFunction("pow")(symbol("x"), number(2)) }
|
||||
.compile()
|
||||
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
package kscience.kmath.estree
|
||||
|
||||
import kscience.kmath.ast.mstInRing
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.ByteRing
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
internal class TestESTreeVariables {
|
||||
@Test
|
||||
fun testVariable() {
|
||||
val expr = ByteRing.mstInRing { symbol("x") }.compile()
|
||||
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUndefinedVariableFails() {
|
||||
val expr = ByteRing.mstInRing { symbol("x") }.compile()
|
||||
assertFailsWith<NoSuchElementException> { expr() }
|
||||
}
|
||||
}
|
@ -1,13 +1,13 @@
|
||||
package kscience.kmath.asm
|
||||
|
||||
import kscience.kmath.asm.internal.AsmBuilder
|
||||
import kscience.kmath.asm.internal.MstType
|
||||
import kscience.kmath.asm.internal.buildAlgebraOperationCall
|
||||
import kscience.kmath.asm.internal.buildName
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MST.*
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
/**
|
||||
* Compiles given MST to an Expression using AST compiler.
|
||||
@ -20,40 +20,54 @@ import kscience.kmath.operations.Algebra
|
||||
@PublishedApi
|
||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
||||
is MST.Symbolic -> {
|
||||
is Symbolic -> {
|
||||
val symbol = try {
|
||||
algebra.symbol(node.value)
|
||||
} catch (ignored: Throwable) {
|
||||
} catch (ignored: IllegalStateException) {
|
||||
null
|
||||
}
|
||||
|
||||
if (symbol != null)
|
||||
loadTConstant(symbol)
|
||||
loadObjectConstant(symbol as Any)
|
||||
else
|
||||
loadVariable(node.value)
|
||||
}
|
||||
|
||||
is MST.Numeric -> loadNumeric(node.value)
|
||||
is Numeric -> loadNumberConstant(node.value)
|
||||
|
||||
is MST.Unary -> buildAlgebraOperationCall(
|
||||
context = algebra,
|
||||
name = node.operation,
|
||||
fallbackMethodName = "unaryOperation",
|
||||
parameterTypes = arrayOf(MstType.fromMst(node.value))
|
||||
) { visit(node.value) }
|
||||
is Unary -> when {
|
||||
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
||||
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||
|
||||
is MST.Binary -> buildAlgebraOperationCall(
|
||||
context = algebra,
|
||||
name = node.operation,
|
||||
fallbackMethodName = "binaryOperation",
|
||||
parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right))
|
||||
) {
|
||||
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
||||
}
|
||||
|
||||
is Binary -> when {
|
||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
||||
algebra.binaryOperationFunction(node.operation)
|
||||
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
||||
algebra.leftSideNumberOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
}
|
||||
|
||||
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
||||
algebra.rightSideNumberOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
}
|
||||
|
||||
else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
||||
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3,29 +3,30 @@ package kscience.kmath.asm.internal
|
||||
import kscience.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
import org.objectweb.asm.*
|
||||
import org.objectweb.asm.Opcodes.*
|
||||
import org.objectweb.asm.Type.*
|
||||
import org.objectweb.asm.commons.InstructionAdapter
|
||||
import java.util.*
|
||||
import java.util.stream.Collectors
|
||||
import java.lang.invoke.MethodHandles
|
||||
import java.lang.invoke.MethodType
|
||||
import java.lang.reflect.Modifier
|
||||
import java.util.stream.Collectors.toMap
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
/**
|
||||
* ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
|
||||
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
|
||||
*
|
||||
* @property T the type of AsmExpression to unwrap.
|
||||
* @property algebra the algebra the applied AsmExpressions use.
|
||||
* @property className the unique class name of new loaded class.
|
||||
* @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0.
|
||||
* @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
internal class AsmBuilder<T> internal constructor(
|
||||
private val classOfT: Class<*>,
|
||||
private val algebra: Algebra<T>,
|
||||
internal class AsmBuilder<T>(
|
||||
classOfT: Class<*>,
|
||||
private val className: String,
|
||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit,
|
||||
private val callbackAtInvokeL0: AsmBuilder<T>.() -> Unit,
|
||||
) {
|
||||
/**
|
||||
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
||||
@ -39,20 +40,15 @@ internal class AsmBuilder<T> internal constructor(
|
||||
*/
|
||||
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||
|
||||
/**
|
||||
* ASM Type for [algebra].
|
||||
*/
|
||||
private val tAlgebraType: Type = algebra.javaClass.asm
|
||||
|
||||
/**
|
||||
* ASM type for [T].
|
||||
*/
|
||||
internal val tType: Type = classOfT.asm
|
||||
private val tType: Type = classOfT.asm
|
||||
|
||||
/**
|
||||
* ASM type for new class.
|
||||
*/
|
||||
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
||||
private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/'))
|
||||
|
||||
/**
|
||||
* List of constants to provide to the subclass.
|
||||
@ -64,55 +60,14 @@ internal class AsmBuilder<T> internal constructor(
|
||||
*/
|
||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||
|
||||
/**
|
||||
* States whether this [AsmBuilder] needs to generate constants field.
|
||||
*/
|
||||
private var hasConstants: Boolean = true
|
||||
|
||||
/**
|
||||
* States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
||||
*/
|
||||
internal var primitiveMode: Boolean = false
|
||||
|
||||
/**
|
||||
* Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||
*/
|
||||
internal var primitiveMask: Type = OBJECT_TYPE
|
||||
|
||||
/**
|
||||
* Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode].
|
||||
*/
|
||||
internal var primitiveMaskBoxed: Type = OBJECT_TYPE
|
||||
|
||||
/**
|
||||
* Stack of useful objects types on stack to verify types.
|
||||
*/
|
||||
private val typeStack: ArrayDeque<Type> = ArrayDeque()
|
||||
|
||||
/**
|
||||
* Stack of useful objects types on stack expected by algebra calls.
|
||||
*/
|
||||
internal val expectationStack: ArrayDeque<Type> = ArrayDeque<Type>(1).also { it.push(tType) }
|
||||
|
||||
/**
|
||||
* The cache for instance built by this builder.
|
||||
*/
|
||||
private var generatedInstance: Expression<T>? = null
|
||||
|
||||
/**
|
||||
* Subclasses, loads and instantiates [Expression] for given parameters.
|
||||
*
|
||||
* The built instance is cached.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
internal fun getInstance(): Expression<T> {
|
||||
generatedInstance?.let { return it }
|
||||
|
||||
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
||||
primitiveMode = true
|
||||
primitiveMask = SIGNATURE_LETTERS.getValue(classOfT)
|
||||
primitiveMaskBoxed = tType
|
||||
}
|
||||
val instance: Expression<T> by lazy {
|
||||
val hasConstants: Boolean
|
||||
|
||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||
visit(
|
||||
@ -121,20 +76,20 @@ internal class AsmBuilder<T> internal constructor(
|
||||
classType.internalName,
|
||||
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
|
||||
OBJECT_TYPE.internalName,
|
||||
arrayOf(EXPRESSION_TYPE.internalName)
|
||||
arrayOf(EXPRESSION_TYPE.internalName),
|
||||
)
|
||||
|
||||
visitMethod(
|
||||
ACC_PUBLIC or ACC_FINAL,
|
||||
"invoke",
|
||||
Type.getMethodDescriptor(tType, MAP_TYPE),
|
||||
"(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
||||
null
|
||||
getMethodDescriptor(tType, MAP_TYPE),
|
||||
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}${if (Modifier.isFinal(classOfT.modifiers)) "" else "+"}${tType.descriptor}>;)${tType.descriptor}",
|
||||
null,
|
||||
).instructionAdapter {
|
||||
invokeMethodVisitor = this
|
||||
visitCode()
|
||||
val l0 = label()
|
||||
invokeLabel0Visitor()
|
||||
callbackAtInvokeL0()
|
||||
areturn(tType)
|
||||
val l1 = label()
|
||||
|
||||
@ -144,7 +99,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
null,
|
||||
l0,
|
||||
l1,
|
||||
invokeThisVar
|
||||
0,
|
||||
)
|
||||
|
||||
visitLocalVariable(
|
||||
@ -153,7 +108,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;",
|
||||
l0,
|
||||
l1,
|
||||
invokeArgumentsVar
|
||||
1,
|
||||
)
|
||||
|
||||
visitMaxs(0, 2)
|
||||
@ -163,17 +118,15 @@ internal class AsmBuilder<T> internal constructor(
|
||||
visitMethod(
|
||||
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||
"invoke",
|
||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
||||
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
||||
null,
|
||||
null,
|
||||
null
|
||||
).instructionAdapter {
|
||||
val thisVar = 0
|
||||
val argumentsVar = 1
|
||||
visitCode()
|
||||
val l0 = label()
|
||||
load(thisVar, OBJECT_TYPE)
|
||||
load(argumentsVar, MAP_TYPE)
|
||||
invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false)
|
||||
load(0, OBJECT_TYPE)
|
||||
load(1, MAP_TYPE)
|
||||
invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false)
|
||||
areturn(tType)
|
||||
val l1 = label()
|
||||
|
||||
@ -183,7 +136,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
null,
|
||||
l0,
|
||||
l1,
|
||||
thisVar
|
||||
0,
|
||||
)
|
||||
|
||||
visitMaxs(0, 2)
|
||||
@ -192,15 +145,6 @@ internal class AsmBuilder<T> internal constructor(
|
||||
|
||||
hasConstants = constants.isNotEmpty()
|
||||
|
||||
visitField(
|
||||
access = ACC_PRIVATE or ACC_FINAL,
|
||||
name = "algebra",
|
||||
descriptor = tAlgebraType.descriptor,
|
||||
signature = null,
|
||||
value = null,
|
||||
block = FieldVisitor::visitEnd
|
||||
)
|
||||
|
||||
if (hasConstants)
|
||||
visitField(
|
||||
access = ACC_PRIVATE or ACC_FINAL,
|
||||
@ -208,55 +152,36 @@ internal class AsmBuilder<T> internal constructor(
|
||||
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
||||
signature = null,
|
||||
value = null,
|
||||
block = FieldVisitor::visitEnd
|
||||
block = FieldVisitor::visitEnd,
|
||||
)
|
||||
|
||||
visitMethod(
|
||||
ACC_PUBLIC,
|
||||
"<init>",
|
||||
|
||||
Type.getMethodDescriptor(
|
||||
Type.VOID_TYPE,
|
||||
tAlgebraType,
|
||||
*OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
|
||||
|
||||
getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
|
||||
null,
|
||||
null,
|
||||
null
|
||||
).instructionAdapter {
|
||||
val thisVar = 0
|
||||
val algebraVar = 1
|
||||
val constantsVar = 2
|
||||
val l0 = label()
|
||||
load(thisVar, classType)
|
||||
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
|
||||
load(0, classType)
|
||||
invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false)
|
||||
label()
|
||||
load(thisVar, classType)
|
||||
load(algebraVar, tAlgebraType)
|
||||
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
||||
load(0, classType)
|
||||
|
||||
if (hasConstants) {
|
||||
label()
|
||||
load(thisVar, classType)
|
||||
load(constantsVar, OBJECT_ARRAY_TYPE)
|
||||
load(0, classType)
|
||||
load(1, OBJECT_ARRAY_TYPE)
|
||||
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||
}
|
||||
|
||||
label()
|
||||
visitInsn(RETURN)
|
||||
val l4 = label()
|
||||
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
|
||||
|
||||
visitLocalVariable(
|
||||
"algebra",
|
||||
tAlgebraType.descriptor,
|
||||
null,
|
||||
l0,
|
||||
l4,
|
||||
algebraVar
|
||||
)
|
||||
visitLocalVariable("this", classType.descriptor, null, l0, l4, 0)
|
||||
|
||||
if (hasConstants)
|
||||
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
|
||||
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1)
|
||||
|
||||
visitMaxs(0, 3)
|
||||
visitEnd()
|
||||
@ -265,296 +190,156 @@ internal class AsmBuilder<T> internal constructor(
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
val new = classLoader
|
||||
.defineClass(className, classWriter.toByteArray())
|
||||
.constructors
|
||||
.first()
|
||||
.newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
|
||||
val cls = classLoader.defineClass(className, classWriter.toByteArray())
|
||||
// java.io.File("dump.class").writeBytes(classWriter.toByteArray())
|
||||
val l = MethodHandles.publicLookup()
|
||||
|
||||
generatedInstance = new
|
||||
return new
|
||||
if (hasConstants)
|
||||
l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array<Any>::class.java))
|
||||
.invoke(constants.toTypedArray()) as Expression<T>
|
||||
else
|
||||
l.findConstructor(cls, MethodType.methodType(Void.TYPE)).invoke() as Expression<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a [T] constant from [constants].
|
||||
*/
|
||||
internal fun loadTConstant(value: T) {
|
||||
if (classOfT in INLINABLE_NUMBERS) {
|
||||
val expectedType = expectationStack.pop()
|
||||
val mustBeBoxed = expectedType.sort == Type.OBJECT
|
||||
loadNumberConstant(value as Number, mustBeBoxed)
|
||||
|
||||
if (mustBeBoxed)
|
||||
invokeMethodVisitor.checkcast(tType)
|
||||
|
||||
if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask)
|
||||
return
|
||||
}
|
||||
|
||||
loadObjectConstant(value as Any, tType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Boxes the current value and pushes it.
|
||||
*/
|
||||
private fun box(primitive: Type) {
|
||||
val r = PRIMITIVES_TO_BOXED.getValue(primitive)
|
||||
|
||||
invokeMethodVisitor.invokestatic(
|
||||
r.internalName,
|
||||
"valueOf",
|
||||
Type.getMethodDescriptor(r, primitive),
|
||||
false
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Unboxes the current boxed value and pushes it.
|
||||
*/
|
||||
private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual(
|
||||
NUMBER_TYPE.internalName,
|
||||
NUMBER_CONVERTER_METHODS.getValue(primitive),
|
||||
Type.getMethodDescriptor(primitive),
|
||||
false
|
||||
)
|
||||
|
||||
/**
|
||||
* Loads [java.lang.Object] constant from constants.
|
||||
*/
|
||||
private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||
loadThis()
|
||||
fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
|
||||
invokeMethodVisitor.load(0, classType)
|
||||
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||
iconst(idx)
|
||||
visitInsn(AALOAD)
|
||||
checkcast(type)
|
||||
if (type != OBJECT_TYPE) checkcast(type)
|
||||
}
|
||||
|
||||
internal fun loadNumeric(value: Number) {
|
||||
if (expectationStack.peek() == NUMBER_TYPE) {
|
||||
loadNumberConstant(value, true)
|
||||
expectationStack.pop()
|
||||
typeStack.push(NUMBER_TYPE)
|
||||
} else (algebra as? NumericAlgebra<T>)?.number(value)?.let { loadTConstant(it) }
|
||||
?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.")
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads this variable.
|
||||
*/
|
||||
private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType)
|
||||
|
||||
/**
|
||||
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
|
||||
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
|
||||
* from it).
|
||||
* constant from the constant pool.
|
||||
*/
|
||||
private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
|
||||
fun loadNumberConstant(value: Number) {
|
||||
val boxed = value.javaClass.asm
|
||||
val primitive = BOXED_TO_PRIMITIVES[boxed]
|
||||
|
||||
if (primitive != null) {
|
||||
when (primitive) {
|
||||
Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
|
||||
Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
|
||||
Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
|
||||
Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
|
||||
FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
|
||||
LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
|
||||
INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||
}
|
||||
|
||||
if (mustBeBoxed)
|
||||
box(primitive)
|
||||
val r = PRIMITIVES_TO_BOXED.getValue(primitive)
|
||||
|
||||
invokeMethodVisitor.invokestatic(
|
||||
r.internalName,
|
||||
"valueOf",
|
||||
getMethodDescriptor(r, primitive),
|
||||
false,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
loadObjectConstant(value, boxed)
|
||||
|
||||
if (!mustBeBoxed)
|
||||
unboxTo(primitiveMask)
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
|
||||
* provided.
|
||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke].
|
||||
*/
|
||||
internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
|
||||
load(invokeArgumentsVar, MAP_TYPE)
|
||||
fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
|
||||
load(1, MAP_TYPE)
|
||||
aconst(name)
|
||||
|
||||
invokestatic(
|
||||
MAP_INTRINSICS_TYPE.internalName,
|
||||
"getOrFail",
|
||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
|
||||
false
|
||||
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
|
||||
false,
|
||||
)
|
||||
|
||||
checkcast(tType)
|
||||
val expectedType = expectationStack.pop()
|
||||
|
||||
if (expectedType.sort == Type.OBJECT)
|
||||
typeStack.push(tType)
|
||||
else {
|
||||
unboxTo(primitiveMask)
|
||||
typeStack.push(primitiveMask)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads algebra from according field of the class and casts it to class of [algebra] provided.
|
||||
*/
|
||||
internal fun loadAlgebra() {
|
||||
loadThis()
|
||||
invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
||||
}
|
||||
inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) {
|
||||
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
||||
val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces }
|
||||
|
||||
/**
|
||||
* Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is
|
||||
* [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be
|
||||
* called before the arguments and this operation.
|
||||
*
|
||||
* The result is casted to [T] automatically.
|
||||
*/
|
||||
internal fun invokeAlgebraOperation(
|
||||
owner: String,
|
||||
method: String,
|
||||
descriptor: String,
|
||||
expectedArity: Int,
|
||||
opcode: Int = INVOKEINTERFACE,
|
||||
) {
|
||||
run loop@{
|
||||
repeat(expectedArity) {
|
||||
if (typeStack.isEmpty()) return@loop
|
||||
typeStack.pop()
|
||||
}
|
||||
}
|
||||
val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount
|
||||
?: error("Provided function object doesn't contain invoke method")
|
||||
|
||||
invokeMethodVisitor.visitMethodInsn(
|
||||
opcode,
|
||||
owner,
|
||||
method,
|
||||
descriptor,
|
||||
opcode == INVOKEINTERFACE
|
||||
val type = getType(`interface`)
|
||||
loadObjectConstant(function, type)
|
||||
parameters(this)
|
||||
|
||||
invokeMethodVisitor.invokeinterface(
|
||||
type.internalName,
|
||||
"invoke",
|
||||
getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }),
|
||||
)
|
||||
|
||||
invokeMethodVisitor.checkcast(tType)
|
||||
val isLastExpr = expectationStack.size == 1
|
||||
val expectedType = expectationStack.pop()
|
||||
|
||||
if (expectedType.sort == Type.OBJECT || isLastExpr)
|
||||
typeStack.push(tType)
|
||||
else {
|
||||
unboxTo(primitiveMask)
|
||||
typeStack.push(primitiveMask)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes a LDC Instruction with string constant provided.
|
||||
*/
|
||||
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
|
||||
|
||||
internal companion object {
|
||||
/**
|
||||
* Index of `this` variable in invoke method of the built subclass.
|
||||
*/
|
||||
private const val invokeThisVar: Int = 0
|
||||
|
||||
/**
|
||||
* Index of `arguments` variable in invoke method of the built subclass.
|
||||
*/
|
||||
private const val invokeArgumentsVar: Int = 1
|
||||
|
||||
/**
|
||||
* Maps JVM primitive numbers boxed types to their primitive ASM types.
|
||||
*/
|
||||
private val SIGNATURE_LETTERS: Map<Class<out Any>, Type> by lazy {
|
||||
hashMapOf(
|
||||
java.lang.Byte::class.java to Type.BYTE_TYPE,
|
||||
java.lang.Short::class.java to Type.SHORT_TYPE,
|
||||
java.lang.Integer::class.java to Type.INT_TYPE,
|
||||
java.lang.Long::class.java to Type.LONG_TYPE,
|
||||
java.lang.Float::class.java to Type.FLOAT_TYPE,
|
||||
java.lang.Double::class.java to Type.DOUBLE_TYPE
|
||||
)
|
||||
}
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||
*/
|
||||
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } }
|
||||
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy {
|
||||
hashMapOf(
|
||||
Byte::class.java.asm to BYTE_TYPE,
|
||||
Short::class.java.asm to SHORT_TYPE,
|
||||
Integer::class.java.asm to INT_TYPE,
|
||||
Long::class.java.asm to LONG_TYPE,
|
||||
Float::class.java.asm to FLOAT_TYPE,
|
||||
Double::class.java.asm to DOUBLE_TYPE,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps JVM primitive numbers boxed ASM types to their primitive ASM types.
|
||||
*/
|
||||
private val PRIMITIVES_TO_BOXED: Map<Type, Type> by lazy {
|
||||
BOXED_TO_PRIMITIVES.entries.stream().collect(
|
||||
Collectors.toMap(
|
||||
Map.Entry<Type, Type>::value,
|
||||
Map.Entry<Type, Type>::key
|
||||
)
|
||||
toMap(Map.Entry<Type, Type>::value, Map.Entry<Type, Type>::key),
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps primitive ASM types to [Number] functions unboxing them.
|
||||
*/
|
||||
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
|
||||
hashMapOf(
|
||||
Type.BYTE_TYPE to "byteValue",
|
||||
Type.SHORT_TYPE to "shortValue",
|
||||
Type.INT_TYPE to "intValue",
|
||||
Type.LONG_TYPE to "longValue",
|
||||
Type.FLOAT_TYPE to "floatValue",
|
||||
Type.DOUBLE_TYPE to "doubleValue"
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Provides boxed number types values of which can be stored in JVM bytecode constant pool.
|
||||
*/
|
||||
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||
|
||||
/**
|
||||
* ASM type for [Expression].
|
||||
*/
|
||||
internal val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/expressions/Expression") }
|
||||
|
||||
/**
|
||||
* ASM type for [java.lang.Number].
|
||||
*/
|
||||
internal val NUMBER_TYPE: Type by lazy { Type.getObjectType("java/lang/Number") }
|
||||
val EXPRESSION_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Expression") }
|
||||
|
||||
/**
|
||||
* ASM type for [java.util.Map].
|
||||
*/
|
||||
internal val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") }
|
||||
val MAP_TYPE: Type by lazy { getObjectType("java/util/Map") }
|
||||
|
||||
/**
|
||||
* ASM type for [java.lang.Object].
|
||||
*/
|
||||
internal val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") }
|
||||
val OBJECT_TYPE: Type by lazy { getObjectType("java/lang/Object") }
|
||||
|
||||
/**
|
||||
* ASM type for array of [java.lang.Object].
|
||||
*/
|
||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
|
||||
internal val OBJECT_ARRAY_TYPE: Type by lazy { Type.getType("[Ljava/lang/Object;") }
|
||||
|
||||
/**
|
||||
* ASM type for [Algebra].
|
||||
*/
|
||||
internal val ALGEBRA_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/operations/Algebra") }
|
||||
val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") }
|
||||
|
||||
/**
|
||||
* ASM type for [java.lang.String].
|
||||
*/
|
||||
internal val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") }
|
||||
val STRING_TYPE: Type by lazy { getObjectType("java/lang/String") }
|
||||
|
||||
/**
|
||||
* ASM type for MapIntrinsics.
|
||||
*/
|
||||
internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/asm/internal/MapIntrinsics") }
|
||||
val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("kscience/kmath/asm/internal/MapIntrinsics") }
|
||||
|
||||
/**
|
||||
* ASM Type for [kscience.kmath.expressions.Symbol].
|
||||
*/
|
||||
val SYMBOL_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Symbol") }
|
||||
}
|
||||
}
|
||||
|
@ -1,20 +0,0 @@
|
||||
package kscience.kmath.asm.internal
|
||||
|
||||
import kscience.kmath.ast.MST
|
||||
|
||||
/**
|
||||
* Represents types known in [MST], numbers and general values.
|
||||
*/
|
||||
internal enum class MstType {
|
||||
GENERAL,
|
||||
NUMBER;
|
||||
|
||||
companion object {
|
||||
fun fromMst(mst: MST): MstType {
|
||||
if (mst is MST.Numeric)
|
||||
return NUMBER
|
||||
|
||||
return GENERAL
|
||||
}
|
||||
}
|
||||
}
|
@ -2,29 +2,11 @@ package kscience.kmath.asm.internal
|
||||
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.FieldOperations
|
||||
import kscience.kmath.operations.RingOperations
|
||||
import kscience.kmath.operations.SpaceOperations
|
||||
import org.objectweb.asm.*
|
||||
import org.objectweb.asm.Opcodes.INVOKEVIRTUAL
|
||||
import org.objectweb.asm.commons.InstructionAdapter
|
||||
import java.lang.reflect.Method
|
||||
import java.util.*
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
||||
hashMapOf(
|
||||
SpaceOperations.PLUS_OPERATION to 2 to "add",
|
||||
RingOperations.TIMES_OPERATION to 2 to "multiply",
|
||||
FieldOperations.DIV_OPERATION to 2 to "divide",
|
||||
SpaceOperations.PLUS_OPERATION to 1 to "unaryPlus",
|
||||
SpaceOperations.MINUS_OPERATION to 1 to "unaryMinus",
|
||||
SpaceOperations.MINUS_OPERATION to 2 to "minus"
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns ASM [Type] for given [Class].
|
||||
*
|
||||
@ -109,107 +91,3 @@ internal inline fun ClassWriter.visitField(
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return visitField(access, name, descriptor, signature, value).apply(block)
|
||||
}
|
||||
|
||||
private fun <T> AsmBuilder<T>.findSpecific(context: Algebra<T>, name: String, parameterTypes: Array<MstType>): Method? =
|
||||
context.javaClass.methods.find { method ->
|
||||
val nameValid = method.name == name
|
||||
val arityValid = method.parameters.size == parameterTypes.size
|
||||
val notBridgeInPrimitive = !(primitiveMode && method.isBridge)
|
||||
|
||||
val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) ->
|
||||
!(mstType != MstType.NUMBER && type == java.lang.Number::class.java)
|
||||
}
|
||||
|
||||
nameValid && arityValid && notBridgeInPrimitive && paramsValid
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds
|
||||
* type expectation stack for needed arity.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
private fun <T> AsmBuilder<T>.buildExpectationStack(
|
||||
context: Algebra<T>,
|
||||
name: String,
|
||||
parameterTypes: Array<MstType>
|
||||
): Boolean {
|
||||
val arity = parameterTypes.size
|
||||
val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes)
|
||||
|
||||
if (specific != null)
|
||||
mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) }
|
||||
else
|
||||
expectationStack.addAll(Collections.nCopies(arity, tType))
|
||||
|
||||
return specific != null
|
||||
}
|
||||
|
||||
private fun <T> AsmBuilder<T>.mapTypes(method: Method, parameterTypes: Array<MstType>): List<Type> = method
|
||||
.parameterTypes
|
||||
.zip(parameterTypes)
|
||||
.map { (type, mstType) ->
|
||||
when {
|
||||
type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE
|
||||
else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts
|
||||
* [AsmBuilder.invokeAlgebraOperation] of this method.
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
private fun <T> AsmBuilder<T>.tryInvokeSpecific(
|
||||
context: Algebra<T>,
|
||||
name: String,
|
||||
parameterTypes: Array<MstType>
|
||||
): Boolean {
|
||||
val arity = parameterTypes.size
|
||||
val theName = methodNameAdapters[name to arity] ?: name
|
||||
val spec = findSpecific(context, theName, parameterTypes) ?: return false
|
||||
val owner = context.javaClass.asm
|
||||
|
||||
invokeAlgebraOperation(
|
||||
owner = owner.internalName,
|
||||
method = theName,
|
||||
descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()),
|
||||
expectedArity = arity,
|
||||
opcode = INVOKEVIRTUAL
|
||||
)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds specialized [context] call with option to fallback to generic algebra operation accepting [String].
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
||||
context: Algebra<T>,
|
||||
name: String,
|
||||
fallbackMethodName: String,
|
||||
parameterTypes: Array<MstType>,
|
||||
parameters: AsmBuilder<T>.() -> Unit
|
||||
) {
|
||||
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
||||
val arity = parameterTypes.size
|
||||
loadAlgebra()
|
||||
if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name)
|
||||
parameters()
|
||||
|
||||
if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
|
||||
method = fallbackMethodName,
|
||||
|
||||
descriptor = Type.getMethodDescriptor(
|
||||
AsmBuilder.OBJECT_TYPE,
|
||||
AsmBuilder.STRING_TYPE,
|
||||
*Array(arity) { AsmBuilder.OBJECT_TYPE }
|
||||
),
|
||||
|
||||
expectedArity = arity
|
||||
)
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
// TODO move to common when https://github.com/h0tk3y/better-parse/pull/33 is merged
|
||||
|
||||
package kscience.kmath.ast
|
||||
|
||||
import com.github.h0tk3y.betterParse.combinators.*
|
||||
@ -17,7 +19,8 @@ import kscience.kmath.operations.RingOperations
|
||||
import kscience.kmath.operations.SpaceOperations
|
||||
|
||||
/**
|
||||
* TODO move to common after IR version is released
|
||||
* better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4.
|
||||
*
|
||||
* @author Alexander Nozik and Iaroslav Postovalov
|
||||
*/
|
||||
public object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
@ -83,7 +86,7 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to parse the string into [MST]. Returns [ParseResult] representing expression or error.
|
||||
* Tries to parse the string into [MST] using [ArithmeticsEvaluator]. Returns [ParseResult] representing expression or error.
|
||||
*
|
||||
* @receiver the string to parse.
|
||||
* @return the [MST] node.
|
||||
@ -91,7 +94,7 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
public fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
|
||||
|
||||
/**
|
||||
* Parses the string into [MST].
|
||||
* Parses the string into [MST] using [ArithmeticsEvaluator].
|
||||
*
|
||||
* @receiver the string to parse.
|
||||
* @return the [MST] node.
|
||||
|
@ -1,24 +1,20 @@
|
||||
package kscience.kmath.asm
|
||||
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.ast.mstInRing
|
||||
import kscience.kmath.ast.mstInSpace
|
||||
import kscience.kmath.ast.*
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.ByteRing
|
||||
import kscience.kmath.operations.ComplexField
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.toComplex
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmAlgebras {
|
||||
|
||||
internal class TestAsmConsistencyWithInterpreter {
|
||||
@Test
|
||||
fun space() {
|
||||
val res1 = ByteRing.mstInSpace {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
fun mstSpace() {
|
||||
val res1 = MstSpace.mstInSpace {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
@ -27,14 +23,11 @@ internal class TestAsmAlgebras {
|
||||
|
||||
number(1)
|
||||
) + symbol("x") + zero
|
||||
}("x" to 2.toByte())
|
||||
}("x" to MST.Numeric(2))
|
||||
|
||||
val res2 = ByteRing.mstInSpace {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
val res2 = MstSpace.mstInSpace {
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
@ -43,19 +36,16 @@ internal class TestAsmAlgebras {
|
||||
|
||||
number(1)
|
||||
) + symbol("x") + zero
|
||||
}.compile()("x" to 2.toByte())
|
||||
}.compile()("x" to MST.Numeric(2))
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun ring() {
|
||||
fun byteRing() {
|
||||
val res1 = ByteRing.mstInRing {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(symbol("x") - (2.toByte() + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
@ -67,17 +57,13 @@ internal class TestAsmAlgebras {
|
||||
}("x" to 3.toByte())
|
||||
|
||||
val res2 = ByteRing.mstInRing {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
binaryOperationFunction("+")(
|
||||
unaryOperationFunction("+")(
|
||||
(symbol("x") - (2.toByte() + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||
),
|
||||
|
||||
number(1)
|
||||
) * number(2)
|
||||
}.compile()("x" to 3.toByte())
|
||||
@ -86,10 +72,9 @@ internal class TestAsmAlgebras {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun field() {
|
||||
fun realField() {
|
||||
val res1 = RealField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
||||
"+",
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
@ -97,8 +82,7 @@ internal class TestAsmAlgebras {
|
||||
}("x" to 2.0)
|
||||
|
||||
val res2 = RealField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
|
||||
"+",
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
@ -107,4 +91,25 @@ internal class TestAsmAlgebras {
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun complexField() {
|
||||
val res1 = ComplexField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}("x" to 2.0.toComplex())
|
||||
|
||||
val res2 = ComplexField.mstInField {
|
||||
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
|
||||
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
|
||||
+ number(1),
|
||||
number(1) / 2 + number(2.0) * one
|
||||
) + zero
|
||||
}.compile()("x" to 2.0.toComplex())
|
||||
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
}
|
@ -1,14 +1,15 @@
|
||||
package kscience.kmath.asm
|
||||
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.ast.mstInExtendedField
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.ast.mstInSpace
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.RealField
|
||||
import kotlin.random.Random
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmExpressions {
|
||||
internal class TestAsmOperationsSupport {
|
||||
@Test
|
||||
fun testUnaryOperationInvocation() {
|
||||
val expression = RealField.mstInSpace { -symbol("x") }.compile()
|
||||
@ -28,4 +29,13 @@ internal class TestAsmExpressions {
|
||||
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)
|
||||
assertEquals(4.0, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMultipleCalls() {
|
||||
val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile()
|
||||
val r = Random(0)
|
||||
var s = 0.0
|
||||
repeat(1000000) { s += e("x" to r.nextDouble()) }
|
||||
println(s)
|
||||
}
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
package kscience.kmath.asm
|
||||
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.RealField
|
||||
@ -10,44 +9,44 @@ import kotlin.test.assertEquals
|
||||
internal class TestAsmSpecialization {
|
||||
@Test
|
||||
fun testUnaryPlus() {
|
||||
val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { unaryOperationFunction("+")(symbol("x")) }.compile()
|
||||
assertEquals(2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testUnaryMinus() {
|
||||
val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { unaryOperationFunction("-")(symbol("x")) }.compile()
|
||||
assertEquals(-2.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { binaryOperationFunction("+")(symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSine() {
|
||||
val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { unaryOperationFunction("sin")(symbol("x")) }.compile()
|
||||
assertEquals(0.0, expr("x" to 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { binaryOperationFunction("-")(symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(0.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivide() {
|
||||
val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile()
|
||||
val expr = RealField.mstInField { binaryOperationFunction("/")(symbol("x"), symbol("x")) }.compile()
|
||||
assertEquals(1.0, expr("x" to 2.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
val expr = RealField
|
||||
.mstInField { binaryOperation("power", symbol("x"), number(2)) }
|
||||
.mstInField { binaryOperationFunction("pow")(symbol("x"), number(2)) }
|
||||
.compile()
|
||||
|
||||
assertEquals(4.0, expr("x" to 2.0))
|
||||
|
@ -9,14 +9,14 @@ import kotlin.test.assertFailsWith
|
||||
|
||||
internal class TestAsmVariables {
|
||||
@Test
|
||||
fun testVariableWithoutDefault() {
|
||||
val expr = ByteRing.mstInRing { symbol("x") }
|
||||
fun testVariable() {
|
||||
val expr = ByteRing.mstInRing { symbol("x") }.compile()
|
||||
assertEquals(1.toByte(), expr("x" to 1.toByte()))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testVariableWithoutDefaultFails() {
|
||||
val expr = ByteRing.mstInRing { symbol("x") }
|
||||
assertFailsWith<IllegalStateException> { expr() }
|
||||
fun testUndefinedVariableFails() {
|
||||
val expr = ByteRing.mstInRing { symbol("x") }.compile()
|
||||
assertFailsWith<NoSuchElementException> { expr() }
|
||||
}
|
||||
}
|
||||
|
@ -1,25 +0,0 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.asm.expression
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.ast.parseMath
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.Complex
|
||||
import kscience.kmath.operations.ComplexField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class AsmTest {
|
||||
@Test
|
||||
fun `compile MST`() {
|
||||
val res = ComplexField.expression("2+2*(2+2)".parseMath())()
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `compile MSTExpression`() {
|
||||
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()()
|
||||
assertEquals(Complex(10.0, 0.0), res)
|
||||
}
|
||||
}
|
@ -1,7 +1,5 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.ast.evaluate
|
||||
import kscience.kmath.ast.parseMath
|
||||
import kscience.kmath.operations.Field
|
||||
import kscience.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
|
@ -1,8 +1,5 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.ast.evaluate
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.ast.parseMath
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.Complex
|
||||
@ -45,10 +42,13 @@ internal class ParserTest {
|
||||
val magicalAlgebra = object : Algebra<String> {
|
||||
override fun symbol(value: String): String = value
|
||||
|
||||
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
|
||||
override fun unaryOperationFunction(operation: String): (arg: String) -> String {
|
||||
throw NotImplementedError()
|
||||
}
|
||||
|
||||
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
|
||||
"magic" -> "$left ★ $right"
|
||||
override fun binaryOperationFunction(operation: String): (left: String, right: String) -> String =
|
||||
when (operation) {
|
||||
"magic" -> { left, right -> "$left ★ $right" }
|
||||
else -> throw NotImplementedError()
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
package kscience.kmath.commons.expressions
|
||||
|
||||
import kscience.kmath.expressions.*
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.ExtendedField
|
||||
import kscience.kmath.operations.RingWithNumbers
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
|
||||
/**
|
||||
@ -10,15 +12,18 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
* @property order The derivation order.
|
||||
* @property bindings The map of bindings values. All bindings are considered free parameters
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class DerivativeStructureField(
|
||||
public val order: Int,
|
||||
bindings: Map<Symbol, Double>,
|
||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>, RingWithNumbers<DerivativeStructure> {
|
||||
public val numberOfVariables: Int = bindings.size
|
||||
|
||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) }
|
||||
|
||||
override fun number(value: Number): DerivativeStructure = const(value.toDouble())
|
||||
|
||||
/**
|
||||
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||
*/
|
||||
|
@ -1,41 +1,28 @@
|
||||
package kscience.kmath.commons.linear
|
||||
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.linear.DiagonalFeature
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.linear.origin
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.NDStructure
|
||||
import org.apache.commons.math3.linear.*
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.cast
|
||||
|
||||
public class CMMatrix(public val origin: RealMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||
public inline class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
|
||||
public override val rowNum: Int get() = origin.rowDimension
|
||||
public override val colNum: Int get() = origin.columnDimension
|
||||
|
||||
public override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
|
||||
if (origin is DiagonalMatrix) yield(DiagonalFeature)
|
||||
}.toHashSet()
|
||||
|
||||
public override fun suggestFeature(vararg features: MatrixFeature): CMMatrix =
|
||||
CMMatrix(origin, this.features + features)
|
||||
@UnstableKMathAPI
|
||||
override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
|
||||
DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null
|
||||
else -> null
|
||||
}?.let { type.cast(it) }
|
||||
|
||||
public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
|
||||
|
||||
public override fun equals(other: Any?): Boolean {
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int {
|
||||
var result = origin.hashCode()
|
||||
result = 31 * result + features.hashCode()
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
public fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
||||
this
|
||||
} else {
|
||||
//TODO add feature analysis
|
||||
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }
|
||||
CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
|
||||
public fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this)
|
||||
|
||||
@ -60,6 +47,16 @@ public object CMMatrixContext : MatrixContext<Double, CMMatrix> {
|
||||
return CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun Matrix<Double>.toCM(): CMMatrix = when (val matrix = origin) {
|
||||
is CMMatrix -> matrix
|
||||
else -> {
|
||||
//TODO add feature analysis
|
||||
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }
|
||||
CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
}
|
||||
|
||||
public override fun Matrix<Double>.dot(other: Matrix<Double>): CMMatrix =
|
||||
CMMatrix(toCM().origin.multiply(other.toCM().origin))
|
||||
|
||||
|
@ -26,7 +26,7 @@ The core features of KMath:
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
|
||||
>
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
|
@ -25,34 +25,34 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
||||
/**
|
||||
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||
*/
|
||||
public override fun binaryOperation(
|
||||
operation: String,
|
||||
left: Expression<T>,
|
||||
right: Expression<T>,
|
||||
): Expression<T> = Expression { arguments ->
|
||||
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
|
||||
public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
{ left, right ->
|
||||
Expression { arguments ->
|
||||
algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||
*/
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
|
||||
algebra.unaryOperation(operation, arg.invoke(arguments))
|
||||
public override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
|
||||
Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A context class for [Expression] construction for [Space] algebras.
|
||||
*/
|
||||
public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
|
||||
public open class FunctionalExpressionSpace<T, A : Space<T>>(
|
||||
algebra: A,
|
||||
) : FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
|
||||
public override val zero: Expression<T> get() = const(algebra.zero)
|
||||
|
||||
/**
|
||||
* Builds an Expression of addition of two another expressions.
|
||||
*/
|
||||
public override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
binaryOperationFunction(SpaceOperations.PLUS_OPERATION)(a, b)
|
||||
|
||||
/**
|
||||
* Builds an Expression of multiplication of expression by number.
|
||||
@ -66,15 +66,16 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||
public operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
||||
public operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
|
||||
public override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.unaryOperationFunction(operation)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionAlgebra>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
||||
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
|
||||
public open class FunctionalExpressionRing<T, A : Ring<T>>(
|
||||
algebra: A,
|
||||
) : FunctionalExpressionSpace<T, A>(algebra), Ring<Expression<T>> {
|
||||
public override val one: Expression<T>
|
||||
get() = const(algebra.one)
|
||||
|
||||
@ -82,68 +83,72 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
|
||||
* Builds an Expression of multiplication of two expressions.
|
||||
*/
|
||||
public override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
|
||||
|
||||
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
|
||||
public override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionSpace>.unaryOperationFunction(operation)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionSpace>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||
FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>
|
||||
where A : Field<T>, A : NumericAlgebra<T> {
|
||||
public open class FunctionalExpressionField<T, A : Field<T>>(
|
||||
algebra: A,
|
||||
) : FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>> {
|
||||
/**
|
||||
* Builds an Expression of division an expression by another one.
|
||||
*/
|
||||
public override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
|
||||
|
||||
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
|
||||
public override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionRing>.unaryOperationFunction(operation)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionRing>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
||||
FunctionalExpressionField<T, A>(algebra),
|
||||
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
||||
public open class FunctionalExpressionExtendedField<T, A : ExtendedField<T>>(
|
||||
algebra: A,
|
||||
) : FunctionalExpressionField<T, A>(algebra), ExtendedField<Expression<T>> {
|
||||
|
||||
override fun number(value: Number): Expression<T> = const(algebra.number(value))
|
||||
|
||||
public override fun sin(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||
unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg)
|
||||
|
||||
public override fun cos(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||
unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg)
|
||||
|
||||
public override fun asin(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||
unaryOperationFunction(TrigonometricOperations.ASIN_OPERATION)(arg)
|
||||
|
||||
public override fun acos(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||
unaryOperationFunction(TrigonometricOperations.ACOS_OPERATION)(arg)
|
||||
|
||||
public override fun atan(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||
unaryOperationFunction(TrigonometricOperations.ATAN_OPERATION)(arg)
|
||||
|
||||
public override fun power(arg: Expression<T>, pow: Number): Expression<T> =
|
||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||
binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow))
|
||||
|
||||
public override fun exp(arg: Expression<T>): Expression<T> =
|
||||
unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||
unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg)
|
||||
|
||||
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||
public override fun ln(arg: Expression<T>): Expression<T> =
|
||||
unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg)
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionField>.unaryOperation(operation, arg)
|
||||
public override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionField>.unaryOperationFunction(operation)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
super<FunctionalExpressionField>.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionField>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public inline fun <T, A : Space<T>> A.expressionInSpace(block: FunctionalExpressionSpace<T, A>.() -> Expression<T>): Expression<T> =
|
||||
|
@ -1,6 +1,7 @@
|
||||
package kscience.kmath.expressions
|
||||
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kotlin.contracts.InvocationKind
|
||||
@ -79,10 +80,11 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||
/**
|
||||
* Represents field in context of which functions can be derived.
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
public val context: F,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, RingWithNumbers<AutoDiffValue<T>> {
|
||||
public override val zero: AutoDiffValue<T>
|
||||
get() = const(context.zero)
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.Ring
|
||||
import kscience.kmath.structures.*
|
||||
|
||||
@ -21,30 +20,11 @@ public class BufferMatrixContext<T : Any, R : Ring<T>>(
|
||||
public companion object
|
||||
}
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
public object RealMatrixContext : GenericMatrixContext<Double, RealField, BufferMatrix<Double>> {
|
||||
public override val elementContext: RealField
|
||||
get() = RealField
|
||||
|
||||
public override inline fun produce(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: (i: Int, j: Int) -> Double,
|
||||
): BufferMatrix<Double> {
|
||||
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
|
||||
public override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> =
|
||||
RealBuffer(size, initializer)
|
||||
}
|
||||
|
||||
public class BufferMatrix<T : Any>(
|
||||
public override val rowNum: Int,
|
||||
public override val colNum: Int,
|
||||
public val buffer: Buffer<out T>,
|
||||
public override val features: Set<MatrixFeature> = emptySet(),
|
||||
) : FeaturedMatrix<T> {
|
||||
) : Matrix<T> {
|
||||
|
||||
init {
|
||||
require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" }
|
||||
@ -52,9 +32,6 @@ public class BufferMatrix<T : Any>(
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
public override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
|
||||
BufferMatrix(rowNum, colNum, buffer, this.features + features)
|
||||
|
||||
public override operator fun get(index: IntArray): T = get(index[0], index[1])
|
||||
public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
||||
|
||||
@ -66,23 +43,26 @@ public class BufferMatrix<T : Any>(
|
||||
if (this === other) return true
|
||||
|
||||
return when (other) {
|
||||
is NDStructure<*> -> return NDStructure.equals(this, other)
|
||||
is NDStructure<*> -> NDStructure.contentEquals(this, other)
|
||||
else -> false
|
||||
}
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int {
|
||||
var result = buffer.hashCode()
|
||||
result = 31 * result + features.hashCode()
|
||||
override fun hashCode(): Int {
|
||||
var result = rowNum
|
||||
result = 31 * result + colNum
|
||||
result = 31 * result + buffer.hashCode()
|
||||
return result
|
||||
}
|
||||
|
||||
public override fun toString(): String {
|
||||
return if (rowNum <= 5 && colNum <= 5)
|
||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
|
||||
"Matrix(rowsNum = $rowNum, colNum = $colNum)\n" +
|
||||
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
|
||||
buffer.asSequence().joinToString(separator = "\t") { it.toString() }
|
||||
}
|
||||
else "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
|
||||
else "Matrix(rowsNum = $rowNum, colNum = $colNum)"
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -1,83 +0,0 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.operations.Ring
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.Structure2D
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* A 2d structure plus optional matrix-specific features
|
||||
*/
|
||||
public interface FeaturedMatrix<T : Any> : Matrix<T> {
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
public val features: Set<MatrixFeature>
|
||||
|
||||
/**
|
||||
* Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure.
|
||||
*
|
||||
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
|
||||
* add only those features that are valid.
|
||||
*/
|
||||
public fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
|
||||
|
||||
public companion object
|
||||
}
|
||||
|
||||
public inline fun Structure2D.Companion.real(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: (Int, Int) -> Double,
|
||||
): BufferMatrix<Double> = MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
* Build a square matrix from given elements.
|
||||
*/
|
||||
public fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
|
||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||
require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(size, size, buffer)
|
||||
}
|
||||
|
||||
public val Matrix<*>.features: Set<MatrixFeature> get() = (this as? FeaturedMatrix)?.features ?: emptySet()
|
||||
|
||||
/**
|
||||
* Check if matrix has the given feature class
|
||||
*/
|
||||
public inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean =
|
||||
features.find { it is T } != null
|
||||
|
||||
/**
|
||||
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||
*/
|
||||
public inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
|
||||
features.filterIsInstance<T>().firstOrNull()
|
||||
|
||||
/**
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix(rows, columns, DiagonalFeature) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix(rows, columns) { _, _ -> elementContext.zero }
|
||||
|
||||
public class TransposedFeature<T : Any>(public val original: Matrix<T>) : MatrixFeature
|
||||
|
||||
/**
|
||||
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
||||
*/
|
||||
public fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
|
||||
return getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
||||
colNum,
|
||||
rowNum,
|
||||
setOf(TransposedFeature(this))
|
||||
) { i, j -> get(j, i) }
|
||||
}
|
@ -1,31 +1,31 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
import kscience.kmath.structures.*
|
||||
|
||||
/**
|
||||
* Common implementation of [LUPDecompositionFeature]
|
||||
* Common implementation of [LupDecompositionFeature].
|
||||
*/
|
||||
public class LUPDecomposition<T : Any>(
|
||||
public val context: MatrixContext<T, FeaturedMatrix<T>>,
|
||||
public class LupDecomposition<T : Any>(
|
||||
public val context: MatrixContext<T, Matrix<T>>,
|
||||
public val elementContext: Field<T>,
|
||||
public val lu: Structure2D<T>,
|
||||
public val lu: Matrix<T>,
|
||||
public val pivot: IntArray,
|
||||
private val even: Boolean,
|
||||
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
||||
|
||||
) : LupDecompositionFeature<T>, DeterminantFeature<T> {
|
||||
/**
|
||||
* Returns the matrix L of the decomposition.
|
||||
*
|
||||
* L is a lower-triangular matrix with [Ring.one] in diagonal
|
||||
*/
|
||||
override val l: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
|
||||
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
when {
|
||||
j < i -> lu[i, j]
|
||||
j == i -> elementContext.one
|
||||
else -> elementContext.zero
|
||||
}
|
||||
}
|
||||
} + LFeature
|
||||
|
||||
|
||||
/**
|
||||
@ -33,9 +33,9 @@ public class LUPDecomposition<T : Any>(
|
||||
*
|
||||
* U is an upper-triangular matrix including the diagonal
|
||||
*/
|
||||
override val u: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
|
||||
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j >= i) lu[i, j] else elementContext.zero
|
||||
}
|
||||
} + UFeature
|
||||
|
||||
/**
|
||||
* Returns the P rows permutation matrix.
|
||||
@ -43,7 +43,7 @@ public class LUPDecomposition<T : Any>(
|
||||
* P is a sparse matrix with exactly one element set to [Ring.one] in
|
||||
* each row and each column, all other elements being set to [Ring.zero].
|
||||
*/
|
||||
override val p: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j == pivot[i]) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
@ -64,12 +64,12 @@ internal fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, *>.abs
|
||||
/**
|
||||
* Create a lup decomposition of generic matrix.
|
||||
*/
|
||||
public fun <T : Comparable<T>> MatrixContext<T, FeaturedMatrix<T>>.lup(
|
||||
public fun <T : Comparable<T>> MatrixContext<T, Matrix<T>>.lup(
|
||||
factory: MutableBufferFactory<T>,
|
||||
elementContext: Field<T>,
|
||||
matrix: Matrix<T>,
|
||||
checkSingular: (T) -> Boolean,
|
||||
): LUPDecomposition<T> {
|
||||
): LupDecomposition<T> {
|
||||
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
|
||||
val m = matrix.colNum
|
||||
val pivot = IntArray(matrix.rowNum)
|
||||
@ -138,20 +138,23 @@ public fun <T : Comparable<T>> MatrixContext<T, FeaturedMatrix<T>>.lup(
|
||||
for (row in col + 1 until m) lu[row, col] /= luDiag
|
||||
}
|
||||
|
||||
return LUPDecomposition(this@lup, elementContext, lu.collect(), pivot, even)
|
||||
return LupDecomposition(this@lup, elementContext, lu.collect(), pivot, even)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.lup(
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, Matrix<T>>.lup(
|
||||
matrix: Matrix<T>,
|
||||
noinline checkSingular: (T) -> Boolean,
|
||||
): LUPDecomposition<T> = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular)
|
||||
): LupDecomposition<T> = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular)
|
||||
|
||||
public fun MatrixContext<Double, FeaturedMatrix<Double>>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
|
||||
public fun MatrixContext<Double, Matrix<Double>>.lup(matrix: Matrix<Double>): LupDecomposition<Double> =
|
||||
lup(Buffer.Companion::real, RealField, matrix) { it < 1e-11 }
|
||||
|
||||
public fun <T : Any> LUPDecomposition<T>.solveWithLUP(factory: MutableBufferFactory<T>, matrix: Matrix<T>): FeaturedMatrix<T> {
|
||||
public fun <T : Any> LupDecomposition<T>.solveWithLUP(
|
||||
factory: MutableBufferFactory<T>,
|
||||
matrix: Matrix<T>,
|
||||
): Matrix<T> {
|
||||
require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" }
|
||||
|
||||
BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run {
|
||||
@ -196,34 +199,41 @@ public fun <T : Any> LUPDecomposition<T>.solveWithLUP(factory: MutableBufferFact
|
||||
}
|
||||
}
|
||||
|
||||
public inline fun <reified T : Any> LUPDecomposition<T>.solveWithLUP(matrix: Matrix<T>): Matrix<T> =
|
||||
public inline fun <reified T : Any> LupDecomposition<T>.solveWithLUP(matrix: Matrix<T>): Matrix<T> =
|
||||
solveWithLUP(MutableBuffer.Companion::auto, matrix)
|
||||
|
||||
/**
|
||||
* Solve a linear equation **a*x = b** using LUP decomposition
|
||||
*/
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.solveWithLUP(
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, Matrix<T>>.solveWithLUP(
|
||||
a: Matrix<T>,
|
||||
b: Matrix<T>,
|
||||
noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto,
|
||||
noinline checkSingular: (T) -> Boolean,
|
||||
): FeaturedMatrix<T> {
|
||||
): Matrix<T> {
|
||||
// Use existing decomposition if it is provided by matrix
|
||||
val decomposition = a.getFeature() ?: lup(bufferFactory, elementContext, a, checkSingular)
|
||||
return decomposition.solveWithLUP(bufferFactory, b)
|
||||
}
|
||||
|
||||
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): FeaturedMatrix<Double> =
|
||||
solveWithLUP(a, b) { it < 1e-11 }
|
||||
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.inverseWithLUP(
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, Matrix<T>>.inverseWithLUP(
|
||||
matrix: Matrix<T>,
|
||||
noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto,
|
||||
noinline checkSingular: (T) -> Boolean,
|
||||
): FeaturedMatrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)
|
||||
): Matrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)
|
||||
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> {
|
||||
// Use existing decomposition if it is provided by matrix
|
||||
val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real
|
||||
val decomposition: LupDecomposition<Double> = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 }
|
||||
return decomposition.solveWithLUP(bufferFactory, b)
|
||||
}
|
||||
|
||||
/**
|
||||
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
|
||||
*/
|
||||
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): FeaturedMatrix<Double> =
|
||||
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), Buffer.Companion::real) { it < 1e-11 }
|
||||
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): Matrix<Double> =
|
||||
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum))
|
@ -1,12 +1,9 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.BufferFactory
|
||||
import kscience.kmath.structures.Structure2D
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kscience.kmath.structures.*
|
||||
|
||||
public class MatrixBuilder(public val rows: Int, public val columns: Int) {
|
||||
public operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> {
|
||||
public operator fun <T : Any> invoke(vararg elements: T): Matrix<T> {
|
||||
require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
@ -17,7 +14,7 @@ public class MatrixBuilder(public val rows: Int, public val columns: Int) {
|
||||
|
||||
public fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns)
|
||||
|
||||
public fun <T : Any> Structure2D.Companion.row(vararg values: T): FeaturedMatrix<T> {
|
||||
public fun <T : Any> Structure2D.Companion.row(vararg values: T): Matrix<T> {
|
||||
val buffer = values.asBuffer()
|
||||
return BufferMatrix(1, values.size, buffer)
|
||||
}
|
||||
@ -26,12 +23,12 @@ public inline fun <reified T : Any> Structure2D.Companion.row(
|
||||
size: Int,
|
||||
factory: BufferFactory<T> = Buffer.Companion::auto,
|
||||
noinline builder: (Int) -> T
|
||||
): FeaturedMatrix<T> {
|
||||
): Matrix<T> {
|
||||
val buffer = factory(size, builder)
|
||||
return BufferMatrix(1, size, buffer)
|
||||
}
|
||||
|
||||
public fun <T : Any> Structure2D.Companion.column(vararg values: T): FeaturedMatrix<T> {
|
||||
public fun <T : Any> Structure2D.Companion.column(vararg values: T): Matrix<T> {
|
||||
val buffer = values.asBuffer()
|
||||
return BufferMatrix(values.size, 1, buffer)
|
||||
}
|
||||
@ -40,7 +37,7 @@ public inline fun <reified T : Any> Structure2D.Companion.column(
|
||||
size: Int,
|
||||
factory: BufferFactory<T> = Buffer.Companion::auto,
|
||||
noinline builder: (Int) -> T
|
||||
): FeaturedMatrix<T> {
|
||||
): Matrix<T> {
|
||||
val buffer = factory(size, builder)
|
||||
return BufferMatrix(size, 1, buffer)
|
||||
}
|
||||
|
@ -18,10 +18,16 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
|
||||
*/
|
||||
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M
|
||||
|
||||
/**
|
||||
* Produce a point compatible with matrix space (and possibly optimized for it)
|
||||
*/
|
||||
public fun point(size: Int, initializer: (Int) -> T): Point<T> = Buffer.boxing(size, initializer)
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public override fun binaryOperation(operation: String, left: Matrix<T>, right: Matrix<T>): M = when (operation) {
|
||||
"dot" -> left dot right
|
||||
else -> super.binaryOperation(operation, left, right) as M
|
||||
public override fun binaryOperationFunction(operation: String): (left: Matrix<T>, right: Matrix<T>) -> M =
|
||||
when (operation) {
|
||||
"dot" -> { left, right -> left dot right }
|
||||
else -> super.binaryOperationFunction(operation) as (Matrix<T>, Matrix<T>) -> M
|
||||
}
|
||||
|
||||
/**
|
||||
@ -61,10 +67,6 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
|
||||
public operator fun T.times(m: Matrix<T>): M = m * this
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* Non-boxing double matrix
|
||||
*/
|
||||
public val real: RealMatrixContext = RealMatrixContext
|
||||
|
||||
/**
|
||||
* A structured matrix with custom buffer
|
||||
@ -88,11 +90,6 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>, out M : Matrix<T>> :
|
||||
*/
|
||||
public val elementContext: R
|
||||
|
||||
/**
|
||||
* Produce a point compatible with matrix space
|
||||
*/
|
||||
public fun point(size: Int, initializer: (Int) -> T): Point<T>
|
||||
|
||||
public override infix fun Matrix<T>.dot(other: Matrix<T>): M {
|
||||
//TODO add typed error
|
||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||
@ -136,8 +133,6 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>, out M : Matrix<T>> :
|
||||
public override fun multiply(a: Matrix<T>, k: Number): M =
|
||||
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
|
||||
|
||||
public operator fun Number.times(matrix: FeaturedMatrix<T>): M = multiply(matrix, this)
|
||||
|
||||
public override operator fun Matrix<T>.times(value: T): M =
|
||||
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
|
||||
}
|
||||
|
@ -1,62 +1,158 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.structures.Matrix
|
||||
|
||||
/**
|
||||
* A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix
|
||||
* operations performance in some cases.
|
||||
* A marker interface representing some properties of matrices or additional transformations of them. Features are used
|
||||
* to optimize matrix operations performance in some cases or retrieve the APIs.
|
||||
*/
|
||||
public interface MatrixFeature
|
||||
|
||||
/**
|
||||
* The matrix with this feature is considered to have only diagonal non-null elements
|
||||
* Matrices with this feature are considered to have only diagonal non-null elements.
|
||||
*/
|
||||
public object DiagonalFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Matrix with this feature has all zero elements
|
||||
*/
|
||||
public object ZeroFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Matrix with this feature have unit elements on diagonal and zero elements in all other places
|
||||
*/
|
||||
public object UnitFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Inverted matrix feature
|
||||
*/
|
||||
public interface InverseMatrixFeature<T : Any> : MatrixFeature {
|
||||
public val inverse: FeaturedMatrix<T>
|
||||
public interface DiagonalFeature : MatrixFeature{
|
||||
public companion object: DiagonalFeature
|
||||
}
|
||||
|
||||
/**
|
||||
* A determinant container
|
||||
* Matrices with this feature have all zero elements.
|
||||
*/
|
||||
public object ZeroFeature : DiagonalFeature
|
||||
|
||||
/**
|
||||
* Matrices with this feature have unit elements on diagonal and zero elements in all other places.
|
||||
*/
|
||||
public object UnitFeature : DiagonalFeature
|
||||
|
||||
/**
|
||||
* Matrices with this feature can be inverted: [inverse] = `a`<sup>-1</sup> where `a` is the owning matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface InverseMatrixFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The inverse matrix of the matrix that owns this feature.
|
||||
*/
|
||||
public val inverse: Matrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrices with this feature can compute their determinant.
|
||||
*/
|
||||
public interface DeterminantFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The determinant of the matrix that owns this feature.
|
||||
*/
|
||||
public val determinant: T
|
||||
}
|
||||
|
||||
/**
|
||||
* Produces a [DeterminantFeature] where the [DeterminantFeature.determinant] is [determinant].
|
||||
*
|
||||
* @param determinant the value of determinant.
|
||||
* @return a new [DeterminantFeature].
|
||||
*/
|
||||
@Suppress("FunctionName")
|
||||
public fun <T : Any> DeterminantFeature(determinant: T): DeterminantFeature<T> = object : DeterminantFeature<T> {
|
||||
override val determinant: T = determinant
|
||||
}
|
||||
|
||||
/**
|
||||
* Lower triangular matrix
|
||||
* Matrices with this feature are lower triangular ones.
|
||||
*/
|
||||
public object LFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Upper triangular feature
|
||||
* Matrices with this feature are upper triangular ones.
|
||||
*/
|
||||
public object UFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* TODO add documentation
|
||||
* Matrices with this feature support LU factorization with partial pivoting: *[p] · a = [l] · [u]* where
|
||||
* *a* is the owning matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface LUPDecompositionFeature<T : Any> : MatrixFeature {
|
||||
public val l: FeaturedMatrix<T>
|
||||
public val u: FeaturedMatrix<T>
|
||||
public val p: FeaturedMatrix<T>
|
||||
public interface LupDecompositionFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The lower triangular matrix in this decomposition. It may have [LFeature].
|
||||
*/
|
||||
public val l: Matrix<T>
|
||||
|
||||
/**
|
||||
* The upper triangular matrix in this decomposition. It may have [UFeature].
|
||||
*/
|
||||
public val u: Matrix<T>
|
||||
|
||||
/**
|
||||
* The permutation matrix in this decomposition.
|
||||
*/
|
||||
public val p: Matrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrices with this feature are orthogonal ones: *a · a<sup>T</sup> = u* where *a* is the owning matrix, *u*
|
||||
* is the unit matrix ([UnitFeature]).
|
||||
*/
|
||||
public object OrthogonalFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Matrices with this feature support QR factorization: *a = [q] · [r]* where *a* is the owning matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface QRDecompositionFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The orthogonal matrix in this decomposition. It may have [OrthogonalFeature].
|
||||
*/
|
||||
public val q: Matrix<T>
|
||||
|
||||
/**
|
||||
* The upper triangular matrix in this decomposition. It may have [UFeature].
|
||||
*/
|
||||
public val r: Matrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrices with this feature support Cholesky factorization: *a = [l] · [l]<sup>H</sup>* where *a* is the
|
||||
* owning matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface CholeskyDecompositionFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature].
|
||||
*/
|
||||
public val l: Matrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrices with this feature support SVD: *a = [u] · [s] · [v]<sup>H</sup>* where *a* is the owning
|
||||
* matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface SingularValueDecompositionFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The matrix in this decomposition. It is unitary, and it consists from left singular vectors.
|
||||
*/
|
||||
public val u: Matrix<T>
|
||||
|
||||
/**
|
||||
* The matrix in this decomposition. Its main diagonal elements are singular values.
|
||||
*/
|
||||
public val s: Matrix<T>
|
||||
|
||||
/**
|
||||
* The matrix in this decomposition. It is unitary, and it consists from right singular vectors.
|
||||
*/
|
||||
public val v: Matrix<T>
|
||||
|
||||
/**
|
||||
* The buffer of singular values of this SVD.
|
||||
*/
|
||||
public val singularValues: Point<T>
|
||||
}
|
||||
|
||||
//TODO add sparse matrix feature
|
||||
|
@ -0,0 +1,105 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.Ring
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.Structure2D
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kscience.kmath.structures.getFeature
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.safeCast
|
||||
|
||||
/**
|
||||
* A [Matrix] that holds [MatrixFeature] objects.
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public class MatrixWrapper<T : Any> internal constructor(
|
||||
public val origin: Matrix<T>,
|
||||
public val features: Set<MatrixFeature>,
|
||||
) : Matrix<T> by origin {
|
||||
|
||||
/**
|
||||
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
override fun <T : Any> getFeature(type: KClass<T>): T? = type.safeCast(features.find { type.isInstance(it) })
|
||||
?: origin.getFeature(type)
|
||||
|
||||
override fun equals(other: Any?): Boolean = origin == other
|
||||
override fun hashCode(): Int = origin.hashCode()
|
||||
override fun toString(): String {
|
||||
return "MatrixWrapper(matrix=$origin, features=$features)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the original matrix. If this is a wrapper, return its origin. If not, this matrix.
|
||||
* Origin does not necessary store all features.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public val <T : Any> Matrix<T>.origin: Matrix<T> get() = (this as? MatrixWrapper)?.origin ?: this
|
||||
|
||||
/**
|
||||
* Add a single feature to a [Matrix]
|
||||
*/
|
||||
public operator fun <T : Any> Matrix<T>.plus(newFeature: MatrixFeature): MatrixWrapper<T> = if (this is MatrixWrapper) {
|
||||
MatrixWrapper(origin, features + newFeature)
|
||||
} else {
|
||||
MatrixWrapper(this, setOf(newFeature))
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a collection of features to a [Matrix]
|
||||
*/
|
||||
public operator fun <T : Any> Matrix<T>.plus(newFeatures: Collection<MatrixFeature>): MatrixWrapper<T> =
|
||||
if (this is MatrixWrapper) {
|
||||
MatrixWrapper(origin, features + newFeatures)
|
||||
} else {
|
||||
MatrixWrapper(this, newFeatures.toSet())
|
||||
}
|
||||
|
||||
public inline fun Structure2D.Companion.real(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: (Int, Int) -> Double,
|
||||
): BufferMatrix<Double> = MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
* Build a square matrix from given elements.
|
||||
*/
|
||||
public fun <T : Any> Structure2D.Companion.square(vararg elements: T): Matrix<T> {
|
||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||
require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(size, size, buffer)
|
||||
}
|
||||
|
||||
/**
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.one(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix(rows, columns) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
} + UnitFeature
|
||||
|
||||
|
||||
/**
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.zero(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } + ZeroFeature
|
||||
|
||||
public class TransposedFeature<T : Any>(public val original: Matrix<T>) : MatrixFeature
|
||||
|
||||
/**
|
||||
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
|
||||
return getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
||||
colNum,
|
||||
rowNum,
|
||||
) { i, j -> get(j, i) } + TransposedFeature(this)
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.RealBuffer
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>> {
|
||||
|
||||
public override inline fun produce(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: (i: Int, j: Int) -> Double,
|
||||
): BufferMatrix<Double> {
|
||||
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
|
||||
private fun Matrix<Double>.wrap(): BufferMatrix<Double> = if (this is BufferMatrix) this else {
|
||||
produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||
}
|
||||
|
||||
public fun one(rows: Int, columns: Int): Matrix<Double> = VirtualMatrix(rows, columns) { i, j ->
|
||||
if (i == j) 1.0 else 0.0
|
||||
} + DiagonalFeature
|
||||
|
||||
public override infix fun Matrix<Double>.dot(other: Matrix<Double>): BufferMatrix<Double> {
|
||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||
return produce(rowNum, other.colNum) { i, j ->
|
||||
var res = 0.0
|
||||
for (l in 0 until colNum) {
|
||||
res += get(i, l) * other.get(l, j)
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
public override infix fun Matrix<Double>.dot(vector: Point<Double>): Point<Double> {
|
||||
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
||||
return RealBuffer(rowNum) { i ->
|
||||
var res = 0.0
|
||||
for (j in 0 until colNum) {
|
||||
res += get(i, j) * vector[j]
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
override fun add(a: Matrix<Double>, b: Matrix<Double>): BufferMatrix<Double> {
|
||||
require(a.rowNum == b.rowNum) { "Row number mismatch in matrix addition. Left side: ${a.rowNum}, right side: ${b.rowNum}" }
|
||||
require(a.colNum == b.colNum) { "Column number mismatch in matrix addition. Left side: ${a.colNum}, right side: ${b.colNum}" }
|
||||
return produce(a.rowNum, a.colNum) { i, j ->
|
||||
a[i, j] + b[i, j]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.times(value: Double): BufferMatrix<Double> =
|
||||
produce(rowNum, colNum) { i, j -> get(i, j) * value }
|
||||
|
||||
|
||||
override fun multiply(a: Matrix<Double>, k: Number): BufferMatrix<Double> =
|
||||
produce(a.rowNum, a.colNum) { i, j -> a[i, j] * k.toDouble() }
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Partially optimized real-valued matrix
|
||||
*/
|
||||
public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext
|
@ -5,31 +5,16 @@ import kscience.kmath.structures.Matrix
|
||||
public class VirtualMatrix<T : Any>(
|
||||
override val rowNum: Int,
|
||||
override val colNum: Int,
|
||||
override val features: Set<MatrixFeature> = emptySet(),
|
||||
public val generator: (i: Int, j: Int) -> T
|
||||
) : FeaturedMatrix<T> {
|
||||
public constructor(
|
||||
rowNum: Int,
|
||||
colNum: Int,
|
||||
vararg features: MatrixFeature,
|
||||
generator: (i: Int, j: Int) -> T
|
||||
) : this(
|
||||
rowNum,
|
||||
colNum,
|
||||
setOf(*features),
|
||||
generator
|
||||
)
|
||||
) : Matrix<T> {
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
override operator fun get(i: Int, j: Int): T = generator(i, j)
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
|
||||
VirtualMatrix(rowNum, colNum, this.features + features, generator)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is FeaturedMatrix<*>) return false
|
||||
if (other !is Matrix<*>) return false
|
||||
|
||||
if (rowNum != other.rowNum) return false
|
||||
if (colNum != other.colNum) return false
|
||||
@ -40,21 +25,9 @@ public class VirtualMatrix<T : Any>(
|
||||
override fun hashCode(): Int {
|
||||
var result = rowNum
|
||||
result = 31 * result + colNum
|
||||
result = 31 * result + features.hashCode()
|
||||
result = 31 * result + generator.hashCode()
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* Wrap a matrix adding additional features to it
|
||||
*/
|
||||
public fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): FeaturedMatrix<T> {
|
||||
return if (matrix is VirtualMatrix)
|
||||
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
|
||||
else
|
||||
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j -> matrix[i, j] }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -13,50 +13,86 @@ public annotation class KMathContext
|
||||
*/
|
||||
public interface Algebra<T> {
|
||||
/**
|
||||
* Wrap raw string or variable
|
||||
* Wraps a raw string to [T] object. This method is designed for three purposes:
|
||||
*
|
||||
* 1. Mathematical constants (`e`, `pi`).
|
||||
* 2. Variables for expression-like contexts (`a`, `b`, `c`...).
|
||||
* 3. Literals (`{1, 2}`, (`(3; 4)`)).
|
||||
*
|
||||
* In case if algebra can't parse the string, this method must throw [kotlin.IllegalStateException].
|
||||
*
|
||||
* @param value the raw string.
|
||||
* @return an object.
|
||||
*/
|
||||
public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
|
||||
|
||||
/**
|
||||
* Dynamic call of unary operation with name [operation] on [arg]
|
||||
*/
|
||||
public fun unaryOperation(operation: String, arg: T): T
|
||||
|
||||
/**
|
||||
* Dynamic call of binary operation [operation] on [left] and [right]
|
||||
*/
|
||||
public fun binaryOperation(operation: String, left: T, right: T): T
|
||||
}
|
||||
|
||||
/**
|
||||
* An algebraic structure where elements can have numeric representation.
|
||||
* Dynamically dispatches an unary operation with the certain name.
|
||||
*
|
||||
* @param T the type of element of this structure.
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with second `unaryOperation` overload:
|
||||
* i.e. `unaryOperationFunction(a)(b) == unaryOperation(a, b)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @return an operation.
|
||||
*/
|
||||
public interface NumericAlgebra<T> : Algebra<T> {
|
||||
/**
|
||||
* Wraps a number.
|
||||
*/
|
||||
public fun number(value: Number): T
|
||||
public fun unaryOperationFunction(operation: String): (arg: T) -> T =
|
||||
error("Unary operation $operation not defined in $this")
|
||||
|
||||
/**
|
||||
* Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number].
|
||||
* Dynamically invokes an unary operation with the certain name.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with second [unaryOperationFunction] overload:
|
||||
* i.e. `unaryOperationFunction(a)(b) == unaryOperation(a, b)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @param arg the argument of operation.
|
||||
* @return a result of operation.
|
||||
*/
|
||||
public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||
binaryOperation(operation, number(left), right)
|
||||
public fun unaryOperation(operation: String, arg: T): T = unaryOperationFunction(operation)(arg)
|
||||
|
||||
/**
|
||||
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
|
||||
* Dynamically dispatches a binary operation with the certain name.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with second [binaryOperationFunction] overload:
|
||||
* i.e. `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @return an operation.
|
||||
*/
|
||||
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||
leftSideNumberOperation(operation, right, left)
|
||||
public fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
|
||||
error("Binary operation $operation not defined in $this")
|
||||
|
||||
/**
|
||||
* Dynamically invokes a binary operation with the certain name.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with second [binaryOperationFunction] overload:
|
||||
* i.e. `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @param left the first argument of operation.
|
||||
* @param right the second argument of operation.
|
||||
* @return a result of operation.
|
||||
*/
|
||||
public fun binaryOperation(operation: String, left: T, right: T): T = binaryOperationFunction(operation)(left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Call a block with an [Algebra] as receiver.
|
||||
*/
|
||||
// TODO add contract when KT-32313 is fixed
|
||||
public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = block()
|
||||
public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
|
||||
|
||||
/**
|
||||
* Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as
|
||||
@ -146,26 +182,26 @@ public interface SpaceOperations<T> : Algebra<T> {
|
||||
*/
|
||||
public operator fun Number.times(b: T): T = b * this
|
||||
|
||||
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||
PLUS_OPERATION -> arg
|
||||
MINUS_OPERATION -> -arg
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
public override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
||||
PLUS_OPERATION -> { arg -> arg }
|
||||
MINUS_OPERATION -> { arg -> -arg }
|
||||
else -> super.unaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
PLUS_OPERATION -> add(left, right)
|
||||
MINUS_OPERATION -> left - right
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||
PLUS_OPERATION -> ::add
|
||||
MINUS_OPERATION -> { left, right -> left - right }
|
||||
else -> super.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The identifier of addition.
|
||||
* The identifier of addition and unary positive operator.
|
||||
*/
|
||||
public const val PLUS_OPERATION: String = "+"
|
||||
|
||||
/**
|
||||
* The identifier of subtraction (and negation).
|
||||
* The identifier of subtraction and unary negative operator.
|
||||
*/
|
||||
public const val MINUS_OPERATION: String = "-"
|
||||
}
|
||||
@ -207,9 +243,9 @@ public interface RingOperations<T> : SpaceOperations<T> {
|
||||
*/
|
||||
public operator fun T.times(b: T): T = multiply(this, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
TIMES_OPERATION -> multiply(left, right)
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||
TIMES_OPERATION -> ::multiply
|
||||
else -> super.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public companion object {
|
||||
@ -226,61 +262,11 @@ public interface RingOperations<T> : SpaceOperations<T> {
|
||||
*
|
||||
* @param T the type of element of this ring.
|
||||
*/
|
||||
public interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
|
||||
public interface Ring<T> : Space<T>, RingOperations<T> {
|
||||
/**
|
||||
* neutral operation for multiplication
|
||||
*/
|
||||
public val one: T
|
||||
|
||||
override fun number(value: Number): T = one * value.toDouble()
|
||||
|
||||
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left + right
|
||||
SpaceOperations.MINUS_OPERATION -> left - right
|
||||
RingOperations.TIMES_OPERATION -> left * right
|
||||
else -> super.leftSideNumberOperation(operation, left, right)
|
||||
}
|
||||
|
||||
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left + right
|
||||
SpaceOperations.MINUS_OPERATION -> left - right
|
||||
RingOperations.TIMES_OPERATION -> left * right
|
||||
else -> super.rightSideNumberOperation(operation, left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Addition of element and scalar.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
public operator fun T.plus(b: Number): T = this + number(b)
|
||||
|
||||
/**
|
||||
* Addition of scalar and element.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
public operator fun Number.plus(b: T): T = b + this
|
||||
|
||||
/**
|
||||
* Subtraction of element from number.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun T.minus(b: Number): T = this - number(b)
|
||||
|
||||
/**
|
||||
* Subtraction of number from element.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun Number.minus(b: T): T = -b + this
|
||||
}
|
||||
|
||||
/**
|
||||
@ -308,9 +294,9 @@ public interface FieldOperations<T> : RingOperations<T> {
|
||||
*/
|
||||
public operator fun T.div(b: T): T = divide(this, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) {
|
||||
DIV_OPERATION -> divide(left, right)
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||
DIV_OPERATION -> ::divide
|
||||
else -> super.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public companion object {
|
||||
|
@ -1,5 +1,6 @@
|
||||
package kscience.kmath.operations
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.BigInt.Companion.BASE
|
||||
import kscience.kmath.operations.BigInt.Companion.BASE_SIZE
|
||||
import kscience.kmath.structures.*
|
||||
@ -16,7 +17,8 @@ public typealias TBase = ULong
|
||||
*
|
||||
* @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai)
|
||||
*/
|
||||
public object BigIntField : Field<BigInt> {
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object BigIntField : Field<BigInt>, RingWithNumbers<BigInt> {
|
||||
override val zero: BigInt = BigInt.ZERO
|
||||
override val one: BigInt = BigInt.ONE
|
||||
|
||||
|
@ -3,6 +3,7 @@ package kscience.kmath.operations
|
||||
import kscience.kmath.memory.MemoryReader
|
||||
import kscience.kmath.memory.MemorySpec
|
||||
import kscience.kmath.memory.MemoryWriter
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.MemoryBuffer
|
||||
import kscience.kmath.structures.MutableBuffer
|
||||
@ -41,7 +42,8 @@ private val PI_DIV_2 = Complex(PI / 2, 0)
|
||||
/**
|
||||
* A field of [Complex].
|
||||
*/
|
||||
public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex>, RingWithNumbers<Complex> {
|
||||
override val zero: Complex = 0.0.toComplex()
|
||||
override val one: Complex = 1.0.toComplex()
|
||||
|
||||
@ -156,7 +158,7 @@ public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
|
||||
|
||||
override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg)
|
||||
|
||||
override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value)
|
||||
override fun symbol(value: String): Complex = if (value == "i") i else super<ExtendedField>.symbol(value)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -0,0 +1,125 @@
|
||||
package kscience.kmath.operations
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
|
||||
/**
|
||||
* An algebraic structure where elements can have numeric representation.
|
||||
*
|
||||
* @param T the type of element of this structure.
|
||||
*/
|
||||
public interface NumericAlgebra<T> : Algebra<T> {
|
||||
/**
|
||||
* Wraps a number to [T] object.
|
||||
*
|
||||
* @param value the number to wrap.
|
||||
* @return an object.
|
||||
*/
|
||||
public fun number(value: Number): T
|
||||
|
||||
/**
|
||||
* Dynamically dispatches a binary operation with the certain name with numeric first argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with the other [leftSideNumberOperation] overload:
|
||||
* i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @return an operation.
|
||||
*/
|
||||
public fun leftSideNumberOperationFunction(operation: String): (left: Number, right: T) -> T =
|
||||
{ l, r -> binaryOperationFunction(operation)(number(l), r) }
|
||||
|
||||
/**
|
||||
* Dynamically invokes a binary operation with the certain name with numeric first argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with second [leftSideNumberOperation] overload:
|
||||
* i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @param left the first argument of operation.
|
||||
* @param right the second argument of operation.
|
||||
* @return a result of operation.
|
||||
*/
|
||||
public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||
leftSideNumberOperationFunction(operation)(left, right)
|
||||
|
||||
/**
|
||||
* Dynamically dispatches a binary operation with the certain name with numeric first argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload:
|
||||
* i.e. `rightSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @return an operation.
|
||||
*/
|
||||
public fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T =
|
||||
{ l, r -> binaryOperationFunction(operation)(l, number(r)) }
|
||||
|
||||
/**
|
||||
* Dynamically invokes a binary operation with the certain name with numeric second argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload:
|
||||
* i.e. `rightSideNumberOperationFunction(a)(b, c) == rightSideNumberOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @param left the first argument of operation.
|
||||
* @param right the second argument of operation.
|
||||
* @return a result of operation.
|
||||
*/
|
||||
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||
rightSideNumberOperationFunction(operation)(left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* A combination of [NumericAlgebra] and [Ring] that adds intrinsic simple operations on numbers like `T+1`
|
||||
* TODO to be removed and replaced by extensions after multiple receivers are there
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public interface RingWithNumbers<T>: Ring<T>, NumericAlgebra<T>{
|
||||
public override fun number(value: Number): T = one * value
|
||||
|
||||
/**
|
||||
* Addition of element and scalar.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
public operator fun T.plus(b: Number): T = this + number(b)
|
||||
|
||||
/**
|
||||
* Addition of scalar and element.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
public operator fun Number.plus(b: T): T = b + this
|
||||
|
||||
/**
|
||||
* Subtraction of element from number.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun T.minus(b: Number): T = this - number(b)
|
||||
|
||||
/**
|
||||
* Subtraction of number from element.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun Number.minus(b: T): T = -b + this
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
package kscience.kmath.operations
|
||||
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.pow as kpow
|
||||
|
||||
/**
|
||||
@ -15,30 +14,30 @@ public interface ExtendedFieldOperations<T> :
|
||||
public override fun tan(arg: T): T = sin(arg) / cos(arg)
|
||||
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
|
||||
|
||||
public override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||
TrigonometricOperations.COS_OPERATION -> cos(arg)
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(arg)
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(arg)
|
||||
TrigonometricOperations.ACOS_OPERATION -> acos(arg)
|
||||
TrigonometricOperations.ASIN_OPERATION -> asin(arg)
|
||||
TrigonometricOperations.ATAN_OPERATION -> atan(arg)
|
||||
HyperbolicOperations.COSH_OPERATION -> cosh(arg)
|
||||
HyperbolicOperations.SINH_OPERATION -> sinh(arg)
|
||||
HyperbolicOperations.TANH_OPERATION -> tanh(arg)
|
||||
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg)
|
||||
HyperbolicOperations.ASINH_OPERATION -> asinh(arg)
|
||||
HyperbolicOperations.ATANH_OPERATION -> atanh(arg)
|
||||
PowerOperations.SQRT_OPERATION -> sqrt(arg)
|
||||
ExponentialOperations.EXP_OPERATION -> exp(arg)
|
||||
ExponentialOperations.LN_OPERATION -> ln(arg)
|
||||
else -> super.unaryOperation(operation, arg)
|
||||
public override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
||||
TrigonometricOperations.COS_OPERATION -> ::cos
|
||||
TrigonometricOperations.SIN_OPERATION -> ::sin
|
||||
TrigonometricOperations.TAN_OPERATION -> ::tan
|
||||
TrigonometricOperations.ACOS_OPERATION -> ::acos
|
||||
TrigonometricOperations.ASIN_OPERATION -> ::asin
|
||||
TrigonometricOperations.ATAN_OPERATION -> ::atan
|
||||
HyperbolicOperations.COSH_OPERATION -> ::cosh
|
||||
HyperbolicOperations.SINH_OPERATION -> ::sinh
|
||||
HyperbolicOperations.TANH_OPERATION -> ::tanh
|
||||
HyperbolicOperations.ACOSH_OPERATION -> ::acosh
|
||||
HyperbolicOperations.ASINH_OPERATION -> ::asinh
|
||||
HyperbolicOperations.ATANH_OPERATION -> ::atanh
|
||||
PowerOperations.SQRT_OPERATION -> ::sqrt
|
||||
ExponentialOperations.EXP_OPERATION -> ::exp
|
||||
ExponentialOperations.LN_OPERATION -> ::ln
|
||||
else -> super<FieldOperations>.unaryOperationFunction(operation)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Advanced Number-like field that implements basic operations.
|
||||
*/
|
||||
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>, NumericAlgebra<T> {
|
||||
public override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2
|
||||
public override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2
|
||||
public override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
|
||||
@ -46,9 +45,10 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
|
||||
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
|
||||
|
||||
public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> power(left, right)
|
||||
else -> super.rightSideNumberOperation(operation, left, right)
|
||||
public override fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T =
|
||||
when (operation) {
|
||||
PowerOperations.POW_OPERATION -> ::power
|
||||
else -> super.rightSideNumberOperationFunction(operation)
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,9 +80,12 @@ public object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
public override val one: Double
|
||||
get() = 1.0
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> left pow right
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
override fun number(value: Number): Double = value.toDouble()
|
||||
|
||||
public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double =
|
||||
when (operation) {
|
||||
PowerOperations.POW_OPERATION -> ::power
|
||||
else -> super.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public override inline fun add(a: Double, b: Double): Double = a + b
|
||||
@ -130,9 +133,12 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
public override val one: Float
|
||||
get() = 1.0f
|
||||
|
||||
public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> left pow right
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
override fun number(value: Number): Float = value.toFloat()
|
||||
|
||||
public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float =
|
||||
when (operation) {
|
||||
PowerOperations.POW_OPERATION -> ::power
|
||||
else -> super.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public override inline fun add(a: Float, b: Float): Float = a + b
|
||||
@ -173,13 +179,15 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
* A field for [Int] without boxing. Does not produce corresponding ring element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||
public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
|
||||
public override val zero: Int
|
||||
get() = 0
|
||||
|
||||
public override val one: Int
|
||||
get() = 1
|
||||
|
||||
override fun number(value: Number): Int = value.toInt()
|
||||
|
||||
public override inline fun add(a: Int, b: Int): Int = a + b
|
||||
public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
|
||||
|
||||
@ -197,13 +205,15 @@ public object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||
* A field for [Short] without boxing. Does not produce appropriate ring element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||
public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short> {
|
||||
public override val zero: Short
|
||||
get() = 0
|
||||
|
||||
public override val one: Short
|
||||
get() = 1
|
||||
|
||||
override fun number(value: Number): Short = value.toShort()
|
||||
|
||||
public override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
|
||||
public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
|
||||
|
||||
@ -221,13 +231,15 @@ public object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||
* A field for [Byte] without boxing. Does not produce appropriate ring element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||
public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
|
||||
public override val zero: Byte
|
||||
get() = 0
|
||||
|
||||
public override val one: Byte
|
||||
get() = 1
|
||||
|
||||
override fun number(value: Number): Byte = value.toByte()
|
||||
|
||||
public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
||||
public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
|
||||
|
||||
@ -245,12 +257,14 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||
* A field for [Double] without boxing. Does not produce appropriate ring element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object LongRing : Ring<Long>, Norm<Long, Long> {
|
||||
public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
|
||||
public override val zero: Long
|
||||
get() = 0
|
||||
get() = 0L
|
||||
|
||||
public override val one: Long
|
||||
get() = 1
|
||||
get() = 1L
|
||||
|
||||
override fun number(value: Number): Long = value.toLong()
|
||||
|
||||
public override inline fun add(a: Long, b: Long): Long = a + b
|
||||
public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
|
@ -1,9 +1,7 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.operations.Complex
|
||||
import kscience.kmath.operations.ComplexField
|
||||
import kscience.kmath.operations.FieldElement
|
||||
import kscience.kmath.operations.complex
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
@ -12,15 +10,22 @@ public typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField
|
||||
/**
|
||||
* An optimized nd-field for complex numbers
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class ComplexNDField(override val shape: IntArray) :
|
||||
BufferedNDField<Complex, ComplexField>,
|
||||
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> {
|
||||
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>>,
|
||||
RingWithNumbers<NDBuffer<Complex>>{
|
||||
|
||||
override val strides: Strides = DefaultStrides(shape)
|
||||
override val elementContext: ComplexField get() = ComplexField
|
||||
override val zero: ComplexNDElement by lazy { produce { zero } }
|
||||
override val one: ComplexNDElement by lazy { produce { one } }
|
||||
|
||||
override fun number(value: Number): NDBuffer<Complex> {
|
||||
val c = value.toComplex()
|
||||
return produce { c }
|
||||
}
|
||||
|
||||
public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
|
||||
Buffer.complex(size) { initializer(it) }
|
||||
|
||||
@ -29,7 +34,7 @@ public class ComplexNDField(override val shape: IntArray) :
|
||||
*/
|
||||
override fun map(
|
||||
arg: NDBuffer<Complex>,
|
||||
transform: ComplexField.(Complex) -> Complex
|
||||
transform: ComplexField.(Complex) -> Complex,
|
||||
): ComplexNDElement {
|
||||
check(arg)
|
||||
val array = buildBuffer(arg.strides.linearSize) { offset -> ComplexField.transform(arg.buffer[offset]) }
|
||||
@ -43,7 +48,7 @@ public class ComplexNDField(override val shape: IntArray) :
|
||||
|
||||
override fun mapIndexed(
|
||||
arg: NDBuffer<Complex>,
|
||||
transform: ComplexField.(index: IntArray, Complex) -> Complex
|
||||
transform: ComplexField.(index: IntArray, Complex) -> Complex,
|
||||
): ComplexNDElement {
|
||||
check(arg)
|
||||
|
||||
@ -60,7 +65,7 @@ public class ComplexNDField(override val shape: IntArray) :
|
||||
override fun combine(
|
||||
a: NDBuffer<Complex>,
|
||||
b: NDBuffer<Complex>,
|
||||
transform: ComplexField.(Complex, Complex) -> Complex
|
||||
transform: ComplexField.(Complex, Complex) -> Complex,
|
||||
): ComplexNDElement {
|
||||
check(a, b)
|
||||
|
||||
@ -141,7 +146,7 @@ public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = Comple
|
||||
|
||||
public fun NDElement.Companion.complex(
|
||||
vararg shape: Int,
|
||||
initializer: ComplexField.(IntArray) -> Complex
|
||||
initializer: ComplexField.(IntArray) -> Complex,
|
||||
): ComplexNDElement = NDField.complex(*shape).produce(initializer)
|
||||
|
||||
/**
|
||||
|
@ -1,5 +1,6 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kotlin.jvm.JvmName
|
||||
import kotlin.native.concurrent.ThreadLocal
|
||||
import kotlin.reflect.KClass
|
||||
@ -38,14 +39,22 @@ public interface NDStructure<T> {
|
||||
*/
|
||||
public fun elements(): Sequence<Pair<IntArray, T>>
|
||||
|
||||
//force override equality and hash code
|
||||
public override fun equals(other: Any?): Boolean
|
||||
public override fun hashCode(): Int
|
||||
|
||||
/**
|
||||
* Feature is additional property or hint that does not directly affect the structure, but could in some cases help
|
||||
* optimize operations and performance. If the feature is not present, null is defined.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun <T : Any> getFeature(type: KClass<T>): T? = null
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* Indicates whether some [NDStructure] is equal to another one.
|
||||
*/
|
||||
public fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||
public fun contentEquals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||
if (st1 === st2) return true
|
||||
|
||||
// fast comparison of buffers if possible
|
||||
@ -120,6 +129,9 @@ public interface NDStructure<T> {
|
||||
*/
|
||||
public operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
||||
|
||||
@UnstableKMathAPI
|
||||
public inline fun <reified T : Any> NDStructure<*>.getFeature(): T? = getFeature(T::class)
|
||||
|
||||
/**
|
||||
* Represents mutable [NDStructure].
|
||||
*/
|
||||
@ -133,6 +145,9 @@ public interface MutableNDStructure<T> : NDStructure<T> {
|
||||
public operator fun set(index: IntArray, value: T)
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform a structure element-by element in place.
|
||||
*/
|
||||
public inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T): Unit =
|
||||
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
|
||||
|
||||
@ -260,7 +275,7 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
|
@ -150,6 +150,8 @@ public class RealBufferField(public val size: Int) : ExtendedField<Buffer<Double
|
||||
public override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
|
||||
public override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
|
||||
|
||||
override fun number(value: Number): Buffer<Double> = RealBuffer(size) { value.toDouble() }
|
||||
|
||||
public override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.add(a, b)
|
||||
|
@ -1,13 +1,17 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.FieldElement
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.RingWithNumbers
|
||||
|
||||
public typealias RealNDElement = BufferedNDFieldElement<Double, RealField>
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class RealNDField(override val shape: IntArray) :
|
||||
BufferedNDField<Double, RealField>,
|
||||
ExtendedNDField<Double, RealField, NDBuffer<Double>> {
|
||||
ExtendedNDField<Double, RealField, NDBuffer<Double>>,
|
||||
RingWithNumbers<NDBuffer<Double>> {
|
||||
|
||||
override val strides: Strides = DefaultStrides(shape)
|
||||
|
||||
@ -15,35 +19,36 @@ public class RealNDField(override val shape: IntArray) :
|
||||
override val zero: RealNDElement by lazy { produce { zero } }
|
||||
override val one: RealNDElement by lazy { produce { one } }
|
||||
|
||||
public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
||||
RealBuffer(DoubleArray(size) { initializer(it) })
|
||||
override fun number(value: Number): NDBuffer<Double> {
|
||||
val d = value.toDouble()
|
||||
return produce { d }
|
||||
}
|
||||
|
||||
/**
|
||||
* Inline transform an NDStructure to
|
||||
*/
|
||||
override fun map(
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun map(
|
||||
arg: NDBuffer<Double>,
|
||||
transform: RealField.(Double) -> Double
|
||||
transform: RealField.(Double) -> Double,
|
||||
): RealNDElement {
|
||||
check(arg)
|
||||
val array = buildBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) }
|
||||
val array = RealBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) }
|
||||
return BufferedNDFieldElement(this, array)
|
||||
}
|
||||
|
||||
override fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement {
|
||||
val array = buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement {
|
||||
val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
|
||||
return BufferedNDFieldElement(this, array)
|
||||
}
|
||||
|
||||
override fun mapIndexed(
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun mapIndexed(
|
||||
arg: NDBuffer<Double>,
|
||||
transform: RealField.(index: IntArray, Double) -> Double
|
||||
transform: RealField.(index: IntArray, Double) -> Double,
|
||||
): RealNDElement {
|
||||
check(arg)
|
||||
|
||||
return BufferedNDFieldElement(
|
||||
this,
|
||||
buildBuffer(arg.strides.linearSize) { offset ->
|
||||
RealBuffer(arg.strides.linearSize) { offset ->
|
||||
elementContext.transform(
|
||||
arg.strides.index(offset),
|
||||
arg.buffer[offset]
|
||||
@ -51,15 +56,17 @@ public class RealNDField(override val shape: IntArray) :
|
||||
})
|
||||
}
|
||||
|
||||
override fun combine(
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun combine(
|
||||
a: NDBuffer<Double>,
|
||||
b: NDBuffer<Double>,
|
||||
transform: RealField.(Double, Double) -> Double
|
||||
transform: RealField.(Double, Double) -> Double,
|
||||
): RealNDElement {
|
||||
check(a, b)
|
||||
return BufferedNDFieldElement(
|
||||
this,
|
||||
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
|
||||
val buffer = RealBuffer(strides.linearSize) { offset ->
|
||||
elementContext.transform(a.buffer[offset], b.buffer[offset])
|
||||
}
|
||||
return BufferedNDFieldElement(this, buffer)
|
||||
}
|
||||
|
||||
override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> =
|
||||
|
@ -1,12 +1,42 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
/**
|
||||
* A structure that is guaranteed to be two-dimensional
|
||||
* A structure that is guaranteed to be two-dimensional.
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public interface Structure2D<T> : NDStructure<T> {
|
||||
public val rowNum: Int get() = shape[0]
|
||||
public val colNum: Int get() = shape[1]
|
||||
/**
|
||||
* The number of rows in this structure.
|
||||
*/
|
||||
public val rowNum: Int
|
||||
|
||||
/**
|
||||
* The number of columns in this structure.
|
||||
*/
|
||||
public val colNum: Int
|
||||
|
||||
public override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
/**
|
||||
* The buffer of rows of this structure. It gets elements from the structure dynamically.
|
||||
*/
|
||||
public val rows: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } }
|
||||
|
||||
/**
|
||||
* The buffer of columns of this structure. It gets elements from the structure dynamically.
|
||||
*/
|
||||
public val columns: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } }
|
||||
|
||||
/**
|
||||
* Retrieves an element from the structure by two indices.
|
||||
*
|
||||
* @param i the first index.
|
||||
* @param j the second index.
|
||||
* @return an element.
|
||||
*/
|
||||
public operator fun get(i: Int, j: Int): T
|
||||
|
||||
override operator fun get(index: IntArray): T {
|
||||
@ -14,15 +44,9 @@ public interface Structure2D<T> : NDStructure<T> {
|
||||
return get(index[0], index[1])
|
||||
}
|
||||
|
||||
public val rows: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } }
|
||||
|
||||
public val columns: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } }
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||
for (i in (0 until rowNum))
|
||||
for (j in (0 until colNum)) yield(intArrayOf(i, j) to this@Structure2D[i, j])
|
||||
for (i in 0 until rowNum)
|
||||
for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
|
||||
}
|
||||
|
||||
public companion object
|
||||
@ -34,7 +58,11 @@ public interface Structure2D<T> : NDStructure<T> {
|
||||
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
||||
override val shape: IntArray get() = structure.shape
|
||||
|
||||
override val rowNum: Int get() = shape[0]
|
||||
override val colNum: Int get() = shape[1]
|
||||
|
||||
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
@ -46,4 +74,9 @@ public fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2)
|
||||
else
|
||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
|
||||
/**
|
||||
* Alias for [Structure2D] with more familiar name.
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public typealias Matrix<T> = Structure2D<T>
|
||||
|
@ -7,6 +7,7 @@ import kscience.kmath.structures.as2D
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@Suppress("UNUSED_VARIABLE")
|
||||
class MatrixTest {
|
||||
@Test
|
||||
fun testTranspose() {
|
||||
|
@ -12,6 +12,8 @@ internal class FieldVerifier<T>(override val algebra: Field<T>, a: T, b: T, c: T
|
||||
super.verify()
|
||||
|
||||
algebra {
|
||||
assertEquals(a + b, b + a, "Addition in $algebra is not commutative.")
|
||||
assertEquals(a * b, b * a, "Multiplication in $algebra is not commutative.")
|
||||
assertNotEquals(a / b, b / a, "Division in $algebra is not anti-commutative.")
|
||||
assertNotEquals((a / b) / c, a / (b / c), "Division in $algebra is associative.")
|
||||
assertEquals((a + b) / c, (a / c) + (b / c), "Division in $algebra is not right-distributive.")
|
||||
|
@ -10,7 +10,7 @@ internal open class RingVerifier<T>(override val algebra: Ring<T>, a: T, b: T, c
|
||||
super.verify()
|
||||
|
||||
algebra {
|
||||
assertEquals(a * b, a * b, "Multiplication in $algebra is not commutative.")
|
||||
assertEquals(a + b, b + a, "Addition in $algebra is not commutative.")
|
||||
assertEquals(a * b * c, a * (b * c), "Multiplication in $algebra is not associative.")
|
||||
assertEquals(c * (a + b), (c * a) + (c * b), "Multiplication in $algebra is not distributive.")
|
||||
assertEquals(a * one, one * a, "$one in $algebra is not a neutral multiplication element.")
|
||||
|
@ -15,7 +15,6 @@ internal open class SpaceVerifier<T>(
|
||||
AlgebraicVerifier<T, Space<T>> {
|
||||
override fun verify() {
|
||||
algebra {
|
||||
assertEquals(a + b, b + a, "Addition in $algebra is not commutative.")
|
||||
assertEquals(a + b + c, a + (b + c), "Addition in $algebra is not associative.")
|
||||
assertEquals(x * (a + b), x * a + x * b, "Addition in $algebra is not distributive.")
|
||||
assertEquals((a + b) * x, a * x + b * x, "Addition in $algebra is not distributive.")
|
||||
|
@ -1,10 +1,15 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.operations.internal.FieldVerifier
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class NDFieldTest {
|
||||
@Test
|
||||
fun verify() {
|
||||
NDField.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) }
|
||||
}
|
||||
|
||||
class NDFieldTest {
|
||||
@Test
|
||||
fun testStrides() {
|
||||
val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
|
||||
|
@ -8,6 +8,7 @@ import kotlin.math.pow
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@Suppress("UNUSED_VARIABLE")
|
||||
class NumberNDFieldTest {
|
||||
val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() }
|
||||
val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() }
|
||||
|
@ -7,7 +7,7 @@ import java.math.MathContext
|
||||
/**
|
||||
* A field over [BigInteger].
|
||||
*/
|
||||
public object JBigIntegerField : Field<BigInteger> {
|
||||
public object JBigIntegerField : Field<BigInteger>, NumericAlgebra<BigInteger> {
|
||||
public override val zero: BigInteger
|
||||
get() = BigInteger.ZERO
|
||||
|
||||
@ -28,9 +28,9 @@ public object JBigIntegerField : Field<BigInteger> {
|
||||
*
|
||||
* @property mathContext the [MathContext] to use.
|
||||
*/
|
||||
public abstract class JBigDecimalFieldBase internal constructor(public val mathContext: MathContext = MathContext.DECIMAL64) :
|
||||
Field<BigDecimal>,
|
||||
PowerOperations<BigDecimal> {
|
||||
public abstract class JBigDecimalFieldBase internal constructor(
|
||||
private val mathContext: MathContext = MathContext.DECIMAL64,
|
||||
) : Field<BigDecimal>, PowerOperations<BigDecimal>, NumericAlgebra<BigDecimal> {
|
||||
public override val zero: BigDecimal
|
||||
get() = BigDecimal.ZERO
|
||||
|
||||
|
@ -24,7 +24,7 @@ public class LazyNDStructure<T>(
|
||||
}
|
||||
|
||||
public override fun equals(other: Any?): Boolean {
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false)
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int {
|
||||
|
@ -1,11 +1,6 @@
|
||||
package kscience.kmath.dimensions
|
||||
|
||||
import kscience.kmath.linear.GenericMatrixContext
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.linear.transpose
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.Ring
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.Structure2D
|
||||
@ -42,9 +37,11 @@ public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
||||
* An inline wrapper for a Matrix
|
||||
*/
|
||||
public inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
||||
private val structure: Structure2D<T>
|
||||
private val structure: Structure2D<T>,
|
||||
) : DMatrix<T, R, C> {
|
||||
override val shape: IntArray get() = structure.shape
|
||||
override val rowNum: Int get() = shape[0]
|
||||
override val colNum: Int get() = shape[1]
|
||||
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||
}
|
||||
|
||||
@ -81,7 +78,7 @@ public inline class DPointWrapper<T, D : Dimension>(public val point: Point<T>)
|
||||
/**
|
||||
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
||||
*/
|
||||
public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: GenericMatrixContext<T, Ri, Matrix<T>>) {
|
||||
public inline class DMatrixContext<T : Any>(public val context: MatrixContext<T, Matrix<T>>) {
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
||||
require(rowNum == Dimension.dim<R>().toInt()) {
|
||||
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
|
||||
@ -115,7 +112,7 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
|
||||
}
|
||||
|
||||
public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
||||
other: DMatrix<T, C1, C2>
|
||||
other: DMatrix<T, C1, C2>,
|
||||
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
|
||||
|
||||
public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
|
||||
@ -139,18 +136,20 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
|
||||
context { (this@transpose as Matrix<T>).transpose() }.coerce()
|
||||
|
||||
public companion object {
|
||||
public val real: DMatrixContext<Double> = DMatrixContext(MatrixContext.real)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A square unit matrix
|
||||
*/
|
||||
public inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j ->
|
||||
if (i == j) context.elementContext.one else context.elementContext.zero
|
||||
public inline fun <reified D : Dimension> DMatrixContext<Double>.one(): DMatrix<Double, D, D> = produce { i, j ->
|
||||
if (i == j) 1.0 else 0.0
|
||||
}
|
||||
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ ->
|
||||
context.elementContext.zero
|
||||
}
|
||||
|
||||
public companion object {
|
||||
public val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real)
|
||||
}
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> DMatrixContext<Double>.zero(): DMatrix<Double, R, C> =
|
||||
produce { _, _ ->
|
||||
0.0
|
||||
}
|
@ -3,8 +3,10 @@ package kscience.dimensions
|
||||
import kscience.kmath.dimensions.D2
|
||||
import kscience.kmath.dimensions.D3
|
||||
import kscience.kmath.dimensions.DMatrixContext
|
||||
import kscience.kmath.dimensions.one
|
||||
import kotlin.test.Test
|
||||
|
||||
@Suppress("UNUSED_VARIABLE")
|
||||
internal class DMatrixContextTest {
|
||||
@Test
|
||||
fun testDimensionSafeMatrix() {
|
||||
|
@ -1,12 +1,14 @@
|
||||
package kscience.kmath.ejml
|
||||
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.NDStructure
|
||||
import kscience.kmath.structures.RealBuffer
|
||||
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
import kscience.kmath.linear.DeterminantFeature
|
||||
import kscience.kmath.linear.FeaturedMatrix
|
||||
import kscience.kmath.linear.LUPDecompositionFeature
|
||||
import kscience.kmath.linear.MatrixFeature
|
||||
import kscience.kmath.structures.NDStructure
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.cast
|
||||
|
||||
/**
|
||||
* Represents featured matrix over EJML [SimpleMatrix].
|
||||
@ -14,58 +16,75 @@ import kscience.kmath.structures.NDStructure
|
||||
* @property origin the underlying [SimpleMatrix].
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public class EjmlMatrix(public val origin: SimpleMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||
public override val rowNum: Int
|
||||
get() = origin.numRows()
|
||||
public class EjmlMatrix(
|
||||
public val origin: SimpleMatrix,
|
||||
) : Matrix<Double> {
|
||||
public override val rowNum: Int get() = origin.numRows()
|
||||
|
||||
public override val colNum: Int
|
||||
get() = origin.numCols()
|
||||
public override val colNum: Int get() = origin.numCols()
|
||||
|
||||
public override val shape: IntArray
|
||||
get() = intArrayOf(origin.numRows(), origin.numCols())
|
||||
@UnstableKMathAPI
|
||||
override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
|
||||
InverseMatrixFeature::class -> object : InverseMatrixFeature<Double> {
|
||||
override val inverse: Matrix<Double> by lazy { EjmlMatrix(origin.invert()) }
|
||||
}
|
||||
DeterminantFeature::class -> object : DeterminantFeature<Double> {
|
||||
override val determinant: Double by lazy(origin::determinant)
|
||||
}
|
||||
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
|
||||
private val svd by lazy {
|
||||
DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false)
|
||||
.apply { decompose(origin.ddrm.copy()) }
|
||||
}
|
||||
|
||||
public override val features: Set<MatrixFeature> = setOf(
|
||||
object : LUPDecompositionFeature<Double>, DeterminantFeature<Double> {
|
||||
override val determinant: Double
|
||||
get() = origin.determinant()
|
||||
override val u: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) }
|
||||
override val s: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) }
|
||||
override val v: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) }
|
||||
override val singularValues: Point<Double> by lazy { RealBuffer(svd.singularValues) }
|
||||
}
|
||||
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
|
||||
private val qr by lazy {
|
||||
DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) }
|
||||
}
|
||||
|
||||
override val q: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) }
|
||||
override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) }
|
||||
}
|
||||
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
|
||||
override val l: Matrix<Double> by lazy {
|
||||
val cholesky =
|
||||
DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) }
|
||||
|
||||
EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature
|
||||
}
|
||||
}
|
||||
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
|
||||
private val lup by lazy {
|
||||
val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols())
|
||||
.also { it.decompose(origin.ddrm.copy()) }
|
||||
|
||||
Triple(
|
||||
EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))),
|
||||
EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))),
|
||||
EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))),
|
||||
)
|
||||
DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) }
|
||||
}
|
||||
|
||||
override val l: FeaturedMatrix<Double>
|
||||
get() = lup.second
|
||||
|
||||
override val u: FeaturedMatrix<Double>
|
||||
get() = lup.third
|
||||
|
||||
override val p: FeaturedMatrix<Double>
|
||||
get() = lup.first
|
||||
override val l: Matrix<Double> by lazy {
|
||||
EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature
|
||||
}
|
||||
) union features.orEmpty()
|
||||
|
||||
public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix =
|
||||
EjmlMatrix(origin, this.features + features)
|
||||
override val u: Matrix<Double> by lazy {
|
||||
EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature
|
||||
}
|
||||
|
||||
override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
|
||||
}
|
||||
else -> null
|
||||
}?.let { type.cast(it) }
|
||||
|
||||
public override operator fun get(i: Int, j: Int): Double = origin[i, j]
|
||||
|
||||
public override fun equals(other: Any?): Boolean {
|
||||
if (other is EjmlMatrix) return origin.isIdentical(other.origin, 0.0)
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is Matrix<*>) return false
|
||||
return NDStructure.contentEquals(this, other)
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int {
|
||||
var result = origin.hashCode()
|
||||
result = 31 * result + features.hashCode()
|
||||
return result
|
||||
}
|
||||
override fun hashCode(): Int = origin.hashCode()
|
||||
|
||||
|
||||
public override fun toString(): String = "EjmlMatrix(origin=$origin, features=$features)"
|
||||
}
|
||||
|
@ -1,16 +1,14 @@
|
||||
package kscience.kmath.ejml
|
||||
|
||||
import kscience.kmath.linear.InverseMatrixFeature
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.linear.origin
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.getFeature
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
|
||||
/**
|
||||
* Converts this matrix to EJML one.
|
||||
*/
|
||||
public fun Matrix<Double>.toEjml(): EjmlMatrix =
|
||||
if (this is EjmlMatrix) this else EjmlMatrixContext.produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||
|
||||
/**
|
||||
* Represents context of basic operations operating with [EjmlMatrix].
|
||||
*
|
||||
@ -18,6 +16,15 @@ public fun Matrix<Double>.toEjml(): EjmlMatrix =
|
||||
*/
|
||||
public object EjmlMatrixContext : MatrixContext<Double, EjmlMatrix> {
|
||||
|
||||
/**
|
||||
* Converts this matrix to EJML one.
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun Matrix<Double>.toEjml(): EjmlMatrix = when (val matrix = origin) {
|
||||
is EjmlMatrix -> matrix
|
||||
else -> produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts this vector to EJML one.
|
||||
*/
|
||||
@ -33,6 +40,11 @@ public object EjmlMatrixContext : MatrixContext<Double, EjmlMatrix> {
|
||||
}
|
||||
})
|
||||
|
||||
override fun point(size: Int, initializer: (Int) -> Double): Point<Double> =
|
||||
EjmlVector(SimpleMatrix(size, 1).also {
|
||||
(0 until it.numRows()).forEach { row -> it[row, 0] = initializer(row) }
|
||||
})
|
||||
|
||||
public override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlMatrix =
|
||||
EjmlMatrix(toEjml().origin.mult(other.toEjml().origin))
|
||||
|
||||
@ -74,11 +86,7 @@ public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMa
|
||||
public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
|
||||
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
|
||||
|
||||
/**
|
||||
* Returns the inverse of given matrix: b = a^(-1).
|
||||
*
|
||||
* @param a the matrix.
|
||||
* @return the inverse of this matrix.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public fun EjmlMatrixContext.inverse(a: Matrix<Double>): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert())
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature<InverseMatrixFeature<Double>>()!!.inverse as EjmlMatrix
|
||||
|
||||
public fun EjmlMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> = matrix.toEjml().inverted()
|
@ -1,9 +1,11 @@
|
||||
package kscience.kmath.ejml
|
||||
|
||||
import kscience.kmath.linear.DeterminantFeature
|
||||
import kscience.kmath.linear.LUPDecompositionFeature
|
||||
import kscience.kmath.linear.LupDecompositionFeature
|
||||
import kscience.kmath.linear.MatrixFeature
|
||||
import kscience.kmath.linear.getFeature
|
||||
import kscience.kmath.linear.plus
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.structures.getFeature
|
||||
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
import kotlin.random.Random
|
||||
@ -38,13 +40,14 @@ internal class EjmlMatrixTest {
|
||||
assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList())
|
||||
}
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
@Test
|
||||
fun features() {
|
||||
val m = randomMatrix
|
||||
val w = EjmlMatrix(m)
|
||||
val det = w.getFeature<DeterminantFeature<Double>>() ?: fail()
|
||||
assertEquals(m.determinant(), det.determinant)
|
||||
val lup = w.getFeature<LUPDecompositionFeature<Double>>() ?: fail()
|
||||
val lup = w.getFeature<LupDecompositionFeature<Double>>() ?: fail()
|
||||
|
||||
val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols())
|
||||
.also { it.decompose(m.ddrm.copy()) }
|
||||
@ -56,9 +59,10 @@ internal class EjmlMatrixTest {
|
||||
|
||||
private object SomeFeature : MatrixFeature {}
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
@Test
|
||||
fun suggestFeature() {
|
||||
assertNotNull(EjmlMatrix(randomMatrix).suggestFeature(SomeFeature).getFeature<SomeFeature>())
|
||||
assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature<SomeFeature>())
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -21,7 +21,7 @@
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
|
||||
>
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
|
@ -1,14 +1,12 @@
|
||||
package kscience.kmath.real
|
||||
|
||||
import kscience.kmath.linear.FeaturedMatrix
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.RealMatrixContext.elementContext
|
||||
import kscience.kmath.linear.VirtualMatrix
|
||||
import kscience.kmath.linear.inverseWithLUP
|
||||
import kscience.kmath.linear.real
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.operations.sum
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.RealBuffer
|
||||
import kscience.kmath.structures.asIterable
|
||||
import kotlin.math.pow
|
||||
@ -25,7 +23,7 @@ import kotlin.math.pow
|
||||
* Functions that help create a real (Double) matrix
|
||||
*/
|
||||
|
||||
public typealias RealMatrix = FeaturedMatrix<Double>
|
||||
public typealias RealMatrix = Matrix<Double>
|
||||
|
||||
public fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
|
||||
MatrixContext.real.produce(rowNum, colNum, initializer)
|
||||
@ -122,8 +120,7 @@ public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix =
|
||||
extractColumns(columnIndex..columnIndex)
|
||||
|
||||
public fun RealMatrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
val column = columns[j]
|
||||
elementContext { sum(column.asIterable()) }
|
||||
columns[j].asIterable().sum()
|
||||
}
|
||||
|
||||
public fun RealMatrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
|
@ -1,6 +1,5 @@
|
||||
package kaceince.kmath.real
|
||||
|
||||
import kscience.kmath.linear.VirtualMatrix
|
||||
import kscience.kmath.linear.build
|
||||
import kscience.kmath.real.*
|
||||
import kscience.kmath.structures.Matrix
|
||||
@ -42,7 +41,7 @@ internal class RealMatrixTest {
|
||||
1.0, 0.0, 0.0,
|
||||
0.0, 1.0, 2.0
|
||||
)
|
||||
assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3))
|
||||
assertEquals(matrix2, matrix1.repeatStackVertical(3))
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -23,7 +23,7 @@ This subproject implements the following features:
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
|
||||
>
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
|
@ -1,5 +1,6 @@
|
||||
package kscience.kmath.nd4j
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
import kscience.kmath.structures.NDAlgebra
|
||||
import kscience.kmath.structures.NDField
|
||||
@ -35,7 +36,7 @@ public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C, Nd4jArrayStructure<T>>
|
||||
|
||||
public override fun mapIndexed(
|
||||
arg: Nd4jArrayStructure<T>,
|
||||
transform: C.(index: IntArray, T) -> T
|
||||
transform: C.(index: IntArray, T) -> T,
|
||||
): Nd4jArrayStructure<T> {
|
||||
check(arg)
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
@ -46,7 +47,7 @@ public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C, Nd4jArrayStructure<T>>
|
||||
public override fun combine(
|
||||
a: Nd4jArrayStructure<T>,
|
||||
b: Nd4jArrayStructure<T>,
|
||||
transform: C.(T, T) -> T
|
||||
transform: C.(T, T) -> T,
|
||||
): Nd4jArrayStructure<T> {
|
||||
check(a, b)
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
@ -61,8 +62,8 @@ public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C, Nd4jArrayStructure<T>>
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param S the type of space of structure elements.
|
||||
*/
|
||||
public interface Nd4jArraySpace<T, S> : NDSpace<T, S, Nd4jArrayStructure<T>>,
|
||||
Nd4jArrayAlgebra<T, S> where S : Space<T> {
|
||||
public interface Nd4jArraySpace<T, S : Space<T>> : NDSpace<T, S, Nd4jArrayStructure<T>>, Nd4jArrayAlgebra<T, S> {
|
||||
|
||||
public override val zero: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.zeros(*shape).wrap()
|
||||
|
||||
@ -103,7 +104,9 @@ public interface Nd4jArraySpace<T, S> : NDSpace<T, S, Nd4jArrayStructure<T>>,
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param R the type of ring of structure elements.
|
||||
*/
|
||||
public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4jArraySpace<T, R> where R : Ring<T> {
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public interface Nd4jArrayRing<T, R : Ring<T>> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4jArraySpace<T, R> {
|
||||
|
||||
public override val one: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.ones(*shape).wrap()
|
||||
|
||||
@ -111,21 +114,21 @@ public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4j
|
||||
check(a, b)
|
||||
return a.ndArray.mul(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
|
||||
check(this)
|
||||
return ndArray.sub(b).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.plus(b: Number): Nd4jArrayStructure<T> {
|
||||
check(this)
|
||||
return ndArray.add(b).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Number.minus(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(b)
|
||||
return b.ndArray.rsub(this).wrap()
|
||||
}
|
||||
//
|
||||
// public override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
|
||||
// check(this)
|
||||
// return ndArray.sub(b).wrap()
|
||||
// }
|
||||
//
|
||||
// public override operator fun Nd4jArrayStructure<T>.plus(b: Number): Nd4jArrayStructure<T> {
|
||||
// check(this)
|
||||
// return ndArray.add(b).wrap()
|
||||
// }
|
||||
//
|
||||
// public override operator fun Number.minus(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
// check(b)
|
||||
// return b.ndArray.rsub(this).wrap()
|
||||
// }
|
||||
|
||||
public companion object {
|
||||
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
|
||||
@ -165,7 +168,8 @@ public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4j
|
||||
* @param N the type of ND structure.
|
||||
* @param F the type field of structure elements.
|
||||
*/
|
||||
public interface Nd4jArrayField<T, F> : NDField<T, F, Nd4jArrayStructure<T>>, Nd4jArrayRing<T, F> where F : Field<T> {
|
||||
public interface Nd4jArrayField<T, F : Field<T>> : NDField<T, F, Nd4jArrayStructure<T>>, Nd4jArrayRing<T, F> {
|
||||
|
||||
public override fun divide(a: Nd4jArrayStructure<T>, b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(a, b)
|
||||
return a.ndArray.div(b.ndArray).wrap()
|
||||
|
@ -62,6 +62,7 @@ class MCScopeTest {
|
||||
}
|
||||
|
||||
|
||||
@OptIn(ObsoleteCoroutinesApi::class)
|
||||
fun compareResult(test: ATest) {
|
||||
val res1 = runBlocking(Dispatchers.Default) { test() }
|
||||
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
|
||||
|
@ -8,8 +8,8 @@ pluginManagement {
|
||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||
}
|
||||
|
||||
val toolsVersion = "0.7.0"
|
||||
val kotlinVersion = "1.4.20"
|
||||
val toolsVersion = "0.7.3-1.4.30-RC"
|
||||
val kotlinVersion = "1.4.30-RC"
|
||||
|
||||
plugins {
|
||||
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
||||
|
Loading…
Reference in New Issue
Block a user