diff --git a/CHANGELOG.md b/CHANGELOG.md index e542d210c..aa70e6116 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,27 +4,28 @@ ### Added - `fun` annotation for SAM interfaces in library - Explicit `public` visibility for all public APIs -- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140). +- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140) - Automatic README generation for features (#139) - Native support for `memory`, `core` and `dimensions` -- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136). +- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136) - A separate `Symbol` entity, which is used for global unbound symbol. - A `Symbol` indexing scope. - Basic optimization API for Commons-math. - Chi squared optimization for array-like data in CM - `Fitting` utility object in prob/stat -- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`. -- Coroutine-deterministic Monte-Carlo scope with a random number generator. -- Some minor utilities to `kmath-for-real`. +- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray` +- Coroutine-deterministic Monte-Carlo scope with a random number generator +- Some minor utilities to `kmath-for-real` - Generic operation result parameter to `MatrixContext` +- New `MatrixFeature` interfaces for matrix decompositions ### Changed -- Package changed from `scientifik` to `kscience.kmath`. -- Gradle version: 6.6 -> 6.7.1 +- Package changed from `scientifik` to `kscience.kmath` +- Gradle version: 6.6 -> 6.8 - Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`) -- `Polynomial` secondary constructor made function. -- Kotlin version: 1.3.72 -> 1.4.20 -- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library. +- `Polynomial` secondary constructor made function +- Kotlin version: 1.3.72 -> 1.4.21 +- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library - Full autodiff refactoring based on `Symbol` - `kmath-prob` renamed to `kmath-stat` - Grid generators moved to `kmath-for-real` @@ -32,6 +33,8 @@ - Optimized dot product for buffer matrices moved to `kmath-for-real` - EjmlMatrix context is an object - Matrix LUP `inverse` renamed to `inverseWithLUP` +- `NumericAlgebra` moved outside of regular algebra chain (`Ring` no longer implements it). +- Features moved to NDStructure and became transparent. ### Deprecated diff --git a/README.md b/README.md index 50a916d2c..0899f77cc 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,16 @@ submit a feature request if you want something to be implemented first. * ### [kmath-ast](kmath-ast) > > -> **Maturity**: EXPERIMENTAL +> **Maturity**: PROTOTYPE +> +> **Features:** +> - [expression-language](kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt) : Expression language and its parser +> - [mst](kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation +> - [mst-building](kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt) : MST building algebraic structure +> - [mst-interpreter](kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST interpreter +> - [mst-jvm-codegen](kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler +> - [mst-js-codegen](kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler +
* ### [kmath-commons](kmath-commons) @@ -122,7 +131,7 @@ submit a feature request if you want something to be implemented first. * ### [kmath-dimensions](kmath-dimensions) > > -> **Maturity**: EXPERIMENTAL +> **Maturity**: PROTOTYPE
* ### [kmath-ejml](kmath-ejml) diff --git a/build.gradle.kts b/build.gradle.kts index 561a2212b..d171bd608 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -4,7 +4,7 @@ plugins { id("ru.mipt.npm.project") } -internal val kmathVersion: String by extra("0.2.0-dev-4") +internal val kmathVersion: String by extra("0.2.0-dev-5") internal val bintrayRepo: String by extra("kscience") internal val githubProject: String by extra("kmath") @@ -38,3 +38,7 @@ readme { apiValidation { validationDisabled = true } + +ksciencePublish { + spaceRepo = "https://maven.pkg.jetbrains.space/mipt-npm/p/sci/maven" +} diff --git a/docs/templates/ARTIFACT-TEMPLATE.md b/docs/templates/ARTIFACT-TEMPLATE.md index c77948d4b..d46a431bd 100644 --- a/docs/templates/ARTIFACT-TEMPLATE.md +++ b/docs/templates/ARTIFACT-TEMPLATE.md @@ -14,7 +14,7 @@ > maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } > maven { url 'https://dl.bintray.com/hotkeytlt/maven' } - +> > } > > dependencies { diff --git a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt similarity index 57% rename from examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt rename to examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt index a4806ed68..c5edcdedf 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -4,13 +4,19 @@ import kscience.kmath.asm.compile import kscience.kmath.expressions.Expression import kscience.kmath.expressions.expressionInField import kscience.kmath.expressions.invoke +import kscience.kmath.expressions.symbol import kscience.kmath.operations.Field import kscience.kmath.operations.RealField +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State import kotlin.random.Random -import kotlin.system.measureTimeMillis +@State(Scope.Benchmark) internal class ExpressionsInterpretersBenchmark { private val algebra: Field = RealField + + @Benchmark fun functionalExpression() { val expr = algebra.expressionInField { symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0) @@ -19,22 +25,31 @@ internal class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } + @Benchmark fun mstExpression() { val expr = algebra.mstInField { - symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0 } invokeAndSum(expr) } + @Benchmark fun asmExpression() { val expr = algebra.mstInField { - symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0 }.compile() invokeAndSum(expr) } + @Benchmark + fun rawExpression() { + val x by symbol + val expr = Expression { args -> args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 } + invokeAndSum(expr) + } + private fun invokeAndSum(expr: Expression) { val random = Random(0) var sum = 0.0 @@ -46,35 +61,3 @@ internal class ExpressionsInterpretersBenchmark { println(sum) } } - -/** - * This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and - * core FunctionalExpressions API. - * - * The expected rating is: - * - * 1. ASM. - * 2. MST. - * 3. FE. - */ -fun main() { - val benchmark = ExpressionsInterpretersBenchmark() - - val fe = measureTimeMillis { - benchmark.functionalExpression() - } - - println("fe=$fe") - - val mst = measureTimeMillis { - benchmark.mstExpression() - } - - println("mst=$mst") - - val asm = measureTimeMillis { - benchmark.asmExpression() - } - - println("asm=$asm") -} diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ArrayBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ArrayBenchmark.kt index 8c44135fb..ebf31a590 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ArrayBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ArrayBenchmark.kt @@ -16,17 +16,13 @@ internal class ArrayBenchmark { @Benchmark fun benchmarkBufferRead() { var res = 0 - for (i in 1..size) res += arrayBuffer.get( - size - i - ) + for (i in 1..size) res += arrayBuffer[size - i] } @Benchmark fun nativeBufferRead() { var res = 0 - for (i in 1..size) res += nativeBuffer.get( - size - i - ) + for (i in 1..size) res += nativeBuffer[size - i] } companion object { diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/MultiplicationBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt similarity index 56% rename from examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/MultiplicationBenchmark.kt rename to examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt index 9d2b02245..5c59afaee 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/MultiplicationBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt @@ -2,19 +2,21 @@ package kscience.kmath.benchmarks import kotlinx.benchmark.Benchmark import kscience.kmath.commons.linear.CMMatrixContext -import kscience.kmath.commons.linear.CMMatrixContext.dot -import kscience.kmath.commons.linear.toCM import kscience.kmath.ejml.EjmlMatrixContext -import kscience.kmath.ejml.toEjml + +import kscience.kmath.linear.BufferMatrixContext +import kscience.kmath.linear.RealMatrixContext import kscience.kmath.linear.real +import kscience.kmath.operations.RealField import kscience.kmath.operations.invoke +import kscience.kmath.structures.Buffer import kscience.kmath.structures.Matrix import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State import kotlin.random.Random @State(Scope.Benchmark) -class MultiplicationBenchmark { +class DotBenchmark { companion object { val random = Random(12224) val dim = 1000 @@ -23,38 +25,48 @@ class MultiplicationBenchmark { val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } - val cmMatrix1 = matrix1.toCM() - val cmMatrix2 = matrix2.toCM() + val cmMatrix1 = CMMatrixContext { matrix1.toCM() } + val cmMatrix2 = CMMatrixContext { matrix2.toCM() } - val ejmlMatrix1 = matrix1.toEjml() - val ejmlMatrix2 = matrix2.toEjml() + val ejmlMatrix1 = EjmlMatrixContext { matrix1.toEjml() } + val ejmlMatrix2 = EjmlMatrixContext { matrix2.toEjml() } } @Benchmark fun commonsMathMultiplication() { - CMMatrixContext.invoke { + CMMatrixContext { cmMatrix1 dot cmMatrix2 } } @Benchmark fun ejmlMultiplication() { - EjmlMatrixContext.invoke { + EjmlMatrixContext { ejmlMatrix1 dot ejmlMatrix2 } } @Benchmark fun ejmlMultiplicationwithConversion() { - val ejmlMatrix1 = matrix1.toEjml() - val ejmlMatrix2 = matrix2.toEjml() - EjmlMatrixContext.invoke { + EjmlMatrixContext { + val ejmlMatrix1 = matrix1.toEjml() + val ejmlMatrix2 = matrix2.toEjml() + ejmlMatrix1 dot ejmlMatrix2 } } @Benchmark fun bufferedMultiplication() { - matrix1 dot matrix2 + BufferMatrixContext(RealField, Buffer.Companion::real).invoke { + matrix1 dot matrix2 + } + } + + @Benchmark + fun realMultiplication() { + RealMatrixContext { + matrix1 dot matrix2 + } } } \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt index ec8714617..5ff43ef80 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt @@ -5,10 +5,8 @@ import kotlinx.benchmark.Benchmark import kscience.kmath.commons.linear.CMMatrixContext import kscience.kmath.commons.linear.CMMatrixContext.dot import kscience.kmath.commons.linear.inverse -import kscience.kmath.commons.linear.toCM import kscience.kmath.ejml.EjmlMatrixContext import kscience.kmath.ejml.inverse -import kscience.kmath.ejml.toEjml import kscience.kmath.operations.invoke import kscience.kmath.structures.Matrix import org.openjdk.jmh.annotations.Scope @@ -35,16 +33,14 @@ class LinearAlgebraBenchmark { @Benchmark fun cmLUPInversion() { CMMatrixContext { - val cm = matrix.toCM() //avoid overhead on conversion - inverse(cm) + inverse(matrix) } } @Benchmark fun ejmlInverse() { EjmlMatrixContext { - val km = matrix.toEjml() //avoid overhead on conversion - inverse(km) + inverse(matrix) } } } \ No newline at end of file diff --git a/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt index 4478903d9..92b6f046d 100644 --- a/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt @@ -63,4 +63,6 @@ fun main(): Unit = runBlocking(Dispatchers.Default) { val directJob = async { runApacheDirect() } println("KMath Chained: ${chainJob.await()}") println("Apache Direct: ${directJob.await()}") + val normal = GaussianSampler.of(7.0, 2.0) + val chain = normal.sample(generator).blocking() } diff --git a/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt index aa4b10ef2..b69590473 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt @@ -11,7 +11,7 @@ fun main() { val n = 1000 val realField = NDField.real(dim, dim) - val complexField = NDField.complex(dim, dim) + val complexField: ComplexNDField = NDField.complex(dim, dim) val realTime = measureTimeMillis { realField { diff --git a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt index e53af0dee..b5130c92b 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt @@ -33,7 +33,7 @@ fun main() { measureAndPrint("Automatic field addition") { autoField { var res: NDBuffer = one - repeat(n) { res += number(1.0) } + repeat(n) { res += 1.0 } } } @@ -52,7 +52,7 @@ fun main() { measureAndPrint("Nd4j specialized addition") { nd4jField { var res = one - repeat(n) { res += 1.0 as Number } + repeat(n) { res += 1.0 } } } @@ -73,7 +73,7 @@ fun main() { genericField { var res: NDBuffer = one repeat(n) { - res += one // couldn't avoid using `one` due to resolution ambiguity } + res += 1.0 // couldn't avoid using `one` due to resolution ambiguity } } } } diff --git a/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt b/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt index 987eea16f..96684f7dc 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt @@ -4,9 +4,8 @@ import kscience.kmath.dimensions.D2 import kscience.kmath.dimensions.D3 import kscience.kmath.dimensions.DMatrixContext import kscience.kmath.dimensions.Dimension -import kscience.kmath.operations.RealField -private fun DMatrixContext.simple() { +private fun DMatrixContext.simple() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i + j).toDouble() } @@ -18,7 +17,7 @@ private object D5 : Dimension { override val dim: UInt = 5u } -private fun DMatrixContext.custom() { +private fun DMatrixContext.custom() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i - j).toDouble() } val m3 = produce { i, j -> (i - j).toDouble() } diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index be52383ef..da9702f9e 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.8-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 043224800..19e9ee4a9 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -2,72 +2,85 @@ This subproject implements the following features: -- Expression Language and its parser. -- MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation. -- Type-safe builder for MST. -- Evaluating expressions by traversing MST. + - [expression-language](src/jvmMain/kotlin/kscience/kmath/ast/parser.kt) : Expression language and its parser + - [mst](src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation + - [mst-building](src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt) : MST building algebraic structure + - [mst-interpreter](src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST interpreter + - [mst-jvm-codegen](src/jvmMain/kotlin/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler + - [mst-js-codegen](src/jsMain/kotlin/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler + > #### Artifact: -> This module is distributed in the artifact `kscience.kmath:kmath-ast:0.1.4-dev-8`. -> +> +> This module artifact: `kscience.kmath:kmath-ast:0.2.0-dev-4`. +> +> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-ast/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-ast/_latestVersion) +> +> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-ast/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-ast/_latestVersion) +> > **Gradle:** > > ```gradle > repositories { +> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } > maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } -> maven { url https://dl.bintray.com/hotkeytlt/maven' } +> maven { url 'https://dl.bintray.com/hotkeytlt/maven' } +> > } > > dependencies { -> implementation 'kscience.kmath:kmath-ast:0.1.4-dev-8' +> implementation 'kscience.kmath:kmath-ast:0.2.0-dev-4' > } > ``` > **Gradle Kotlin DSL:** > > ```kotlin > repositories { +> maven("https://dl.bintray.com/kotlin/kotlin-eap") > maven("https://dl.bintray.com/mipt-npm/kscience") > maven("https://dl.bintray.com/mipt-npm/dev") > maven("https://dl.bintray.com/hotkeytlt/maven") > } > > dependencies { -> implementation("kscience.kmath:kmath-ast:0.1.4-dev-8") +> implementation("kscience.kmath:kmath-ast:0.2.0-dev-4") > } > ``` -> -## Dynamic Expression Code Generation with ObjectWeb ASM +## Dynamic expression code generation -`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds -a special implementation of `Expression` with implemented `invoke` function. +### On JVM -For example, the following builder: +`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds +a special implementation of `Expression` with implemented `invoke` function. + +For example, the following builder: ```kotlin RealField.mstInField { symbol("x") + 2 }.compile() ``` -… leads to generation of bytecode, which can be decompiled to the following Java class: +… leads to generation of bytecode, which can be decompiled to the following Java class: ```java package kscience.kmath.asm.generated; import java.util.Map; +import kotlin.jvm.functions.Function2; import kscience.kmath.asm.internal.MapIntrinsics; import kscience.kmath.expressions.Expression; -import kscience.kmath.operations.RealField; +import kscience.kmath.expressions.Symbol; -public final class AsmCompiledExpression_1073786867_0 implements Expression { - private final RealField algebra; +public final class AsmCompiledExpression_45045_0 implements Expression { + private final Object[] constants; - public final Double invoke(Map arguments) { - return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D); + public final Double invoke(Map arguments) { + return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2); } - public AsmCompiledExpression_1073786867_0(RealField algebra) { - this.algebra = algebra; + public AsmCompiledExpression_45045_0(Object[] constants) { + this.constants = constants; } } @@ -75,17 +88,35 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression` with implemented `invoke` function. + +For example, the following builder: + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +``` + +… leads to generation of bytecode, which can be decompiled to the following Java class: + +```java +package kscience.kmath.asm.generated; + +import java.util.Map; +import kotlin.jvm.functions.Function2; +import kscience.kmath.asm.internal.MapIntrinsics; +import kscience.kmath.expressions.Expression; +import kscience.kmath.expressions.Symbol; + +public final class AsmCompiledExpression_45045_0 implements Expression { + private final Object[] constants; + + public final Double invoke(Map arguments) { + return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2); + } + + public AsmCompiledExpression_45045_0(Object[] constants) { + this.constants = constants; + } +} + +``` + +### Example Usage + +This API extends MST and MstExpression, so you may optimize as both of them: + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +RealField.expression("x+2".parseMath()) +``` + +#### Known issues + +- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid + class loading overhead. +- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. + +### On JS + +A similar feature is also available on JS. + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +``` + +The code above returns expression implemented with such a JS function: + +```js +var executable = function (constants, arguments) { + return constants[1](constants[0](arguments, "x"), 2); +}; +``` + +#### Known issues + +- This feature uses `eval` which can be unavailable in several environments. diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt index f312323b9..212fd0d0b 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt @@ -2,10 +2,9 @@ package kscience.kmath.ast import kscience.kmath.operations.Algebra import kscience.kmath.operations.NumericAlgebra -import kscience.kmath.operations.RealField /** - * A Mathematical Syntax Tree node for mathematical expressions. + * A Mathematical Syntax Tree (MST) node for mathematical expressions. * * @author Alexander Nozik */ @@ -55,24 +54,25 @@ public sealed class MST { public fun Algebra.evaluate(node: MST): T = when (node) { is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) ?: error("Numeric nodes are not supported by $this") + is MST.Symbolic -> symbol(node.value) - is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) + + is MST.Unary -> when { + this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value)) + else -> unaryOperationFunction(node.operation)(evaluate(node.value)) + } + is MST.Binary -> when { - this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric -> + binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value)) - node.left is MST.Numeric && node.right is MST.Numeric -> { - val number = RealField.binaryOperation( - node.operation, - node.left.value.toDouble(), - node.right.value.toDouble() - ) + this is NumericAlgebra && node.left is MST.Numeric -> + leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right)) - number(number) - } + this is NumericAlgebra && node.right is MST.Numeric -> + rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value) - node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) - node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) - else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) + else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right)) } } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index 6ee6ab9af..eadbc85ee 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -1,64 +1,83 @@ package kscience.kmath.ast +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.* /** * [Algebra] over [MST] nodes. */ public object MstAlgebra : NumericAlgebra { - override fun number(value: Number): MST.Numeric = MST.Numeric(value) + public override fun number(value: Number): MST.Numeric = MST.Numeric(value) + public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) - override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + { arg -> MST.Unary(operation, arg) } - override fun unaryOperation(operation: String, arg: MST): MST.Unary = - MST.Unary(operation, arg) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MST.Binary(operation, left, right) + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + { left, right -> MST.Binary(operation, left, right) } } /** * [Space] over [MST] nodes. */ public object MstSpace : Space, NumericAlgebra { - override val zero: MST.Numeric by lazy { number(0.0) } + public override val zero: MST.Numeric by lazy { number(0.0) } - override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) - override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) - override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) + public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) + public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) + public override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(SpaceOperations.PLUS_OPERATION)(a, b) + public override operator fun MST.unaryPlus(): MST.Unary = + unaryOperationFunction(SpaceOperations.PLUS_OPERATION)(this) - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstAlgebra.binaryOperation(operation, left, right) + public override operator fun MST.unaryMinus(): MST.Unary = + unaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg) + public override operator fun MST.minus(b: MST): MST.Binary = + binaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this, b) + + public override fun multiply(a: MST, k: Number): MST.Binary = + binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(k)) + + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + MstAlgebra.binaryOperationFunction(operation) + + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstAlgebra.unaryOperationFunction(operation) } /** * [Ring] over [MST] nodes. */ -public object MstRing : Ring, NumericAlgebra { - override val zero: MST.Numeric +@OptIn(UnstableKMathAPI::class) +public object MstRing : Ring, RingWithNumbers { + public override val zero: MST.Numeric get() = MstSpace.zero - override val one: MST.Numeric by lazy { number(1.0) } + public override val one: MST.Numeric by lazy { number(1.0) } - override fun number(value: Number): MST.Numeric = MstSpace.number(value) - override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) - override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) - override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) - override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b) + public override fun number(value: Number): MST.Numeric = MstSpace.number(value) + public override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) + public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = + binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstSpace.binaryOperation(operation, left, right) + public override operator fun MST.unaryPlus(): MST.Unary = MstSpace { +this@unaryPlus } + public override operator fun MST.unaryMinus(): MST.Unary = MstSpace { -this@unaryMinus } + public override operator fun MST.minus(b: MST): MST.Binary = MstSpace { this@minus - b } - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg) + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + MstSpace.binaryOperationFunction(operation) + + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstAlgebra.unaryOperationFunction(operation) } /** * [Field] over [MST] nodes. */ -public object MstField : Field { +@OptIn(UnstableKMathAPI::class) +public object MstField : Field, RingWithNumbers { public override val zero: MST.Numeric get() = MstRing.zero @@ -70,51 +89,61 @@ public object MstField : Field { public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) - public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b) + public override fun divide(a: MST, b: MST): MST.Binary = + binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) - public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstRing.binaryOperation(operation, left, right) + public override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } + public override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } + public override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b } - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg) + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + MstRing.binaryOperationFunction(operation) + + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstRing.unaryOperationFunction(operation) } /** * [ExtendedField] over [MST] nodes. */ -public object MstExtendedField : ExtendedField { - override val zero: MST.Numeric +public object MstExtendedField : ExtendedField, NumericAlgebra { + public override val zero: MST.Numeric get() = MstField.zero - override val one: MST.Numeric + public override val one: MST.Numeric get() = MstField.one - override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) - override fun number(value: Number): MST.Numeric = MstField.number(value) - override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) - override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) - override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) - override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) - override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) - override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) - override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) - override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) - override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) - override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) - override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) - override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) - override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) - override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) - override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) + public override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) + public override fun number(value: Number): MST.Numeric = MstRing.number(value) + public override fun sin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) + public override fun cos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg) + public override fun tan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.TAN_OPERATION)(arg) + public override fun asin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ASIN_OPERATION)(arg) + public override fun acos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ACOS_OPERATION)(arg) + public override fun atan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ATAN_OPERATION)(arg) + public override fun sinh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.SINH_OPERATION)(arg) + public override fun cosh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.COSH_OPERATION)(arg) + public override fun tanh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.TANH_OPERATION)(arg) + public override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.ASINH_OPERATION)(arg) + public override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.ACOSH_OPERATION)(arg) + public override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.ATANH_OPERATION)(arg) + public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) + public override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus } + public override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus } + public override operator fun MST.minus(b: MST): MST.Binary = MstField { this@minus - b } - override fun power(arg: MST, pow: Number): MST.Binary = - binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + public override fun power(arg: MST, pow: Number): MST.Binary = + binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) - override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) - override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + public override fun exp(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg) + public override fun ln(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg) - override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary = - MstField.binaryOperation(operation, left, right) + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + MstField.binaryOperationFunction(operation) - override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg) + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstField.unaryOperationFunction(operation) } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt index f68e3f5f8..03d33aa2b 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt @@ -15,11 +15,14 @@ import kotlin.contracts.contract */ public class MstExpression>(public val algebra: A, public val mst: MST) : Expression { private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { - override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) - override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) + override fun symbol(value: String): T = try { + algebra.symbol(value) + } catch (ignored: IllegalStateException) { + null + } ?: arguments.getValue(StringSymbol(value)) - override fun binaryOperation(operation: String, left: T, right: T): T = - algebra.binaryOperation(operation, left, right) + override fun unaryOperationFunction(operation: String): (arg: T) -> T = algebra.unaryOperationFunction(operation) + override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = algebra.binaryOperationFunction(operation) @Suppress("UNCHECKED_CAST") override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt new file mode 100644 index 000000000..5c08ada31 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt @@ -0,0 +1,82 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.MST +import kscience.kmath.ast.MST.* +import kscience.kmath.ast.MstExpression +import kscience.kmath.estree.internal.ESTreeBuilder +import kscience.kmath.estree.internal.estree.BaseExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.operations.Algebra +import kscience.kmath.operations.NumericAlgebra + +@PublishedApi +internal fun MST.compileWith(algebra: Algebra): Expression { + fun ESTreeBuilder.visit(node: MST): BaseExpression = when (node) { + is Symbolic -> { + val symbol = try { + algebra.symbol(node.value) + } catch (ignored: IllegalStateException) { + null + } + + if (symbol != null) + constant(symbol) + else + variable(node.value) + } + + is Numeric -> constant(node.value) + + is Unary -> when { + algebra is NumericAlgebra && node.value is Numeric -> constant( + algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value))) + + else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value)) + } + + is Binary -> when { + algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant( + algebra + .binaryOperationFunction(node.operation) + .invoke(algebra.number(node.left.value), algebra.number(node.right.value)) + ) + + algebra is NumericAlgebra && node.left is Numeric -> call( + algebra.leftSideNumberOperationFunction(node.operation), + visit(node.left), + visit(node.right), + ) + + algebra is NumericAlgebra && node.right is Numeric -> call( + algebra.rightSideNumberOperationFunction(node.operation), + visit(node.left), + visit(node.right), + ) + + else -> call( + algebra.binaryOperationFunction(node.operation), + visit(node.left), + visit(node.right), + ) + } + } + + return ESTreeBuilder { visit(this@compileWith) }.instance +} + + +/** + * Compiles an [MST] to ESTree generated expression using given algebra. + * + * @author Alexander Nozik. + */ +public fun Algebra.expression(mst: MST): Expression = + mst.compileWith(this) + +/** + * Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression. + * + * @author Alexander Nozik. + */ +public fun MstExpression>.compile(): Expression = + mst.compileWith(algebra) diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/ESTreeBuilder.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/ESTreeBuilder.kt new file mode 100644 index 000000000..e1823813a --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/ESTreeBuilder.kt @@ -0,0 +1,79 @@ +package kscience.kmath.estree.internal + +import kscience.kmath.estree.internal.astring.generate +import kscience.kmath.estree.internal.estree.* +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.Symbol + +internal class ESTreeBuilder(val bodyCallback: ESTreeBuilder.() -> BaseExpression) { + private class GeneratedExpression(val executable: dynamic, val constants: Array) : Expression { + @Suppress("UNUSED_VARIABLE") + override fun invoke(arguments: Map): T { + val e = executable + val c = constants + val a = js("{}") + arguments.forEach { (key, value) -> a[key.identity] = value } + return js("e(c, a)").unsafeCast() + } + } + + val instance: Expression by lazy { + val node = Program( + sourceType = "script", + VariableDeclaration( + kind = "var", + VariableDeclarator( + id = Identifier("executable"), + init = FunctionExpression( + params = arrayOf(Identifier("constants"), Identifier("arguments")), + body = BlockStatement(ReturnStatement(bodyCallback())), + ), + ), + ), + ) + + eval(generate(node)) + GeneratedExpression(js("executable"), constants.toTypedArray()) + } + + private val constants = mutableListOf() + + fun constant(value: Any?) = when { + value == null || jsTypeOf(value) == "number" || jsTypeOf(value) == "string" || jsTypeOf(value) == "boolean" -> + SimpleLiteral(value) + + jsTypeOf(value) == "undefined" -> Identifier("undefined") + + else -> { + val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex + + MemberExpression( + computed = true, + optional = false, + `object` = Identifier("constants"), + property = SimpleLiteral(idx), + ) + } + } + + fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name)) + + fun call(function: Function, vararg args: BaseExpression): BaseExpression = SimpleCallExpression( + optional = false, + callee = constant(function), + *args, + ) + + private companion object { + @Suppress("UNUSED_VARIABLE") + val getOrFail: (`object`: dynamic, key: String) -> dynamic = { `object`, key -> + val k = key + val o = `object` + + if (!(js("k in o") as Boolean)) + throw NoSuchElementException("Key $key is missing in the map.") + + js("o[k]") + } + } +} diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.kt new file mode 100644 index 000000000..cf0a8de25 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.kt @@ -0,0 +1,33 @@ +@file:JsModule("astring") +@file:JsNonModule + +package kscience.kmath.estree.internal.astring + +import kscience.kmath.estree.internal.estree.BaseNode + +internal external interface Options { + var indent: String? + get() = definedExternally + set(value) = definedExternally + var lineEnd: String? + get() = definedExternally + set(value) = definedExternally + var startingIndentLevel: Number? + get() = definedExternally + set(value) = definedExternally + var comments: Boolean? + get() = definedExternally + set(value) = definedExternally + var generator: Any? + get() = definedExternally + set(value) = definedExternally + var sourceMap: Any? + get() = definedExternally + set(value) = definedExternally +} + +internal external fun generate(node: BaseNode, options: Options /* Options & `T$0` */ = definedExternally): String + +internal external fun generate(node: BaseNode): String + +internal external var baseGenerator: Generator diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.typealises.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.typealises.kt new file mode 100644 index 000000000..5a7fe4f16 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.typealises.kt @@ -0,0 +1,3 @@ +package kscience.kmath.estree.internal.astring + +internal typealias Generator = Any diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/emitter/emitter.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/emitter/emitter.kt new file mode 100644 index 000000000..1e0a95a16 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/emitter/emitter.kt @@ -0,0 +1,13 @@ +package kscience.kmath.estree.internal.emitter + +internal open external class Emitter { + constructor(obj: Any) + constructor() + + open fun on(event: String, fn: () -> Unit) + open fun off(event: String, fn: () -> Unit) + open fun once(event: String, fn: () -> Unit) + open fun emit(event: String, vararg any: Any) + open fun listeners(event: String): Array<() -> Unit> + open fun hasListeners(event: String): Boolean +} diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.extensions.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.extensions.kt new file mode 100644 index 000000000..5bc197d0c --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.extensions.kt @@ -0,0 +1,62 @@ +package kscience.kmath.estree.internal.estree + +internal fun Program(sourceType: String, vararg body: dynamic) = object : Program { + override var type = "Program" + override var sourceType = sourceType + override var body = body +} + +internal fun VariableDeclaration(kind: String, vararg declarations: VariableDeclarator) = object : VariableDeclaration { + override var type = "VariableDeclaration" + override var declarations = declarations.toList().toTypedArray() + override var kind = kind +} + +internal fun VariableDeclarator(id: dynamic, init: dynamic) = object : VariableDeclarator { + override var type = "VariableDeclarator" + override var id = id + override var init = init +} + +internal fun Identifier(name: String) = object : Identifier { + override var type = "Identifier" + override var name = name +} + +internal fun FunctionExpression(params: Array, body: BlockStatement) = object : FunctionExpression { + override var params = params + override var type = "FunctionExpression" + override var body = body +} + +internal fun BlockStatement(vararg body: dynamic) = object : BlockStatement { + override var type = "BlockStatement" + override var body = body +} + +internal fun ReturnStatement(argument: dynamic) = object : ReturnStatement { + override var type = "ReturnStatement" + override var argument = argument +} + +internal fun SimpleLiteral(value: dynamic) = object : SimpleLiteral { + override var type = "Literal" + override var value = value +} + +internal fun MemberExpression(computed: Boolean, optional: Boolean, `object`: dynamic, property: dynamic) = + object : MemberExpression { + override var type = "MemberExpression" + override var computed = computed + override var optional = optional + override var `object` = `object` + override var property = property + } + +internal fun SimpleCallExpression(optional: Boolean, callee: dynamic, vararg arguments: dynamic) = + object : SimpleCallExpression { + override var type = "CallExpression" + override var optional = optional + override var callee = callee + override var arguments = arguments + } diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.kt new file mode 100644 index 000000000..a5385d1ee --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.kt @@ -0,0 +1,644 @@ +package kscience.kmath.estree.internal.estree + +import kotlin.js.RegExp + +internal external interface BaseNodeWithoutComments { + var type: String + var loc: SourceLocation? + get() = definedExternally + set(value) = definedExternally + var range: dynamic /* JsTuple */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseNode : BaseNodeWithoutComments { + var leadingComments: Array? + get() = definedExternally + set(value) = definedExternally + var trailingComments: Array? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface Comment : BaseNodeWithoutComments { + override var type: String /* "Line" | "Block" */ + var value: String +} + +internal external interface SourceLocation { + var source: String? + get() = definedExternally + set(value) = definedExternally + var start: Position + var end: Position +} + +internal external interface Position { + var line: Number + var column: Number +} + +internal external interface Program : BaseNode { + override var type: String /* "Program" */ + var sourceType: String /* "script" | "module" */ + var body: Array + var comments: Array? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface Directive : BaseNode { + override var type: String /* "ExpressionStatement" */ + var expression: dynamic /* SimpleLiteral | RegExpLiteral */ + get() = definedExternally + set(value) = definedExternally + var directive: String +} + +internal external interface BaseFunction : BaseNode { + var params: Array + var generator: Boolean? + get() = definedExternally + set(value) = definedExternally + var async: Boolean? + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* BlockStatement | ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseStatement : BaseNode + +internal external interface EmptyStatement : BaseStatement { + override var type: String /* "EmptyStatement" */ +} + +internal external interface BlockStatement : BaseStatement { + override var type: String /* "BlockStatement" */ + var body: Array + var innerComments: Array? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ExpressionStatement : BaseStatement { + override var type: String /* "ExpressionStatement" */ + var expression: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface IfStatement : BaseStatement { + override var type: String /* "IfStatement" */ + var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var consequent: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally + var alternate: dynamic /* ExpressionStatement? | BlockStatement? | EmptyStatement? | DebuggerStatement? | WithStatement? | ReturnStatement? | LabeledStatement? | BreakStatement? | ContinueStatement? | IfStatement? | SwitchStatement? | ThrowStatement? | TryStatement? | WhileStatement? | DoWhileStatement? | ForStatement? | ForInStatement? | ForOfStatement? | FunctionDeclaration? | VariableDeclaration? | ClassDeclaration? */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface LabeledStatement : BaseStatement { + override var type: String /* "LabeledStatement" */ + var label: Identifier + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BreakStatement : BaseStatement { + override var type: String /* "BreakStatement" */ + var label: Identifier? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ContinueStatement : BaseStatement { + override var type: String /* "ContinueStatement" */ + var label: Identifier? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface WithStatement : BaseStatement { + override var type: String /* "WithStatement" */ + var `object`: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface SwitchStatement : BaseStatement { + override var type: String /* "SwitchStatement" */ + var discriminant: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var cases: Array +} + +internal external interface ReturnStatement : BaseStatement { + override var type: String /* "ReturnStatement" */ + var argument: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ThrowStatement : BaseStatement { + override var type: String /* "ThrowStatement" */ + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface TryStatement : BaseStatement { + override var type: String /* "TryStatement" */ + var block: BlockStatement + var handler: CatchClause? + get() = definedExternally + set(value) = definedExternally + var finalizer: BlockStatement? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface WhileStatement : BaseStatement { + override var type: String /* "WhileStatement" */ + var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface DoWhileStatement : BaseStatement { + override var type: String /* "DoWhileStatement" */ + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally + var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ForStatement : BaseStatement { + override var type: String /* "ForStatement" */ + var init: dynamic /* VariableDeclaration? | ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var test: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var update: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseForXStatement : BaseStatement { + var left: dynamic /* VariableDeclaration | Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ForInStatement : BaseForXStatement { + override var type: String /* "ForInStatement" */ +} + +internal external interface DebuggerStatement : BaseStatement { + override var type: String /* "DebuggerStatement" */ +} + +internal external interface BaseDeclaration : BaseStatement + +internal external interface FunctionDeclaration : BaseFunction, BaseDeclaration { + override var type: String /* "FunctionDeclaration" */ + var id: Identifier? + override var body: BlockStatement +} + +internal external interface VariableDeclaration : BaseDeclaration { + override var type: String /* "VariableDeclaration" */ + var declarations: Array + var kind: String /* "var" | "let" | "const" */ +} + +internal external interface VariableDeclarator : BaseNode { + override var type: String /* "VariableDeclarator" */ + var id: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + var init: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseExpression : BaseNode + +internal external interface ChainExpression : BaseExpression { + override var type: String /* "ChainExpression" */ + var expression: dynamic /* SimpleCallExpression | MemberExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ThisExpression : BaseExpression { + override var type: String /* "ThisExpression" */ +} + +internal external interface ArrayExpression : BaseExpression { + override var type: String /* "ArrayExpression" */ + var elements: Array +} + +internal external interface ObjectExpression : BaseExpression { + override var type: String /* "ObjectExpression" */ + var properties: Array +} + +internal external interface Property : BaseNode { + override var type: String /* "Property" */ + var key: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var value: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern */ + get() = definedExternally + set(value) = definedExternally + var kind: String /* "init" | "get" | "set" */ + var method: Boolean + var shorthand: Boolean + var computed: Boolean +} + +internal external interface FunctionExpression : BaseFunction, BaseExpression { + var id: Identifier? + get() = definedExternally + set(value) = definedExternally + override var type: String /* "FunctionExpression" */ + override var body: BlockStatement +} + +internal external interface SequenceExpression : BaseExpression { + override var type: String /* "SequenceExpression" */ + var expressions: Array +} + +internal external interface UnaryExpression : BaseExpression { + override var type: String /* "UnaryExpression" */ + var operator: String /* "-" | "+" | "!" | "~" | "typeof" | "void" | "delete" */ + var prefix: Boolean + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BinaryExpression : BaseExpression { + override var type: String /* "BinaryExpression" */ + var operator: String /* "==" | "!=" | "===" | "!==" | "<" | "<=" | ">" | ">=" | "<<" | ">>" | ">>>" | "+" | "-" | "*" | "/" | "%" | "**" | "|" | "^" | "&" | "in" | "instanceof" */ + var left: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface AssignmentExpression : BaseExpression { + override var type: String /* "AssignmentExpression" */ + var operator: String /* "=" | "+=" | "-=" | "*=" | "/=" | "%=" | "**=" | "<<=" | ">>=" | ">>>=" | "|=" | "^=" | "&=" */ + var left: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface UpdateExpression : BaseExpression { + override var type: String /* "UpdateExpression" */ + var operator: String /* "++" | "--" */ + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var prefix: Boolean +} + +internal external interface LogicalExpression : BaseExpression { + override var type: String /* "LogicalExpression" */ + var operator: String /* "||" | "&&" | "??" */ + var left: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ConditionalExpression : BaseExpression { + override var type: String /* "ConditionalExpression" */ + var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var alternate: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var consequent: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseCallExpression : BaseExpression { + var callee: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | Super */ + get() = definedExternally + set(value) = definedExternally + var arguments: Array +} + +internal external interface SimpleCallExpression : BaseCallExpression { + override var type: String /* "CallExpression" */ + var optional: Boolean +} + +internal external interface NewExpression : BaseCallExpression { + override var type: String /* "NewExpression" */ +} + +internal external interface MemberExpression : BaseExpression, BasePattern { + override var type: String /* "MemberExpression" */ + var `object`: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | Super */ + get() = definedExternally + set(value) = definedExternally + var property: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var computed: Boolean + var optional: Boolean +} + +internal external interface BasePattern : BaseNode + +internal external interface SwitchCase : BaseNode { + override var type: String /* "SwitchCase" */ + var test: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var consequent: Array +} + +internal external interface CatchClause : BaseNode { + override var type: String /* "CatchClause" */ + var param: dynamic /* Identifier? | ObjectPattern? | ArrayPattern? | RestElement? | AssignmentPattern? | MemberExpression? */ + get() = definedExternally + set(value) = definedExternally + var body: BlockStatement +} + +internal external interface Identifier : BaseNode, BaseExpression, BasePattern { + override var type: String /* "Identifier" */ + var name: String +} + +internal external interface SimpleLiteral : BaseNode, BaseExpression { + override var type: String /* "Literal" */ + var value: dynamic /* String? | Boolean? | Number? */ + get() = definedExternally + set(value) = definedExternally + var raw: String? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface `T$1` { + var pattern: String + var flags: String +} + +internal external interface RegExpLiteral : BaseNode, BaseExpression { + override var type: String /* "Literal" */ + var value: RegExp? + get() = definedExternally + set(value) = definedExternally + var regex: `T$1` + var raw: String? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ForOfStatement : BaseForXStatement { + override var type: String /* "ForOfStatement" */ + var await: Boolean +} + +internal external interface Super : BaseNode { + override var type: String /* "Super" */ +} + +internal external interface SpreadElement : BaseNode { + override var type: String /* "SpreadElement" */ + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ArrowFunctionExpression : BaseExpression, BaseFunction { + override var type: String /* "ArrowFunctionExpression" */ + var expression: Boolean + override var body: dynamic /* BlockStatement | ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface YieldExpression : BaseExpression { + override var type: String /* "YieldExpression" */ + var argument: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var delegate: Boolean +} + +internal external interface TemplateLiteral : BaseExpression { + override var type: String /* "TemplateLiteral" */ + var quasis: Array + var expressions: Array +} + +internal external interface TaggedTemplateExpression : BaseExpression { + override var type: String /* "TaggedTemplateExpression" */ + var tag: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var quasi: TemplateLiteral +} + +internal external interface `T$2` { + var cooked: String + var raw: String +} + +internal external interface TemplateElement : BaseNode { + override var type: String /* "TemplateElement" */ + var tail: Boolean + var value: `T$2` +} + +internal external interface AssignmentProperty : Property { + override var value: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + override var kind: String /* "init" */ + override var method: Boolean +} + +internal external interface ObjectPattern : BasePattern { + override var type: String /* "ObjectPattern" */ + var properties: Array +} + +internal external interface ArrayPattern : BasePattern { + override var type: String /* "ArrayPattern" */ + var elements: Array +} + +internal external interface RestElement : BasePattern { + override var type: String /* "RestElement" */ + var argument: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface AssignmentPattern : BasePattern { + override var type: String /* "AssignmentPattern" */ + var left: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseClass : BaseNode { + var superClass: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var body: ClassBody +} + +internal external interface ClassBody : BaseNode { + override var type: String /* "ClassBody" */ + var body: Array +} + +internal external interface MethodDefinition : BaseNode { + override var type: String /* "MethodDefinition" */ + var key: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var value: FunctionExpression + var kind: String /* "constructor" | "method" | "get" | "set" */ + var computed: Boolean + var static: Boolean +} + +internal external interface ClassDeclaration : BaseClass, BaseDeclaration { + override var type: String /* "ClassDeclaration" */ + var id: Identifier? +} + +internal external interface ClassExpression : BaseClass, BaseExpression { + override var type: String /* "ClassExpression" */ + var id: Identifier? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface MetaProperty : BaseExpression { + override var type: String /* "MetaProperty" */ + var meta: Identifier + var property: Identifier +} + +internal external interface BaseModuleDeclaration : BaseNode + +internal external interface BaseModuleSpecifier : BaseNode { + var local: Identifier +} + +internal external interface ImportDeclaration : BaseModuleDeclaration { + override var type: String /* "ImportDeclaration" */ + var specifiers: Array + var source: dynamic /* SimpleLiteral | RegExpLiteral */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ImportSpecifier : BaseModuleSpecifier { + override var type: String /* "ImportSpecifier" */ + var imported: Identifier +} + +internal external interface ImportExpression : BaseExpression { + override var type: String /* "ImportExpression" */ + var source: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ImportDefaultSpecifier : BaseModuleSpecifier { + override var type: String /* "ImportDefaultSpecifier" */ +} + +internal external interface ImportNamespaceSpecifier : BaseModuleSpecifier { + override var type: String /* "ImportNamespaceSpecifier" */ +} + +internal external interface ExportNamedDeclaration : BaseModuleDeclaration { + override var type: String /* "ExportNamedDeclaration" */ + var declaration: dynamic /* FunctionDeclaration? | VariableDeclaration? | ClassDeclaration? */ + get() = definedExternally + set(value) = definedExternally + var specifiers: Array + var source: dynamic /* SimpleLiteral? | RegExpLiteral? */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ExportSpecifier : BaseModuleSpecifier { + override var type: String /* "ExportSpecifier" */ + var exported: Identifier +} + +internal external interface ExportDefaultDeclaration : BaseModuleDeclaration { + override var type: String /* "ExportDefaultDeclaration" */ + var declaration: dynamic /* FunctionDeclaration | VariableDeclaration | ClassDeclaration | ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ExportAllDeclaration : BaseModuleDeclaration { + override var type: String /* "ExportAllDeclaration" */ + var source: dynamic /* SimpleLiteral | RegExpLiteral */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface AwaitExpression : BaseExpression { + override var type: String /* "AwaitExpression" */ + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/stream/stream.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/stream/stream.kt new file mode 100644 index 000000000..b3c65a758 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/stream/stream.kt @@ -0,0 +1,7 @@ +package kscience.kmath.estree.internal.stream + +import kscience.kmath.estree.internal.emitter.Emitter + +internal open external class Stream : Emitter { + open fun pipe(dest: Any, options: Any): Any +} diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es2015.iterable.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es2015.iterable.kt new file mode 100644 index 000000000..22d4dd8e0 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es2015.iterable.kt @@ -0,0 +1,25 @@ +package kscience.kmath.estree.internal.tsstdlib + +internal external interface IteratorYieldResult { + var done: Boolean? + get() = definedExternally + set(value) = definedExternally + var value: TYield +} + +internal external interface IteratorReturnResult { + var done: Boolean + var value: TReturn +} + +internal external interface Iterator { + fun next(vararg args: Any /* JsTuple<> | JsTuple */): dynamic /* IteratorYieldResult | IteratorReturnResult */ + val `return`: ((value: TReturn) -> dynamic)? + val `throw`: ((e: Any) -> dynamic)? +} + +internal typealias Iterator__1 = Iterator + +internal external interface Iterable + +internal external interface IterableIterator : Iterator__1 diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es5.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es5.kt new file mode 100644 index 000000000..70f6d9702 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es5.kt @@ -0,0 +1,82 @@ +@file:Suppress("UNUSED_TYPEALIAS_PARAMETER", "DEPRECATION") + +package kscience.kmath.estree.internal.tsstdlib + +import kotlin.js.RegExp + +internal typealias RegExpMatchArray = Array + +internal typealias RegExpExecArray = Array + +internal external interface RegExpConstructor { + @nativeInvoke + operator fun invoke(pattern: RegExp, flags: String = definedExternally): RegExp + + @nativeInvoke + operator fun invoke(pattern: RegExp): RegExp + + @nativeInvoke + operator fun invoke(pattern: String, flags: String = definedExternally): RegExp + + @nativeInvoke + operator fun invoke(pattern: String): RegExp + var prototype: RegExp + var `$1`: String + var `$2`: String + var `$3`: String + var `$4`: String + var `$5`: String + var `$6`: String + var `$7`: String + var `$8`: String + var `$9`: String + var lastMatch: String +} + +internal external interface ConcatArray { + var length: Number + + @nativeGetter + operator fun get(n: Number): T? + + @nativeSetter + operator fun set(n: Number, value: T) + fun join(separator: String = definedExternally): String + fun slice(start: Number = definedExternally, end: Number = definedExternally): Array +} + +internal external interface ArrayConstructor { + fun from(iterable: Iterable): Array + fun from(iterable: ArrayLike): Array + fun from(iterable: Iterable, mapfn: (v: T, k: Number) -> U, thisArg: Any = definedExternally): Array + fun from(iterable: Iterable, mapfn: (v: T, k: Number) -> U): Array + fun from(iterable: ArrayLike, mapfn: (v: T, k: Number) -> U, thisArg: Any = definedExternally): Array + fun from(iterable: ArrayLike, mapfn: (v: T, k: Number) -> U): Array + fun of(vararg items: T): Array + + @nativeInvoke + operator fun invoke(arrayLength: Number = definedExternally): Array + + @nativeInvoke + operator fun invoke(): Array + + @nativeInvoke + operator fun invoke(arrayLength: Number): Array + + @nativeInvoke + operator fun invoke(vararg items: T): Array + fun isArray(arg: Any): Boolean + var prototype: Array +} + +internal external interface ArrayLike { + var length: Number + + @nativeGetter + operator fun get(n: Number): T? + + @nativeSetter + operator fun set(n: Number, value: T) +} + +internal typealias Extract = Any diff --git a/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt new file mode 100644 index 000000000..b9be02d49 --- /dev/null +++ b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt @@ -0,0 +1,115 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.* +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.ByteRing +import kscience.kmath.operations.ComplexField +import kscience.kmath.operations.RealField +import kscience.kmath.operations.toComplex +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestESTreeConsistencyWithInterpreter { + @Test + fun mstSpace() { + val res1 = MstSpace.mstInSpace { + binaryOperationFunction("+")( + unaryOperationFunction("+")( + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }("x" to MST.Numeric(2)) + + val res2 = MstSpace.mstInSpace { + binaryOperationFunction("+")( + unaryOperationFunction("+")( + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }.compile()("x" to MST.Numeric(2)) + + assertEquals(res1, res2) + } + + @Test + fun byteRing() { + val res1 = ByteRing.mstInRing { + binaryOperationFunction("+")( + unaryOperationFunction("+")( + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }("x" to 3.toByte()) + + val res2 = ByteRing.mstInRing { + binaryOperationFunction("+")( + unaryOperationFunction("+")( + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + number(1) + ) * number(2) + }.compile()("x" to 3.toByte()) + + assertEquals(res1, res2) + } + + @Test + fun realField() { + val res1 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }("x" to 2.0) + + val res2 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }.compile()("x" to 2.0) + + assertEquals(res1, res2) + } + + @Test + fun complexField() { + val res1 = ComplexField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }("x" to 2.0.toComplex()) + + val res2 = ComplexField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }.compile()("x" to 2.0.toComplex()) + + assertEquals(res1, res2) + } +} diff --git a/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeOperationsSupport.kt b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeOperationsSupport.kt new file mode 100644 index 000000000..72a4669d9 --- /dev/null +++ b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeOperationsSupport.kt @@ -0,0 +1,41 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.mstInExtendedField +import kscience.kmath.ast.mstInField +import kscience.kmath.ast.mstInSpace +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestESTreeOperationsSupport { + @Test + fun testUnaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") }.compile() + val res = expression("x" to 2.0) + assertEquals(-2.0, res) + } + + @Test + fun testBinaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile() + val res = expression("x" to 2.0) + assertEquals(-1.0, res) + } + + @Test + fun testConstProductInvocation() { + val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) + assertEquals(4.0, res) + } + + @Test + fun testMultipleCalls() { + val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile() + val r = Random(0) + var s = 0.0 + repeat(1000000) { s += e("x" to r.nextDouble()) } + println(s) + } +} diff --git a/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeSpecialization.kt b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeSpecialization.kt new file mode 100644 index 000000000..9d0d17e58 --- /dev/null +++ b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeSpecialization.kt @@ -0,0 +1,54 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.mstInField +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestESTreeSpecialization { + @Test + fun testUnaryPlus() { + val expr = RealField.mstInField { unaryOperationFunction("+")(symbol("x")) }.compile() + assertEquals(2.0, expr("x" to 2.0)) + } + + @Test + fun testUnaryMinus() { + val expr = RealField.mstInField { unaryOperationFunction("-")(symbol("x")) }.compile() + assertEquals(-2.0, expr("x" to 2.0)) + } + + @Test + fun testAdd() { + val expr = RealField.mstInField { binaryOperationFunction("+")(symbol("x"), symbol("x")) }.compile() + assertEquals(4.0, expr("x" to 2.0)) + } + + @Test + fun testSine() { + val expr = RealField.mstInField { unaryOperationFunction("sin")(symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 0.0)) + } + + @Test + fun testMinus() { + val expr = RealField.mstInField { binaryOperationFunction("-")(symbol("x"), symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 2.0)) + } + + @Test + fun testDivide() { + val expr = RealField.mstInField { binaryOperationFunction("/")(symbol("x"), symbol("x")) }.compile() + assertEquals(1.0, expr("x" to 2.0)) + } + + @Test + fun testPower() { + val expr = RealField + .mstInField { binaryOperationFunction("pow")(symbol("x"), number(2)) } + .compile() + + assertEquals(4.0, expr("x" to 2.0)) + } +} diff --git a/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeVariables.kt b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeVariables.kt new file mode 100644 index 000000000..846120ee2 --- /dev/null +++ b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeVariables.kt @@ -0,0 +1,22 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.mstInRing +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +internal class TestESTreeVariables { + @Test + fun testVariable() { + val expr = ByteRing.mstInRing { symbol("x") }.compile() + assertEquals(1.toByte(), expr("x" to 1.toByte())) + } + + @Test + fun testUndefinedVariableFails() { + val expr = ByteRing.mstInRing { symbol("x") }.compile() + assertFailsWith { expr() } + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt index 9ccfa464c..55cdec243 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt @@ -1,13 +1,13 @@ package kscience.kmath.asm import kscience.kmath.asm.internal.AsmBuilder -import kscience.kmath.asm.internal.MstType -import kscience.kmath.asm.internal.buildAlgebraOperationCall import kscience.kmath.asm.internal.buildName import kscience.kmath.ast.MST +import kscience.kmath.ast.MST.* import kscience.kmath.ast.MstExpression import kscience.kmath.expressions.Expression import kscience.kmath.operations.Algebra +import kscience.kmath.operations.NumericAlgebra /** * Compiles given MST to an Expression using AST compiler. @@ -20,40 +20,54 @@ import kscience.kmath.operations.Algebra @PublishedApi internal fun MST.compileWith(type: Class, algebra: Algebra): Expression { fun AsmBuilder.visit(node: MST): Unit = when (node) { - is MST.Symbolic -> { + is Symbolic -> { val symbol = try { algebra.symbol(node.value) - } catch (ignored: Throwable) { + } catch (ignored: IllegalStateException) { null } if (symbol != null) - loadTConstant(symbol) + loadObjectConstant(symbol as Any) else loadVariable(node.value) } - is MST.Numeric -> loadNumeric(node.value) + is Numeric -> loadNumberConstant(node.value) - is MST.Unary -> buildAlgebraOperationCall( - context = algebra, - name = node.operation, - fallbackMethodName = "unaryOperation", - parameterTypes = arrayOf(MstType.fromMst(node.value)) - ) { visit(node.value) } + is Unary -> when { + algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant( + algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value))) - is MST.Binary -> buildAlgebraOperationCall( - context = algebra, - name = node.operation, - fallbackMethodName = "binaryOperation", - parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right)) - ) { - visit(node.left) - visit(node.right) + else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) } + } + + is Binary -> when { + algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant( + algebra.binaryOperationFunction(node.operation) + .invoke(algebra.number(node.left.value), algebra.number(node.right.value)) + ) + + algebra is NumericAlgebra && node.left is Numeric -> buildCall( + algebra.leftSideNumberOperationFunction(node.operation)) { + visit(node.left) + visit(node.right) + } + + algebra is NumericAlgebra && node.right is Numeric -> buildCall( + algebra.rightSideNumberOperationFunction(node.operation)) { + visit(node.left) + visit(node.right) + } + + else -> buildCall(algebra.binaryOperationFunction(node.operation)) { + visit(node.left) + visit(node.right) + } } } - return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() + return AsmBuilder(type, buildName(this)) { visit(this@compileWith) }.instance } /** diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index a1e482103..93d8d1143 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -3,29 +3,30 @@ package kscience.kmath.asm.internal import kscience.kmath.asm.internal.AsmBuilder.ClassLoader import kscience.kmath.ast.MST import kscience.kmath.expressions.Expression -import kscience.kmath.operations.Algebra -import kscience.kmath.operations.NumericAlgebra import org.objectweb.asm.* import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.Type.* import org.objectweb.asm.commons.InstructionAdapter -import java.util.* -import java.util.stream.Collectors +import java.lang.invoke.MethodHandles +import java.lang.invoke.MethodType +import java.lang.reflect.Modifier +import java.util.stream.Collectors.toMap +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * * @property T the type of AsmExpression to unwrap. - * @property algebra the algebra the applied AsmExpressions use. * @property className the unique class name of new loaded class. - * @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. + * @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0. * @author Iaroslav Postovalov */ -internal class AsmBuilder internal constructor( - private val classOfT: Class<*>, - private val algebra: Algebra, +internal class AsmBuilder( + classOfT: Class<*>, private val className: String, - private val invokeLabel0Visitor: AsmBuilder.() -> Unit, + private val callbackAtInvokeL0: AsmBuilder.() -> Unit, ) { /** * Internal classloader of [AsmBuilder] with alias to define class from byte array. @@ -39,20 +40,15 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - /** - * ASM Type for [algebra]. - */ - private val tAlgebraType: Type = algebra.javaClass.asm - /** * ASM type for [T]. */ - internal val tType: Type = classOfT.asm + private val tType: Type = classOfT.asm /** * ASM type for new class. */ - private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! + private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/')) /** * List of constants to provide to the subclass. @@ -64,55 +60,14 @@ internal class AsmBuilder internal constructor( */ private lateinit var invokeMethodVisitor: InstructionAdapter - /** - * States whether this [AsmBuilder] needs to generate constants field. - */ - private var hasConstants: Boolean = true - - /** - * States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. - */ - internal var primitiveMode: Boolean = false - - /** - * Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. - */ - internal var primitiveMask: Type = OBJECT_TYPE - - /** - * Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. - */ - internal var primitiveMaskBoxed: Type = OBJECT_TYPE - - /** - * Stack of useful objects types on stack to verify types. - */ - private val typeStack: ArrayDeque = ArrayDeque() - - /** - * Stack of useful objects types on stack expected by algebra calls. - */ - internal val expectationStack: ArrayDeque = ArrayDeque(1).also { it.push(tType) } - - /** - * The cache for instance built by this builder. - */ - private var generatedInstance: Expression? = null - /** * Subclasses, loads and instantiates [Expression] for given parameters. * * The built instance is cached. */ @Suppress("UNCHECKED_CAST") - internal fun getInstance(): Expression { - generatedInstance?.let { return it } - - if (SIGNATURE_LETTERS.containsKey(classOfT)) { - primitiveMode = true - primitiveMask = SIGNATURE_LETTERS.getValue(classOfT) - primitiveMaskBoxed = tType - } + val instance: Expression by lazy { + val hasConstants: Boolean val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( @@ -121,20 +76,20 @@ internal class AsmBuilder internal constructor( classType.internalName, "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", OBJECT_TYPE.internalName, - arrayOf(EXPRESSION_TYPE.internalName) + arrayOf(EXPRESSION_TYPE.internalName), ) visitMethod( ACC_PUBLIC or ACC_FINAL, "invoke", - Type.getMethodDescriptor(tType, MAP_TYPE), - "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", - null + getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}${if (Modifier.isFinal(classOfT.modifiers)) "" else "+"}${tType.descriptor}>;)${tType.descriptor}", + null, ).instructionAdapter { invokeMethodVisitor = this visitCode() val l0 = label() - invokeLabel0Visitor() + callbackAtInvokeL0() areturn(tType) val l1 = label() @@ -144,7 +99,7 @@ internal class AsmBuilder internal constructor( null, l0, l1, - invokeThisVar + 0, ) visitLocalVariable( @@ -153,7 +108,7 @@ internal class AsmBuilder internal constructor( "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", l0, l1, - invokeArgumentsVar + 1, ) visitMaxs(0, 2) @@ -163,17 +118,15 @@ internal class AsmBuilder internal constructor( visitMethod( ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, "invoke", - Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + null, null, - null ).instructionAdapter { - val thisVar = 0 - val argumentsVar = 1 visitCode() val l0 = label() - load(thisVar, OBJECT_TYPE) - load(argumentsVar, MAP_TYPE) - invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) + load(0, OBJECT_TYPE) + load(1, MAP_TYPE) + invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false) areturn(tType) val l1 = label() @@ -183,7 +136,7 @@ internal class AsmBuilder internal constructor( null, l0, l1, - thisVar + 0, ) visitMaxs(0, 2) @@ -192,15 +145,6 @@ internal class AsmBuilder internal constructor( hasConstants = constants.isNotEmpty() - visitField( - access = ACC_PRIVATE or ACC_FINAL, - name = "algebra", - descriptor = tAlgebraType.descriptor, - signature = null, - value = null, - block = FieldVisitor::visitEnd - ) - if (hasConstants) visitField( access = ACC_PRIVATE or ACC_FINAL, @@ -208,55 +152,36 @@ internal class AsmBuilder internal constructor( descriptor = OBJECT_ARRAY_TYPE.descriptor, signature = null, value = null, - block = FieldVisitor::visitEnd + block = FieldVisitor::visitEnd, ) visitMethod( ACC_PUBLIC, "", - - Type.getMethodDescriptor( - Type.VOID_TYPE, - tAlgebraType, - *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), - + getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), + null, null, - null ).instructionAdapter { - val thisVar = 0 - val algebraVar = 1 - val constantsVar = 2 val l0 = label() - load(thisVar, classType) - invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) + load(0, classType) + invokespecial(OBJECT_TYPE.internalName, "", getMethodDescriptor(VOID_TYPE), false) label() - load(thisVar, classType) - load(algebraVar, tAlgebraType) - putfield(classType.internalName, "algebra", tAlgebraType.descriptor) + load(0, classType) if (hasConstants) { label() - load(thisVar, classType) - load(constantsVar, OBJECT_ARRAY_TYPE) + load(0, classType) + load(1, OBJECT_ARRAY_TYPE) putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) } label() visitInsn(RETURN) val l4 = label() - visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) - - visitLocalVariable( - "algebra", - tAlgebraType.descriptor, - null, - l0, - l4, - algebraVar - ) + visitLocalVariable("this", classType.descriptor, null, l0, l4, 0) if (hasConstants) - visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1) visitMaxs(0, 3) visitEnd() @@ -265,296 +190,156 @@ internal class AsmBuilder internal constructor( visitEnd() } - val new = classLoader - .defineClass(className, classWriter.toByteArray()) - .constructors - .first() - .newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression + val cls = classLoader.defineClass(className, classWriter.toByteArray()) + // java.io.File("dump.class").writeBytes(classWriter.toByteArray()) + val l = MethodHandles.publicLookup() - generatedInstance = new - return new + if (hasConstants) + l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array::class.java)) + .invoke(constants.toTypedArray()) as Expression + else + l.findConstructor(cls, MethodType.methodType(Void.TYPE)).invoke() as Expression } - /** - * Loads a [T] constant from [constants]. - */ - internal fun loadTConstant(value: T) { - if (classOfT in INLINABLE_NUMBERS) { - val expectedType = expectationStack.pop() - val mustBeBoxed = expectedType.sort == Type.OBJECT - loadNumberConstant(value as Number, mustBeBoxed) - - if (mustBeBoxed) - invokeMethodVisitor.checkcast(tType) - - if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) - return - } - - loadObjectConstant(value as Any, tType) - } - - /** - * Boxes the current value and pushes it. - */ - private fun box(primitive: Type) { - val r = PRIMITIVES_TO_BOXED.getValue(primitive) - - invokeMethodVisitor.invokestatic( - r.internalName, - "valueOf", - Type.getMethodDescriptor(r, primitive), - false - ) - } - - /** - * Unboxes the current boxed value and pushes it. - */ - private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual( - NUMBER_TYPE.internalName, - NUMBER_CONVERTER_METHODS.getValue(primitive), - Type.getMethodDescriptor(primitive), - false - ) - /** * Loads [java.lang.Object] constant from constants. */ - private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { - val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - loadThis() + fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run { + val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex + invokeMethodVisitor.load(0, classType) getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) iconst(idx) visitInsn(AALOAD) - checkcast(type) + if (type != OBJECT_TYPE) checkcast(type) } - internal fun loadNumeric(value: Number) { - if (expectationStack.peek() == NUMBER_TYPE) { - loadNumberConstant(value, true) - expectationStack.pop() - typeStack.push(NUMBER_TYPE) - } else (algebra as? NumericAlgebra)?.number(value)?.let { loadTConstant(it) } - ?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.") - } - - /** - * Loads this variable. - */ - private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) - /** * Either loads a numeric constant [value] from the class's constants field or boxes a primitive - * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded - * from it). + * constant from the constant pool. */ - private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { + fun loadNumberConstant(value: Number) { val boxed = value.javaClass.asm val primitive = BOXED_TO_PRIMITIVES[boxed] if (primitive != null) { when (primitive) { - Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) - Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) - Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) - Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) } - if (mustBeBoxed) - box(primitive) + val r = PRIMITIVES_TO_BOXED.getValue(primitive) + + invokeMethodVisitor.invokestatic( + r.internalName, + "valueOf", + getMethodDescriptor(r, primitive), + false, + ) return } loadObjectConstant(value, boxed) - - if (!mustBeBoxed) - unboxTo(primitiveMask) } /** - * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be - * provided. + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. */ - internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run { - load(invokeArgumentsVar, MAP_TYPE) + fun loadVariable(name: String): Unit = invokeMethodVisitor.run { + load(1, MAP_TYPE) aconst(name) invokestatic( MAP_INTRINSICS_TYPE.internalName, "getOrFail", - Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), - false + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), + false, ) checkcast(tType) - val expectedType = expectationStack.pop() - - if (expectedType.sort == Type.OBJECT) - typeStack.push(tType) - else { - unboxTo(primitiveMask) - typeStack.push(primitiveMask) - } } - /** - * Loads algebra from according field of the class and casts it to class of [algebra] provided. - */ - internal fun loadAlgebra() { - loadThis() - invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) - } + inline fun buildCall(function: Function, parameters: AsmBuilder.() -> Unit) { + contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } + val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces } - /** - * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is - * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be - * called before the arguments and this operation. - * - * The result is casted to [T] automatically. - */ - internal fun invokeAlgebraOperation( - owner: String, - method: String, - descriptor: String, - expectedArity: Int, - opcode: Int = INVOKEINTERFACE, - ) { - run loop@{ - repeat(expectedArity) { - if (typeStack.isEmpty()) return@loop - typeStack.pop() - } - } + val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount + ?: error("Provided function object doesn't contain invoke method") - invokeMethodVisitor.visitMethodInsn( - opcode, - owner, - method, - descriptor, - opcode == INVOKEINTERFACE + val type = getType(`interface`) + loadObjectConstant(function, type) + parameters(this) + + invokeMethodVisitor.invokeinterface( + type.internalName, + "invoke", + getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }), ) invokeMethodVisitor.checkcast(tType) - val isLastExpr = expectationStack.size == 1 - val expectedType = expectationStack.pop() - - if (expectedType.sort == Type.OBJECT || isLastExpr) - typeStack.push(tType) - else { - unboxTo(primitiveMask) - typeStack.push(primitiveMask) - } } - /** - * Writes a LDC Instruction with string constant provided. - */ - internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) - - internal companion object { - /** - * Index of `this` variable in invoke method of the built subclass. - */ - private const val invokeThisVar: Int = 0 - - /** - * Index of `arguments` variable in invoke method of the built subclass. - */ - private const val invokeArgumentsVar: Int = 1 - - /** - * Maps JVM primitive numbers boxed types to their primitive ASM types. - */ - private val SIGNATURE_LETTERS: Map, Type> by lazy { - hashMapOf( - java.lang.Byte::class.java to Type.BYTE_TYPE, - java.lang.Short::class.java to Type.SHORT_TYPE, - java.lang.Integer::class.java to Type.INT_TYPE, - java.lang.Long::class.java to Type.LONG_TYPE, - java.lang.Float::class.java to Type.FLOAT_TYPE, - java.lang.Double::class.java to Type.DOUBLE_TYPE - ) - } - + companion object { /** * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. */ - private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } + private val BOXED_TO_PRIMITIVES: Map by lazy { + hashMapOf( + Byte::class.java.asm to BYTE_TYPE, + Short::class.java.asm to SHORT_TYPE, + Integer::class.java.asm to INT_TYPE, + Long::class.java.asm to LONG_TYPE, + Float::class.java.asm to FLOAT_TYPE, + Double::class.java.asm to DOUBLE_TYPE, + ) + } /** * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. */ private val PRIMITIVES_TO_BOXED: Map by lazy { BOXED_TO_PRIMITIVES.entries.stream().collect( - Collectors.toMap( - Map.Entry::value, - Map.Entry::key - ) + toMap(Map.Entry::value, Map.Entry::key), ) } - /** - * Maps primitive ASM types to [Number] functions unboxing them. - */ - private val NUMBER_CONVERTER_METHODS: Map by lazy { - hashMapOf( - Type.BYTE_TYPE to "byteValue", - Type.SHORT_TYPE to "shortValue", - Type.INT_TYPE to "intValue", - Type.LONG_TYPE to "longValue", - Type.FLOAT_TYPE to "floatValue", - Type.DOUBLE_TYPE to "doubleValue" - ) - } - - /** - * Provides boxed number types values of which can be stored in JVM bytecode constant pool. - */ - private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } - /** * ASM type for [Expression]. */ - internal val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/expressions/Expression") } - - /** - * ASM type for [java.lang.Number]. - */ - internal val NUMBER_TYPE: Type by lazy { Type.getObjectType("java/lang/Number") } + val EXPRESSION_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Expression") } /** * ASM type for [java.util.Map]. */ - internal val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") } + val MAP_TYPE: Type by lazy { getObjectType("java/util/Map") } /** * ASM type for [java.lang.Object]. */ - internal val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") } + val OBJECT_TYPE: Type by lazy { getObjectType("java/lang/Object") } /** * ASM type for array of [java.lang.Object]. */ - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") - internal val OBJECT_ARRAY_TYPE: Type by lazy { Type.getType("[Ljava/lang/Object;") } - - /** - * ASM type for [Algebra]. - */ - internal val ALGEBRA_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/operations/Algebra") } + val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") } /** * ASM type for [java.lang.String]. */ - internal val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") } + val STRING_TYPE: Type by lazy { getObjectType("java/lang/String") } /** * ASM type for MapIntrinsics. */ - internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("kscience/kmath/asm/internal/MapIntrinsics") } + val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("kscience/kmath/asm/internal/MapIntrinsics") } + + /** + * ASM Type for [kscience.kmath.expressions.Symbol]. + */ + val SYMBOL_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Symbol") } } } diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt deleted file mode 100644 index 526c27305..000000000 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt +++ /dev/null @@ -1,20 +0,0 @@ -package kscience.kmath.asm.internal - -import kscience.kmath.ast.MST - -/** - * Represents types known in [MST], numbers and general values. - */ -internal enum class MstType { - GENERAL, - NUMBER; - - companion object { - fun fromMst(mst: MST): MstType { - if (mst is MST.Numeric) - return NUMBER - - return GENERAL - } - } -} diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt index ef9751502..6d5d19d42 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt @@ -2,29 +2,11 @@ package kscience.kmath.asm.internal import kscience.kmath.ast.MST import kscience.kmath.expressions.Expression -import kscience.kmath.operations.Algebra -import kscience.kmath.operations.FieldOperations -import kscience.kmath.operations.RingOperations -import kscience.kmath.operations.SpaceOperations import org.objectweb.asm.* -import org.objectweb.asm.Opcodes.INVOKEVIRTUAL import org.objectweb.asm.commons.InstructionAdapter -import java.lang.reflect.Method -import java.util.* import kotlin.contracts.InvocationKind import kotlin.contracts.contract -private val methodNameAdapters: Map, String> by lazy { - hashMapOf( - SpaceOperations.PLUS_OPERATION to 2 to "add", - RingOperations.TIMES_OPERATION to 2 to "multiply", - FieldOperations.DIV_OPERATION to 2 to "divide", - SpaceOperations.PLUS_OPERATION to 1 to "unaryPlus", - SpaceOperations.MINUS_OPERATION to 1 to "unaryMinus", - SpaceOperations.MINUS_OPERATION to 2 to "minus" - ) -} - /** * Returns ASM [Type] for given [Class]. * @@ -109,107 +91,3 @@ internal inline fun ClassWriter.visitField( contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return visitField(access, name, descriptor, signature, value).apply(block) } - -private fun AsmBuilder.findSpecific(context: Algebra, name: String, parameterTypes: Array): Method? = - context.javaClass.methods.find { method -> - val nameValid = method.name == name - val arityValid = method.parameters.size == parameterTypes.size - val notBridgeInPrimitive = !(primitiveMode && method.isBridge) - - val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) -> - !(mstType != MstType.NUMBER && type == java.lang.Number::class.java) - } - - nameValid && arityValid && notBridgeInPrimitive && paramsValid - } - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds - * type expectation stack for needed arity. - * - * @author Iaroslav Postovalov - */ -private fun AsmBuilder.buildExpectationStack( - context: Algebra, - name: String, - parameterTypes: Array -): Boolean { - val arity = parameterTypes.size - val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes) - - if (specific != null) - mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) } - else - expectationStack.addAll(Collections.nCopies(arity, tType)) - - return specific != null -} - -private fun AsmBuilder.mapTypes(method: Method, parameterTypes: Array): List = method - .parameterTypes - .zip(parameterTypes) - .map { (type, mstType) -> - when { - type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE - else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed - } - } - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts - * [AsmBuilder.invokeAlgebraOperation] of this method. - * - * @author Iaroslav Postovalov - */ -private fun AsmBuilder.tryInvokeSpecific( - context: Algebra, - name: String, - parameterTypes: Array -): Boolean { - val arity = parameterTypes.size - val theName = methodNameAdapters[name to arity] ?: name - val spec = findSpecific(context, theName, parameterTypes) ?: return false - val owner = context.javaClass.asm - - invokeAlgebraOperation( - owner = owner.internalName, - method = theName, - descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()), - expectedArity = arity, - opcode = INVOKEVIRTUAL - ) - - return true -} - -/** - * Builds specialized [context] call with option to fallback to generic algebra operation accepting [String]. - * - * @author Iaroslav Postovalov - */ -internal inline fun AsmBuilder.buildAlgebraOperationCall( - context: Algebra, - name: String, - fallbackMethodName: String, - parameterTypes: Array, - parameters: AsmBuilder.() -> Unit -) { - contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } - val arity = parameterTypes.size - loadAlgebra() - if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) - parameters() - - if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = fallbackMethodName, - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - *Array(arity) { AsmBuilder.OBJECT_TYPE } - ), - - expectedArity = arity - ) -} diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt index 79cf77a6a..0b66e2c31 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt @@ -1,3 +1,5 @@ +// TODO move to common when https://github.com/h0tk3y/better-parse/pull/33 is merged + package kscience.kmath.ast import com.github.h0tk3y.betterParse.combinators.* @@ -17,7 +19,8 @@ import kscience.kmath.operations.RingOperations import kscience.kmath.operations.SpaceOperations /** - * TODO move to common after IR version is released + * better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4. + * * @author Alexander Nozik and Iaroslav Postovalov */ public object ArithmeticsEvaluator : Grammar() { @@ -83,7 +86,7 @@ public object ArithmeticsEvaluator : Grammar() { } /** - * Tries to parse the string into [MST]. Returns [ParseResult] representing expression or error. + * Tries to parse the string into [MST] using [ArithmeticsEvaluator]. Returns [ParseResult] representing expression or error. * * @receiver the string to parse. * @return the [MST] node. @@ -91,7 +94,7 @@ public object ArithmeticsEvaluator : Grammar() { public fun String.tryParseMath(): ParseResult = ArithmeticsEvaluator.tryParseToEnd(this) /** - * Parses the string into [MST]. + * Parses the string into [MST] using [ArithmeticsEvaluator]. * * @receiver the string to parse. * @return the [MST] node. diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt similarity index 56% rename from kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt rename to kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt index 5eebfe43d..ae180bf3f 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt @@ -1,24 +1,20 @@ package kscience.kmath.asm -import kscience.kmath.ast.mstInField -import kscience.kmath.ast.mstInRing -import kscience.kmath.ast.mstInSpace +import kscience.kmath.ast.* import kscience.kmath.expressions.invoke import kscience.kmath.operations.ByteRing +import kscience.kmath.operations.ComplexField import kscience.kmath.operations.RealField +import kscience.kmath.operations.toComplex import kotlin.test.Test import kotlin.test.assertEquals -internal class TestAsmAlgebras { - +internal class TestAsmConsistencyWithInterpreter { @Test - fun space() { - val res1 = ByteRing.mstInSpace { - binaryOperation( - "+", - - unaryOperation( - "+", + fun mstSpace() { + val res1 = MstSpace.mstInSpace { + binaryOperationFunction("+")( + unaryOperationFunction("+")( number(3.toByte()) - (number(2.toByte()) + (multiply( add(number(1), number(1)), 2 @@ -27,14 +23,11 @@ internal class TestAsmAlgebras { number(1) ) + symbol("x") + zero - }("x" to 2.toByte()) + }("x" to MST.Numeric(2)) - val res2 = ByteRing.mstInSpace { - binaryOperation( - "+", - - unaryOperation( - "+", + val res2 = MstSpace.mstInSpace { + binaryOperationFunction("+")( + unaryOperationFunction("+")( number(3.toByte()) - (number(2.toByte()) + (multiply( add(number(1), number(1)), 2 @@ -43,19 +36,16 @@ internal class TestAsmAlgebras { number(1) ) + symbol("x") + zero - }.compile()("x" to 2.toByte()) + }.compile()("x" to MST.Numeric(2)) assertEquals(res1, res2) } @Test - fun ring() { + fun byteRing() { val res1 = ByteRing.mstInRing { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperationFunction("+")( + unaryOperationFunction("+")( (symbol("x") - (2.toByte() + (multiply( add(number(1), number(1)), 2 @@ -67,17 +57,13 @@ internal class TestAsmAlgebras { }("x" to 3.toByte()) val res2 = ByteRing.mstInRing { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperationFunction("+")( + unaryOperationFunction("+")( (symbol("x") - (2.toByte() + (multiply( add(number(1), number(1)), 2 ) + 1.toByte()))) * 3.0 - 1.toByte() ), - number(1) ) * number(2) }.compile()("x" to 3.toByte()) @@ -86,10 +72,9 @@ internal class TestAsmAlgebras { } @Test - fun field() { + fun realField() { val res1 = RealField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( - "+", + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one @@ -97,8 +82,7 @@ internal class TestAsmAlgebras { }("x" to 2.0) val res2 = RealField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( - "+", + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one @@ -107,4 +91,25 @@ internal class TestAsmAlgebras { assertEquals(res1, res2) } + + @Test + fun complexField() { + val res1 = ComplexField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }("x" to 2.0.toComplex()) + + val res2 = ComplexField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }.compile()("x" to 2.0.toComplex()) + + assertEquals(res1, res2) + } } diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmOperationsSupport.kt similarity index 67% rename from kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt rename to kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmOperationsSupport.kt index 7f4d453f7..2ce52aa87 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmOperationsSupport.kt @@ -1,14 +1,15 @@ package kscience.kmath.asm -import kscience.kmath.asm.compile +import kscience.kmath.ast.mstInExtendedField import kscience.kmath.ast.mstInField import kscience.kmath.ast.mstInSpace import kscience.kmath.expressions.invoke import kscience.kmath.operations.RealField +import kotlin.random.Random import kotlin.test.Test import kotlin.test.assertEquals -internal class TestAsmExpressions { +internal class TestAsmOperationsSupport { @Test fun testUnaryOperationInvocation() { val expression = RealField.mstInSpace { -symbol("x") }.compile() @@ -28,4 +29,13 @@ internal class TestAsmExpressions { val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) assertEquals(4.0, res) } + + @Test + fun testMultipleCalls() { + val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile() + val r = Random(0) + var s = 0.0 + repeat(1000000) { s += e("x" to r.nextDouble()) } + println(s) + } } diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt index 8f8175acd..602c54651 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt @@ -1,6 +1,5 @@ package kscience.kmath.asm -import kscience.kmath.asm.compile import kscience.kmath.ast.mstInField import kscience.kmath.expressions.invoke import kscience.kmath.operations.RealField @@ -10,44 +9,44 @@ import kotlin.test.assertEquals internal class TestAsmSpecialization { @Test fun testUnaryPlus() { - val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() + val expr = RealField.mstInField { unaryOperationFunction("+")(symbol("x")) }.compile() assertEquals(2.0, expr("x" to 2.0)) } @Test fun testUnaryMinus() { - val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() + val expr = RealField.mstInField { unaryOperationFunction("-")(symbol("x")) }.compile() assertEquals(-2.0, expr("x" to 2.0)) } @Test fun testAdd() { - val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() + val expr = RealField.mstInField { binaryOperationFunction("+")(symbol("x"), symbol("x")) }.compile() assertEquals(4.0, expr("x" to 2.0)) } @Test fun testSine() { - val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() + val expr = RealField.mstInField { unaryOperationFunction("sin")(symbol("x")) }.compile() assertEquals(0.0, expr("x" to 0.0)) } @Test fun testMinus() { - val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() + val expr = RealField.mstInField { binaryOperationFunction("-")(symbol("x"), symbol("x")) }.compile() assertEquals(0.0, expr("x" to 2.0)) } @Test fun testDivide() { - val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() + val expr = RealField.mstInField { binaryOperationFunction("/")(symbol("x"), symbol("x")) }.compile() assertEquals(1.0, expr("x" to 2.0)) } @Test fun testPower() { val expr = RealField - .mstInField { binaryOperation("power", symbol("x"), number(2)) } + .mstInField { binaryOperationFunction("pow")(symbol("x"), number(2)) } .compile() assertEquals(4.0, expr("x" to 2.0)) diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt index 7b89c74fa..c91568dbf 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt @@ -9,14 +9,14 @@ import kotlin.test.assertFailsWith internal class TestAsmVariables { @Test - fun testVariableWithoutDefault() { - val expr = ByteRing.mstInRing { symbol("x") } + fun testVariable() { + val expr = ByteRing.mstInRing { symbol("x") }.compile() assertEquals(1.toByte(), expr("x" to 1.toByte())) } @Test - fun testVariableWithoutDefaultFails() { - val expr = ByteRing.mstInRing { symbol("x") } - assertFailsWith { expr() } + fun testUndefinedVariableFails() { + val expr = ByteRing.mstInRing { symbol("x") }.compile() + assertFailsWith { expr() } } } diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/AsmTest.kt deleted file mode 100644 index 1ca8aa5dd..000000000 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/AsmTest.kt +++ /dev/null @@ -1,25 +0,0 @@ -package kscience.kmath.ast - -import kscience.kmath.asm.compile -import kscience.kmath.asm.expression -import kscience.kmath.ast.mstInField -import kscience.kmath.ast.parseMath -import kscience.kmath.expressions.invoke -import kscience.kmath.operations.Complex -import kscience.kmath.operations.ComplexField -import kotlin.test.Test -import kotlin.test.assertEquals - -internal class AsmTest { - @Test - fun `compile MST`() { - val res = ComplexField.expression("2+2*(2+2)".parseMath())() - assertEquals(Complex(10.0, 0.0), res) - } - - @Test - fun `compile MSTExpression`() { - val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()() - assertEquals(Complex(10.0, 0.0), res) - } -} diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserPrecedenceTest.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserPrecedenceTest.kt index 1844a0991..561fe51bd 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserPrecedenceTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserPrecedenceTest.kt @@ -1,7 +1,5 @@ package kscience.kmath.ast -import kscience.kmath.ast.evaluate -import kscience.kmath.ast.parseMath import kscience.kmath.operations.Field import kscience.kmath.operations.RealField import kotlin.test.Test diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt index 2dc24597e..3aa5392c8 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt @@ -1,8 +1,5 @@ package kscience.kmath.ast -import kscience.kmath.ast.evaluate -import kscience.kmath.ast.mstInField -import kscience.kmath.ast.parseMath import kscience.kmath.expressions.invoke import kscience.kmath.operations.Algebra import kscience.kmath.operations.Complex @@ -45,12 +42,15 @@ internal class ParserTest { val magicalAlgebra = object : Algebra { override fun symbol(value: String): String = value - override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError() - - override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) { - "magic" -> "$left ★ $right" - else -> throw NotImplementedError() + override fun unaryOperationFunction(operation: String): (arg: String) -> String { + throw NotImplementedError() } + + override fun binaryOperationFunction(operation: String): (left: String, right: String) -> String = + when (operation) { + "magic" -> { left, right -> "$left ★ $right" } + else -> throw NotImplementedError() + } } val mst = "magic(a, b)".parseMath() diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 345babe8b..2912ddc4c 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -1,7 +1,9 @@ package kscience.kmath.commons.expressions import kscience.kmath.expressions.* +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.ExtendedField +import kscience.kmath.operations.RingWithNumbers import org.apache.commons.math3.analysis.differentiation.DerivativeStructure /** @@ -10,15 +12,18 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure * @property order The derivation order. * @property bindings The map of bindings values. All bindings are considered free parameters */ +@OptIn(UnstableKMathAPI::class) public class DerivativeStructureField( public val order: Int, bindings: Map, -) : ExtendedField, ExpressionAlgebra { +) : ExtendedField, ExpressionAlgebra, RingWithNumbers { public val numberOfVariables: Int = bindings.size public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) } + override fun number(value: Number): DerivativeStructure = const(value.toDouble()) + /** * A class that implements both [DerivativeStructure] and a [Symbol] */ diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt index 712927400..850446afa 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt @@ -1,41 +1,28 @@ package kscience.kmath.commons.linear -import kscience.kmath.linear.* +import kscience.kmath.linear.DiagonalFeature +import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.Point +import kscience.kmath.linear.origin +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix -import kscience.kmath.structures.NDStructure import org.apache.commons.math3.linear.* +import kotlin.reflect.KClass +import kotlin.reflect.cast -public class CMMatrix(public val origin: RealMatrix, features: Set? = null) : FeaturedMatrix { +public inline class CMMatrix(public val origin: RealMatrix) : Matrix { public override val rowNum: Int get() = origin.rowDimension public override val colNum: Int get() = origin.columnDimension - public override val features: Set = features ?: sequence { - if (origin is DiagonalMatrix) yield(DiagonalFeature) - }.toHashSet() - - public override fun suggestFeature(vararg features: MatrixFeature): CMMatrix = - CMMatrix(origin, this.features + features) + @UnstableKMathAPI + override fun getFeature(type: KClass): T? = when (type) { + DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null + else -> null + }?.let { type.cast(it) } public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) - - public override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) - } - - public override fun hashCode(): Int { - var result = origin.hashCode() - result = 31 * result + features.hashCode() - return result - } } -public fun Matrix.toCM(): CMMatrix = if (this is CMMatrix) { - this -} else { - //TODO add feature analysis - val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } - CMMatrix(Array2DRowRealMatrix(array)) -} public fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this) @@ -60,6 +47,16 @@ public object CMMatrixContext : MatrixContext { return CMMatrix(Array2DRowRealMatrix(array)) } + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toCM(): CMMatrix = when (val matrix = origin) { + is CMMatrix -> matrix + else -> { + //TODO add feature analysis + val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } + CMMatrix(Array2DRowRealMatrix(array)) + } + } + public override fun Matrix.dot(other: Matrix): CMMatrix = CMMatrix(toCM().origin.multiply(other.toCM().origin)) diff --git a/kmath-core/README.md b/kmath-core/README.md index 3a919e85a..7882e5252 100644 --- a/kmath-core/README.md +++ b/kmath-core/README.md @@ -26,7 +26,7 @@ The core features of KMath: > maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } > maven { url 'https://dl.bintray.com/hotkeytlt/maven' } - +> > } > > dependencies { diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 0630e8e4b..1a3668855 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -25,34 +25,34 @@ public abstract class FunctionalExpressionAlgebra>( /** * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. */ - public override fun binaryOperation( - operation: String, - left: Expression, - right: Expression, - ): Expression = Expression { arguments -> - algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments)) - } + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + { left, right -> + Expression { arguments -> + algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments)) + } + } /** * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. */ - public override fun unaryOperation(operation: String, arg: Expression): Expression = Expression { arguments -> - algebra.unaryOperation(operation, arg.invoke(arguments)) + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = { arg -> + Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) } } } /** * A context class for [Expression] construction for [Space] algebras. */ -public open class FunctionalExpressionSpace>(algebra: A) : - FunctionalExpressionAlgebra(algebra), Space> { +public open class FunctionalExpressionSpace>( + algebra: A, +) : FunctionalExpressionAlgebra(algebra), Space> { public override val zero: Expression get() = const(algebra.zero) /** * Builds an Expression of addition of two another expressions. */ public override fun add(a: Expression, b: Expression): Expression = - binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) + binaryOperationFunction(SpaceOperations.PLUS_OPERATION)(a, b) /** * Builds an Expression of multiplication of expression by number. @@ -66,15 +66,16 @@ public open class FunctionalExpressionSpace>(algebra: A) : public operator fun T.plus(arg: Expression): Expression = arg + this public operator fun T.minus(arg: Expression): Expression = arg - this - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + super.unaryOperationFunction(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperationFunction(operation) } -public open class FunctionalExpressionRing(algebra: A) : FunctionalExpressionSpace(algebra), - Ring> where A : Ring, A : NumericAlgebra { +public open class FunctionalExpressionRing>( + algebra: A, +) : FunctionalExpressionSpace(algebra), Ring> { public override val one: Expression get() = const(algebra.one) @@ -82,68 +83,72 @@ public open class FunctionalExpressionRing(algebra: A) : FunctionalExpress * Builds an Expression of multiplication of two expressions. */ public override fun multiply(a: Expression, b: Expression): Expression = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) + binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) public operator fun Expression.times(arg: T): Expression = this * const(arg) public operator fun T.times(arg: Expression): Expression = arg * this - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + super.unaryOperationFunction(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperationFunction(operation) } -public open class FunctionalExpressionField(algebra: A) : - FunctionalExpressionRing(algebra), Field> - where A : Field, A : NumericAlgebra { +public open class FunctionalExpressionField>( + algebra: A, +) : FunctionalExpressionRing(algebra), Field> { /** * Builds an Expression of division an expression by another one. */ public override fun divide(a: Expression, b: Expression): Expression = - binaryOperation(FieldOperations.DIV_OPERATION, a, b) + binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) public operator fun Expression.div(arg: T): Expression = this / const(arg) public operator fun T.div(arg: Expression): Expression = arg / this - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + super.unaryOperationFunction(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperationFunction(operation) } -public open class FunctionalExpressionExtendedField(algebra: A) : - FunctionalExpressionField(algebra), - ExtendedField> where A : ExtendedField, A : NumericAlgebra { +public open class FunctionalExpressionExtendedField>( + algebra: A, +) : FunctionalExpressionField(algebra), ExtendedField> { + + override fun number(value: Number): Expression = const(algebra.number(value)) + public override fun sin(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) + unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) public override fun cos(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.COS_OPERATION, arg) + unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg) public override fun asin(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) + unaryOperationFunction(TrigonometricOperations.ASIN_OPERATION)(arg) public override fun acos(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) + unaryOperationFunction(TrigonometricOperations.ACOS_OPERATION)(arg) public override fun atan(arg: Expression): Expression = - unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) + unaryOperationFunction(TrigonometricOperations.ATAN_OPERATION)(arg) public override fun power(arg: Expression, pow: Number): Expression = - binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) + binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) public override fun exp(arg: Expression): Expression = - unaryOperation(ExponentialOperations.EXP_OPERATION, arg) + unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg) - public override fun ln(arg: Expression): Expression = unaryOperation(ExponentialOperations.LN_OPERATION, arg) + public override fun ln(arg: Expression): Expression = + unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg) - public override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + super.unaryOperationFunction(operation) - public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperationFunction(operation) } public inline fun > A.expressionInSpace(block: FunctionalExpressionSpace.() -> Expression): Expression = diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index e8a894d23..0621e82bd 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -1,6 +1,7 @@ package kscience.kmath.expressions import kscience.kmath.linear.Point +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.* import kscience.kmath.structures.asBuffer import kotlin.contracts.InvocationKind @@ -79,10 +80,11 @@ public fun > F.simpleAutoDiff( /** * Represents field in context of which functions can be derived. */ +@OptIn(UnstableKMathAPI::class) public open class SimpleAutoDiffField>( public val context: F, bindings: Map, -) : Field>, ExpressionAlgebra> { +) : Field>, ExpressionAlgebra>, RingWithNumbers> { public override val zero: AutoDiffValue get() = const(context.zero) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt index 8b50bbe33..a74d948fc 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt @@ -1,6 +1,5 @@ package kscience.kmath.linear -import kscience.kmath.operations.RealField import kscience.kmath.operations.Ring import kscience.kmath.structures.* @@ -21,30 +20,11 @@ public class BufferMatrixContext>( public companion object } -@Suppress("OVERRIDE_BY_INLINE") -public object RealMatrixContext : GenericMatrixContext> { - public override val elementContext: RealField - get() = RealField - - public override inline fun produce( - rows: Int, - columns: Int, - initializer: (i: Int, j: Int) -> Double, - ): BufferMatrix { - val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } - return BufferMatrix(rows, columns, buffer) - } - - public override inline fun point(size: Int, initializer: (Int) -> Double): Point = - RealBuffer(size, initializer) -} - public class BufferMatrix( public override val rowNum: Int, public override val colNum: Int, public val buffer: Buffer, - public override val features: Set = emptySet(), -) : FeaturedMatrix { +) : Matrix { init { require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" } @@ -52,9 +32,6 @@ public class BufferMatrix( override val shape: IntArray get() = intArrayOf(rowNum, colNum) - public override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix = - BufferMatrix(rowNum, colNum, buffer, this.features + features) - public override operator fun get(index: IntArray): T = get(index[0], index[1]) public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j] @@ -66,23 +43,26 @@ public class BufferMatrix( if (this === other) return true return when (other) { - is NDStructure<*> -> return NDStructure.equals(this, other) + is NDStructure<*> -> NDStructure.contentEquals(this, other) else -> false } } - public override fun hashCode(): Int { - var result = buffer.hashCode() - result = 31 * result + features.hashCode() + override fun hashCode(): Int { + var result = rowNum + result = 31 * result + colNum + result = 31 * result + buffer.hashCode() return result } public override fun toString(): String { return if (rowNum <= 5 && colNum <= 5) - "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" + + "Matrix(rowsNum = $rowNum, colNum = $colNum)\n" + rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer -> buffer.asSequence().joinToString(separator = "\t") { it.toString() } } - else "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)" + else "Matrix(rowsNum = $rowNum, colNum = $colNum)" } + + } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt deleted file mode 100644 index 68272203c..000000000 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt +++ /dev/null @@ -1,83 +0,0 @@ -package kscience.kmath.linear - -import kscience.kmath.operations.Ring -import kscience.kmath.structures.Matrix -import kscience.kmath.structures.Structure2D -import kscience.kmath.structures.asBuffer -import kotlin.math.sqrt - -/** - * A 2d structure plus optional matrix-specific features - */ -public interface FeaturedMatrix : Matrix { - override val shape: IntArray get() = intArrayOf(rowNum, colNum) - public val features: Set - - /** - * Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure. - * - * The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to - * add only those features that are valid. - */ - public fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix - - public companion object -} - -public inline fun Structure2D.Companion.real( - rows: Int, - columns: Int, - initializer: (Int, Int) -> Double, -): BufferMatrix = MatrixContext.real.produce(rows, columns, initializer) - -/** - * Build a square matrix from given elements. - */ -public fun Structure2D.Companion.square(vararg elements: T): FeaturedMatrix { - val size: Int = sqrt(elements.size.toDouble()).toInt() - require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" } - val buffer = elements.asBuffer() - return BufferMatrix(size, size, buffer) -} - -public val Matrix<*>.features: Set get() = (this as? FeaturedMatrix)?.features ?: emptySet() - -/** - * Check if matrix has the given feature class - */ -public inline fun Matrix<*>.hasFeature(): Boolean = - features.find { it is T } != null - -/** - * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria - */ -public inline fun Matrix<*>.getFeature(): T? = - features.filterIsInstance().firstOrNull() - -/** - * Diagonal matrix of ones. The matrix is virtual no actual matrix is created - */ -public fun > GenericMatrixContext.one(rows: Int, columns: Int): FeaturedMatrix = - VirtualMatrix(rows, columns, DiagonalFeature) { i, j -> - if (i == j) elementContext.one else elementContext.zero - } - - -/** - * A virtual matrix of zeroes - */ -public fun > GenericMatrixContext.zero(rows: Int, columns: Int): FeaturedMatrix = - VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } - -public class TransposedFeature(public val original: Matrix) : MatrixFeature - -/** - * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` - */ -public fun Matrix.transpose(): Matrix { - return getFeature>()?.original ?: VirtualMatrix( - colNum, - rowNum, - setOf(TransposedFeature(this)) - ) { i, j -> get(j, i) } -} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt similarity index 78% rename from kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt index 099fa1909..5cf7c8f70 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt @@ -1,31 +1,31 @@ package kscience.kmath.linear +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.* import kscience.kmath.structures.* /** - * Common implementation of [LUPDecompositionFeature] + * Common implementation of [LupDecompositionFeature]. */ -public class LUPDecomposition( - public val context: MatrixContext>, +public class LupDecomposition( + public val context: MatrixContext>, public val elementContext: Field, - public val lu: Structure2D, + public val lu: Matrix, public val pivot: IntArray, private val even: Boolean, -) : LUPDecompositionFeature, DeterminantFeature { - +) : LupDecompositionFeature, DeterminantFeature { /** * Returns the matrix L of the decomposition. * * L is a lower-triangular matrix with [Ring.one] in diagonal */ - override val l: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j -> + override val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> when { j < i -> lu[i, j] j == i -> elementContext.one else -> elementContext.zero } - } + } + LFeature /** @@ -33,9 +33,9 @@ public class LUPDecomposition( * * U is an upper-triangular matrix including the diagonal */ - override val u: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j -> + override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j >= i) lu[i, j] else elementContext.zero - } + } + UFeature /** * Returns the P rows permutation matrix. @@ -43,7 +43,7 @@ public class LUPDecomposition( * P is a sparse matrix with exactly one element set to [Ring.one] in * each row and each column, all other elements being set to [Ring.zero]. */ - override val p: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + override val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j == pivot[i]) elementContext.one else elementContext.zero } @@ -64,12 +64,12 @@ internal fun , F : Field> GenericMatrixContext.abs /** * Create a lup decomposition of generic matrix. */ -public fun > MatrixContext>.lup( +public fun > MatrixContext>.lup( factory: MutableBufferFactory, elementContext: Field, matrix: Matrix, checkSingular: (T) -> Boolean, -): LUPDecomposition { +): LupDecomposition { require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" } val m = matrix.colNum val pivot = IntArray(matrix.rowNum) @@ -138,20 +138,23 @@ public fun > MatrixContext>.lup( for (row in col + 1 until m) lu[row, col] /= luDiag } - return LUPDecomposition(this@lup, elementContext, lu.collect(), pivot, even) + return LupDecomposition(this@lup, elementContext, lu.collect(), pivot, even) } } } -public inline fun , F : Field> GenericMatrixContext>.lup( +public inline fun , F : Field> GenericMatrixContext>.lup( matrix: Matrix, noinline checkSingular: (T) -> Boolean, -): LUPDecomposition = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular) +): LupDecomposition = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular) -public fun MatrixContext>.lup(matrix: Matrix): LUPDecomposition = +public fun MatrixContext>.lup(matrix: Matrix): LupDecomposition = lup(Buffer.Companion::real, RealField, matrix) { it < 1e-11 } -public fun LUPDecomposition.solveWithLUP(factory: MutableBufferFactory, matrix: Matrix): FeaturedMatrix { +public fun LupDecomposition.solveWithLUP( + factory: MutableBufferFactory, + matrix: Matrix, +): Matrix { require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" } BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run { @@ -196,34 +199,41 @@ public fun LUPDecomposition.solveWithLUP(factory: MutableBufferFact } } -public inline fun LUPDecomposition.solveWithLUP(matrix: Matrix): Matrix = +public inline fun LupDecomposition.solveWithLUP(matrix: Matrix): Matrix = solveWithLUP(MutableBuffer.Companion::auto, matrix) /** * Solve a linear equation **a*x = b** using LUP decomposition */ -public inline fun , F : Field> GenericMatrixContext>.solveWithLUP( +@OptIn(UnstableKMathAPI::class) +public inline fun , F : Field> GenericMatrixContext>.solveWithLUP( a: Matrix, b: Matrix, noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, noinline checkSingular: (T) -> Boolean, -): FeaturedMatrix { +): Matrix { // Use existing decomposition if it is provided by matrix val decomposition = a.getFeature() ?: lup(bufferFactory, elementContext, a, checkSingular) return decomposition.solveWithLUP(bufferFactory, b) } -public fun RealMatrixContext.solveWithLUP(a: Matrix, b: Matrix): FeaturedMatrix = - solveWithLUP(a, b) { it < 1e-11 } - -public inline fun , F : Field> GenericMatrixContext>.inverseWithLUP( +public inline fun , F : Field> GenericMatrixContext>.inverseWithLUP( matrix: Matrix, noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, noinline checkSingular: (T) -> Boolean, -): FeaturedMatrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) +): Matrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) + + +@OptIn(UnstableKMathAPI::class) +public fun RealMatrixContext.solveWithLUP(a: Matrix, b: Matrix): Matrix { + // Use existing decomposition if it is provided by matrix + val bufferFactory: MutableBufferFactory = MutableBuffer.Companion::real + val decomposition: LupDecomposition = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 } + return decomposition.solveWithLUP(bufferFactory, b) +} /** * Inverses a square matrix using LUP decomposition. Non square matrix will throw a error. */ -public fun RealMatrixContext.inverseWithLUP(matrix: Matrix): FeaturedMatrix = - solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), Buffer.Companion::real) { it < 1e-11 } +public fun RealMatrixContext.inverseWithLUP(matrix: Matrix): Matrix = + solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum)) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt index 91c1ec824..c0c209248 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt @@ -1,12 +1,9 @@ package kscience.kmath.linear -import kscience.kmath.structures.Buffer -import kscience.kmath.structures.BufferFactory -import kscience.kmath.structures.Structure2D -import kscience.kmath.structures.asBuffer +import kscience.kmath.structures.* public class MatrixBuilder(public val rows: Int, public val columns: Int) { - public operator fun invoke(vararg elements: T): FeaturedMatrix { + public operator fun invoke(vararg elements: T): Matrix { require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" } val buffer = elements.asBuffer() return BufferMatrix(rows, columns, buffer) @@ -17,7 +14,7 @@ public class MatrixBuilder(public val rows: Int, public val columns: Int) { public fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) -public fun Structure2D.Companion.row(vararg values: T): FeaturedMatrix { +public fun Structure2D.Companion.row(vararg values: T): Matrix { val buffer = values.asBuffer() return BufferMatrix(1, values.size, buffer) } @@ -26,12 +23,12 @@ public inline fun Structure2D.Companion.row( size: Int, factory: BufferFactory = Buffer.Companion::auto, noinline builder: (Int) -> T -): FeaturedMatrix { +): Matrix { val buffer = factory(size, builder) return BufferMatrix(1, size, buffer) } -public fun Structure2D.Companion.column(vararg values: T): FeaturedMatrix { +public fun Structure2D.Companion.column(vararg values: T): Matrix { val buffer = values.asBuffer() return BufferMatrix(values.size, 1, buffer) } @@ -40,7 +37,7 @@ public inline fun Structure2D.Companion.column( size: Int, factory: BufferFactory = Buffer.Companion::auto, noinline builder: (Int) -> T -): FeaturedMatrix { +): Matrix { val buffer = factory(size, builder) return BufferMatrix(size, 1, buffer) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt index d9dc57b0f..59a41f840 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt @@ -18,11 +18,17 @@ public interface MatrixContext> : SpaceOperations T): M + /** + * Produce a point compatible with matrix space (and possibly optimized for it) + */ + public fun point(size: Int, initializer: (Int) -> T): Point = Buffer.boxing(size, initializer) + @Suppress("UNCHECKED_CAST") - public override fun binaryOperation(operation: String, left: Matrix, right: Matrix): M = when (operation) { - "dot" -> left dot right - else -> super.binaryOperation(operation, left, right) as M - } + public override fun binaryOperationFunction(operation: String): (left: Matrix, right: Matrix) -> M = + when (operation) { + "dot" -> { left, right -> left dot right } + else -> super.binaryOperationFunction(operation) as (Matrix, Matrix) -> M + } /** * Computes the dot product of this matrix and another one. @@ -61,10 +67,6 @@ public interface MatrixContext> : SpaceOperations): M = m * this public companion object { - /** - * Non-boxing double matrix - */ - public val real: RealMatrixContext = RealMatrixContext /** * A structured matrix with custom buffer @@ -88,11 +90,6 @@ public interface GenericMatrixContext, out M : Matrix> : */ public val elementContext: R - /** - * Produce a point compatible with matrix space - */ - public fun point(size: Int, initializer: (Int) -> T): Point - public override infix fun Matrix.dot(other: Matrix): M { //TODO add typed error require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } @@ -136,8 +133,6 @@ public interface GenericMatrixContext, out M : Matrix> : public override fun multiply(a: Matrix, k: Number): M = produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } } - public operator fun Number.times(matrix: FeaturedMatrix): M = multiply(matrix, this) - public override operator fun Matrix.times(value: T): M = produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt index a82032e50..e61feec6c 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt @@ -1,62 +1,158 @@ package kscience.kmath.linear +import kscience.kmath.structures.Matrix + /** - * A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix - * operations performance in some cases. + * A marker interface representing some properties of matrices or additional transformations of them. Features are used + * to optimize matrix operations performance in some cases or retrieve the APIs. */ public interface MatrixFeature /** - * The matrix with this feature is considered to have only diagonal non-null elements + * Matrices with this feature are considered to have only diagonal non-null elements. */ -public object DiagonalFeature : MatrixFeature - -/** - * Matrix with this feature has all zero elements - */ -public object ZeroFeature : MatrixFeature - -/** - * Matrix with this feature have unit elements on diagonal and zero elements in all other places - */ -public object UnitFeature : MatrixFeature - -/** - * Inverted matrix feature - */ -public interface InverseMatrixFeature : MatrixFeature { - public val inverse: FeaturedMatrix +public interface DiagonalFeature : MatrixFeature{ + public companion object: DiagonalFeature } /** - * A determinant container + * Matrices with this feature have all zero elements. + */ +public object ZeroFeature : DiagonalFeature + +/** + * Matrices with this feature have unit elements on diagonal and zero elements in all other places. + */ +public object UnitFeature : DiagonalFeature + +/** + * Matrices with this feature can be inverted: [inverse] = `a`-1 where `a` is the owning matrix. + * + * @param T the type of matrices' items. + */ +public interface InverseMatrixFeature : MatrixFeature { + /** + * The inverse matrix of the matrix that owns this feature. + */ + public val inverse: Matrix +} + +/** + * Matrices with this feature can compute their determinant. */ public interface DeterminantFeature : MatrixFeature { + /** + * The determinant of the matrix that owns this feature. + */ public val determinant: T } +/** + * Produces a [DeterminantFeature] where the [DeterminantFeature.determinant] is [determinant]. + * + * @param determinant the value of determinant. + * @return a new [DeterminantFeature]. + */ @Suppress("FunctionName") public fun DeterminantFeature(determinant: T): DeterminantFeature = object : DeterminantFeature { override val determinant: T = determinant } /** - * Lower triangular matrix + * Matrices with this feature are lower triangular ones. */ public object LFeature : MatrixFeature /** - * Upper triangular feature + * Matrices with this feature are upper triangular ones. */ public object UFeature : MatrixFeature /** - * TODO add documentation + * Matrices with this feature support LU factorization with partial pivoting: *[p] · a = [l] · [u]* where + * *a* is the owning matrix. + * + * @param T the type of matrices' items. */ -public interface LUPDecompositionFeature : MatrixFeature { - public val l: FeaturedMatrix - public val u: FeaturedMatrix - public val p: FeaturedMatrix +public interface LupDecompositionFeature : MatrixFeature { + /** + * The lower triangular matrix in this decomposition. It may have [LFeature]. + */ + public val l: Matrix + + /** + * The upper triangular matrix in this decomposition. It may have [UFeature]. + */ + public val u: Matrix + + /** + * The permutation matrix in this decomposition. + */ + public val p: Matrix +} + +/** + * Matrices with this feature are orthogonal ones: *a · aT = u* where *a* is the owning matrix, *u* + * is the unit matrix ([UnitFeature]). + */ +public object OrthogonalFeature : MatrixFeature + +/** + * Matrices with this feature support QR factorization: *a = [q] · [r]* where *a* is the owning matrix. + * + * @param T the type of matrices' items. + */ +public interface QRDecompositionFeature : MatrixFeature { + /** + * The orthogonal matrix in this decomposition. It may have [OrthogonalFeature]. + */ + public val q: Matrix + + /** + * The upper triangular matrix in this decomposition. It may have [UFeature]. + */ + public val r: Matrix +} + +/** + * Matrices with this feature support Cholesky factorization: *a = [l] · [l]H* where *a* is the + * owning matrix. + * + * @param T the type of matrices' items. + */ +public interface CholeskyDecompositionFeature : MatrixFeature { + /** + * The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature]. + */ + public val l: Matrix +} + +/** + * Matrices with this feature support SVD: *a = [u] · [s] · [v]H* where *a* is the owning + * matrix. + * + * @param T the type of matrices' items. + */ +public interface SingularValueDecompositionFeature : MatrixFeature { + /** + * The matrix in this decomposition. It is unitary, and it consists from left singular vectors. + */ + public val u: Matrix + + /** + * The matrix in this decomposition. Its main diagonal elements are singular values. + */ + public val s: Matrix + + /** + * The matrix in this decomposition. It is unitary, and it consists from right singular vectors. + */ + public val v: Matrix + + /** + * The buffer of singular values of this SVD. + */ + public val singularValues: Point } //TODO add sparse matrix feature diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt new file mode 100644 index 000000000..362db1fe7 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt @@ -0,0 +1,105 @@ +package kscience.kmath.linear + +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.Ring +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.Structure2D +import kscience.kmath.structures.asBuffer +import kscience.kmath.structures.getFeature +import kotlin.math.sqrt +import kotlin.reflect.KClass +import kotlin.reflect.safeCast + +/** + * A [Matrix] that holds [MatrixFeature] objects. + * + * @param T the type of items. + */ +public class MatrixWrapper internal constructor( + public val origin: Matrix, + public val features: Set, +) : Matrix by origin { + + /** + * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria + */ + @UnstableKMathAPI + override fun getFeature(type: KClass): T? = type.safeCast(features.find { type.isInstance(it) }) + ?: origin.getFeature(type) + + override fun equals(other: Any?): Boolean = origin == other + override fun hashCode(): Int = origin.hashCode() + override fun toString(): String { + return "MatrixWrapper(matrix=$origin, features=$features)" + } +} + +/** + * Return the original matrix. If this is a wrapper, return its origin. If not, this matrix. + * Origin does not necessary store all features. + */ +@UnstableKMathAPI +public val Matrix.origin: Matrix get() = (this as? MatrixWrapper)?.origin ?: this + +/** + * Add a single feature to a [Matrix] + */ +public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixWrapper = if (this is MatrixWrapper) { + MatrixWrapper(origin, features + newFeature) +} else { + MatrixWrapper(this, setOf(newFeature)) +} + +/** + * Add a collection of features to a [Matrix] + */ +public operator fun Matrix.plus(newFeatures: Collection): MatrixWrapper = + if (this is MatrixWrapper) { + MatrixWrapper(origin, features + newFeatures) + } else { + MatrixWrapper(this, newFeatures.toSet()) + } + +public inline fun Structure2D.Companion.real( + rows: Int, + columns: Int, + initializer: (Int, Int) -> Double, +): BufferMatrix = MatrixContext.real.produce(rows, columns, initializer) + +/** + * Build a square matrix from given elements. + */ +public fun Structure2D.Companion.square(vararg elements: T): Matrix { + val size: Int = sqrt(elements.size.toDouble()).toInt() + require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" } + val buffer = elements.asBuffer() + return BufferMatrix(size, size, buffer) +} + +/** + * Diagonal matrix of ones. The matrix is virtual no actual matrix is created + */ +public fun > GenericMatrixContext.one(rows: Int, columns: Int): Matrix = + VirtualMatrix(rows, columns) { i, j -> + if (i == j) elementContext.one else elementContext.zero + } + UnitFeature + + +/** + * A virtual matrix of zeroes + */ +public fun > GenericMatrixContext.zero(rows: Int, columns: Int): Matrix = + VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } + ZeroFeature + +public class TransposedFeature(public val original: Matrix) : MatrixFeature + +/** + * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` + */ +@OptIn(UnstableKMathAPI::class) +public fun Matrix.transpose(): Matrix { + return getFeature>()?.original ?: VirtualMatrix( + colNum, + rowNum, + ) { i, j -> get(j, i) } + TransposedFeature(this) +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt new file mode 100644 index 000000000..8e197672f --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt @@ -0,0 +1,68 @@ +package kscience.kmath.linear + +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.RealBuffer + +@Suppress("OVERRIDE_BY_INLINE") +public object RealMatrixContext : MatrixContext> { + + public override inline fun produce( + rows: Int, + columns: Int, + initializer: (i: Int, j: Int) -> Double, + ): BufferMatrix { + val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } + return BufferMatrix(rows, columns, buffer) + } + + private fun Matrix.wrap(): BufferMatrix = if (this is BufferMatrix) this else { + produce(rowNum, colNum) { i, j -> get(i, j) } + } + + public fun one(rows: Int, columns: Int): Matrix = VirtualMatrix(rows, columns) { i, j -> + if (i == j) 1.0 else 0.0 + } + DiagonalFeature + + public override infix fun Matrix.dot(other: Matrix): BufferMatrix { + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } + return produce(rowNum, other.colNum) { i, j -> + var res = 0.0 + for (l in 0 until colNum) { + res += get(i, l) * other.get(l, j) + } + res + } + } + + public override infix fun Matrix.dot(vector: Point): Point { + require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" } + return RealBuffer(rowNum) { i -> + var res = 0.0 + for (j in 0 until colNum) { + res += get(i, j) * vector[j] + } + res + } + } + + override fun add(a: Matrix, b: Matrix): BufferMatrix { + require(a.rowNum == b.rowNum) { "Row number mismatch in matrix addition. Left side: ${a.rowNum}, right side: ${b.rowNum}" } + require(a.colNum == b.colNum) { "Column number mismatch in matrix addition. Left side: ${a.colNum}, right side: ${b.colNum}" } + return produce(a.rowNum, a.colNum) { i, j -> + a[i, j] + b[i, j] + } + } + + override fun Matrix.times(value: Double): BufferMatrix = + produce(rowNum, colNum) { i, j -> get(i, j) * value } + + + override fun multiply(a: Matrix, k: Number): BufferMatrix = + produce(a.rowNum, a.colNum) { i, j -> a[i, j] * k.toDouble() } +} + + +/** + * Partially optimized real-valued matrix + */ +public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt index e0a1d0026..0269a64d1 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt @@ -5,31 +5,16 @@ import kscience.kmath.structures.Matrix public class VirtualMatrix( override val rowNum: Int, override val colNum: Int, - override val features: Set = emptySet(), public val generator: (i: Int, j: Int) -> T -) : FeaturedMatrix { - public constructor( - rowNum: Int, - colNum: Int, - vararg features: MatrixFeature, - generator: (i: Int, j: Int) -> T - ) : this( - rowNum, - colNum, - setOf(*features), - generator - ) +) : Matrix { override val shape: IntArray get() = intArrayOf(rowNum, colNum) override operator fun get(i: Int, j: Int): T = generator(i, j) - override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix = - VirtualMatrix(rowNum, colNum, this.features + features, generator) - override fun equals(other: Any?): Boolean { if (this === other) return true - if (other !is FeaturedMatrix<*>) return false + if (other !is Matrix<*>) return false if (rowNum != other.rowNum) return false if (colNum != other.colNum) return false @@ -40,21 +25,9 @@ public class VirtualMatrix( override fun hashCode(): Int { var result = rowNum result = 31 * result + colNum - result = 31 * result + features.hashCode() result = 31 * result + generator.hashCode() return result } - public companion object { - /** - * Wrap a matrix adding additional features to it - */ - public fun wrap(matrix: Matrix, vararg features: MatrixFeature): FeaturedMatrix { - return if (matrix is VirtualMatrix) - VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator) - else - VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j -> matrix[i, j] } - } - } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt index 12a45615a..e7eb2770d 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt @@ -13,50 +13,86 @@ public annotation class KMathContext */ public interface Algebra { /** - * Wrap raw string or variable + * Wraps a raw string to [T] object. This method is designed for three purposes: + * + * 1. Mathematical constants (`e`, `pi`). + * 2. Variables for expression-like contexts (`a`, `b`, `c`...). + * 3. Literals (`{1, 2}`, (`(3; 4)`)). + * + * In case if algebra can't parse the string, this method must throw [kotlin.IllegalStateException]. + * + * @param value the raw string. + * @return an object. */ public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") /** - * Dynamic call of unary operation with name [operation] on [arg] + * Dynamically dispatches an unary operation with the certain name. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second `unaryOperation` overload: + * i.e. `unaryOperationFunction(a)(b) == unaryOperation(a, b)`. + * + * @param operation the name of operation. + * @return an operation. */ - public fun unaryOperation(operation: String, arg: T): T + public fun unaryOperationFunction(operation: String): (arg: T) -> T = + error("Unary operation $operation not defined in $this") /** - * Dynamic call of binary operation [operation] on [left] and [right] + * Dynamically invokes an unary operation with the certain name. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second [unaryOperationFunction] overload: + * i.e. `unaryOperationFunction(a)(b) == unaryOperation(a, b)`. + * + * @param operation the name of operation. + * @param arg the argument of operation. + * @return a result of operation. */ - public fun binaryOperation(operation: String, left: T, right: T): T -} - -/** - * An algebraic structure where elements can have numeric representation. - * - * @param T the type of element of this structure. - */ -public interface NumericAlgebra : Algebra { - /** - * Wraps a number. - */ - public fun number(value: Number): T + public fun unaryOperation(operation: String, arg: T): T = unaryOperationFunction(operation)(arg) /** - * Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number]. + * Dynamically dispatches a binary operation with the certain name. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second [binaryOperationFunction] overload: + * i.e. `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`. + * + * @param operation the name of operation. + * @return an operation. */ - public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = - binaryOperation(operation, number(left), right) + public fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = + error("Binary operation $operation not defined in $this") /** - * Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number]. + * Dynamically invokes a binary operation with the certain name. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second [binaryOperationFunction] overload: + * i.e. `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`. + * + * @param operation the name of operation. + * @param left the first argument of operation. + * @param right the second argument of operation. + * @return a result of operation. */ - public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = - leftSideNumberOperation(operation, right, left) + public fun binaryOperation(operation: String, left: T, right: T): T = binaryOperationFunction(operation)(left, right) } /** * Call a block with an [Algebra] as receiver. */ // TODO add contract when KT-32313 is fixed -public inline operator fun , R> A.invoke(block: A.() -> R): R = block() +public inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) /** * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as @@ -146,26 +182,26 @@ public interface SpaceOperations : Algebra { */ public operator fun Number.times(b: T): T = b * this - override fun unaryOperation(operation: String, arg: T): T = when (operation) { - PLUS_OPERATION -> arg - MINUS_OPERATION -> -arg - else -> error("Unary operation $operation not defined in $this") + public override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { + PLUS_OPERATION -> { arg -> arg } + MINUS_OPERATION -> { arg -> -arg } + else -> super.unaryOperationFunction(operation) } - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - PLUS_OPERATION -> add(left, right) - MINUS_OPERATION -> left - right - else -> error("Binary operation $operation not defined in $this") + public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { + PLUS_OPERATION -> ::add + MINUS_OPERATION -> { left, right -> left - right } + else -> super.binaryOperationFunction(operation) } public companion object { /** - * The identifier of addition. + * The identifier of addition and unary positive operator. */ public const val PLUS_OPERATION: String = "+" /** - * The identifier of subtraction (and negation). + * The identifier of subtraction and unary negative operator. */ public const val MINUS_OPERATION: String = "-" } @@ -207,9 +243,9 @@ public interface RingOperations : SpaceOperations { */ public operator fun T.times(b: T): T = multiply(this, b) - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - TIMES_OPERATION -> multiply(left, right) - else -> super.binaryOperation(operation, left, right) + public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { + TIMES_OPERATION -> ::multiply + else -> super.binaryOperationFunction(operation) } public companion object { @@ -226,61 +262,11 @@ public interface RingOperations : SpaceOperations { * * @param T the type of element of this ring. */ -public interface Ring : Space, RingOperations, NumericAlgebra { +public interface Ring : Space, RingOperations { /** * neutral operation for multiplication */ public val one: T - - override fun number(value: Number): T = one * value.toDouble() - - override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { - SpaceOperations.PLUS_OPERATION -> left + right - SpaceOperations.MINUS_OPERATION -> left - right - RingOperations.TIMES_OPERATION -> left * right - else -> super.leftSideNumberOperation(operation, left, right) - } - - override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { - SpaceOperations.PLUS_OPERATION -> left + right - SpaceOperations.MINUS_OPERATION -> left - right - RingOperations.TIMES_OPERATION -> left * right - else -> super.rightSideNumberOperation(operation, left, right) - } - - /** - * Addition of element and scalar. - * - * @receiver the addend. - * @param b the augend. - */ - public operator fun T.plus(b: Number): T = this + number(b) - - /** - * Addition of scalar and element. - * - * @receiver the addend. - * @param b the augend. - */ - public operator fun Number.plus(b: T): T = b + this - - /** - * Subtraction of element from number. - * - * @receiver the minuend. - * @param b the subtrahend. - * @receiver the difference. - */ - public operator fun T.minus(b: Number): T = this - number(b) - - /** - * Subtraction of number from element. - * - * @receiver the minuend. - * @param b the subtrahend. - * @receiver the difference. - */ - public operator fun Number.minus(b: T): T = -b + this } /** @@ -308,9 +294,9 @@ public interface FieldOperations : RingOperations { */ public operator fun T.div(b: T): T = divide(this, b) - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - DIV_OPERATION -> divide(left, right) - else -> super.binaryOperation(operation, left, right) + public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { + DIV_OPERATION -> ::divide + else -> super.binaryOperationFunction(operation) } public companion object { diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt index 7af207edc..ec64282d5 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt @@ -1,5 +1,6 @@ package kscience.kmath.operations +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.BigInt.Companion.BASE import kscience.kmath.operations.BigInt.Companion.BASE_SIZE import kscience.kmath.structures.* @@ -16,7 +17,8 @@ public typealias TBase = ULong * * @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai) */ -public object BigIntField : Field { +@OptIn(UnstableKMathAPI::class) +public object BigIntField : Field, RingWithNumbers { override val zero: BigInt = BigInt.ZERO override val one: BigInt = BigInt.ONE diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt index 703931c7c..c6409c015 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt @@ -3,6 +3,7 @@ package kscience.kmath.operations import kscience.kmath.memory.MemoryReader import kscience.kmath.memory.MemorySpec import kscience.kmath.memory.MemoryWriter +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Buffer import kscience.kmath.structures.MemoryBuffer import kscience.kmath.structures.MutableBuffer @@ -41,7 +42,8 @@ private val PI_DIV_2 = Complex(PI / 2, 0) /** * A field of [Complex]. */ -public object ComplexField : ExtendedField, Norm { +@OptIn(UnstableKMathAPI::class) +public object ComplexField : ExtendedField, Norm, RingWithNumbers { override val zero: Complex = 0.0.toComplex() override val one: Complex = 1.0.toComplex() @@ -156,7 +158,7 @@ public object ComplexField : ExtendedField, Norm { override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) - override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) + override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) } /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt new file mode 100644 index 000000000..26f93fae8 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt @@ -0,0 +1,125 @@ +package kscience.kmath.operations + +import kscience.kmath.misc.UnstableKMathAPI + +/** + * An algebraic structure where elements can have numeric representation. + * + * @param T the type of element of this structure. + */ +public interface NumericAlgebra : Algebra { + /** + * Wraps a number to [T] object. + * + * @param value the number to wrap. + * @return an object. + */ + public fun number(value: Number): T + + /** + * Dynamically dispatches a binary operation with the certain name with numeric first argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with the other [leftSideNumberOperation] overload: + * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b)`. + * + * @param operation the name of operation. + * @return an operation. + */ + public fun leftSideNumberOperationFunction(operation: String): (left: Number, right: T) -> T = + { l, r -> binaryOperationFunction(operation)(number(l), r) } + + /** + * Dynamically invokes a binary operation with the certain name with numeric first argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second [leftSideNumberOperation] overload: + * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. + * + * @param operation the name of operation. + * @param left the first argument of operation. + * @param right the second argument of operation. + * @return a result of operation. + */ + public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = + leftSideNumberOperationFunction(operation)(left, right) + + /** + * Dynamically dispatches a binary operation with the certain name with numeric first argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: + * i.e. `rightSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. + * + * @param operation the name of operation. + * @return an operation. + */ + public fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = + { l, r -> binaryOperationFunction(operation)(l, number(r)) } + + /** + * Dynamically invokes a binary operation with the certain name with numeric second argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: + * i.e. `rightSideNumberOperationFunction(a)(b, c) == rightSideNumberOperation(a, b, c)`. + * + * @param operation the name of operation. + * @param left the first argument of operation. + * @param right the second argument of operation. + * @return a result of operation. + */ + public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = + rightSideNumberOperationFunction(operation)(left, right) +} + +/** + * A combination of [NumericAlgebra] and [Ring] that adds intrinsic simple operations on numbers like `T+1` + * TODO to be removed and replaced by extensions after multiple receivers are there + */ +@UnstableKMathAPI +public interface RingWithNumbers: Ring, NumericAlgebra{ + public override fun number(value: Number): T = one * value + + /** + * Addition of element and scalar. + * + * @receiver the addend. + * @param b the augend. + */ + public operator fun T.plus(b: Number): T = this + number(b) + + /** + * Addition of scalar and element. + * + * @receiver the addend. + * @param b the augend. + */ + public operator fun Number.plus(b: T): T = b + this + + /** + * Subtraction of element from number. + * + * @receiver the minuend. + * @param b the subtrahend. + * @receiver the difference. + */ + public operator fun T.minus(b: Number): T = this - number(b) + + /** + * Subtraction of number from element. + * + * @receiver the minuend. + * @param b the subtrahend. + * @receiver the difference. + */ + public operator fun Number.minus(b: T): T = -b + this +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt similarity index 79% rename from kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt index a2b33a0c4..0440d74e8 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt @@ -1,6 +1,5 @@ package kscience.kmath.operations -import kotlin.math.abs import kotlin.math.pow as kpow /** @@ -15,30 +14,30 @@ public interface ExtendedFieldOperations : public override fun tan(arg: T): T = sin(arg) / cos(arg) public override fun tanh(arg: T): T = sinh(arg) / cosh(arg) - public override fun unaryOperation(operation: String, arg: T): T = when (operation) { - TrigonometricOperations.COS_OPERATION -> cos(arg) - TrigonometricOperations.SIN_OPERATION -> sin(arg) - TrigonometricOperations.TAN_OPERATION -> tan(arg) - TrigonometricOperations.ACOS_OPERATION -> acos(arg) - TrigonometricOperations.ASIN_OPERATION -> asin(arg) - TrigonometricOperations.ATAN_OPERATION -> atan(arg) - HyperbolicOperations.COSH_OPERATION -> cosh(arg) - HyperbolicOperations.SINH_OPERATION -> sinh(arg) - HyperbolicOperations.TANH_OPERATION -> tanh(arg) - HyperbolicOperations.ACOSH_OPERATION -> acosh(arg) - HyperbolicOperations.ASINH_OPERATION -> asinh(arg) - HyperbolicOperations.ATANH_OPERATION -> atanh(arg) - PowerOperations.SQRT_OPERATION -> sqrt(arg) - ExponentialOperations.EXP_OPERATION -> exp(arg) - ExponentialOperations.LN_OPERATION -> ln(arg) - else -> super.unaryOperation(operation, arg) + public override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { + TrigonometricOperations.COS_OPERATION -> ::cos + TrigonometricOperations.SIN_OPERATION -> ::sin + TrigonometricOperations.TAN_OPERATION -> ::tan + TrigonometricOperations.ACOS_OPERATION -> ::acos + TrigonometricOperations.ASIN_OPERATION -> ::asin + TrigonometricOperations.ATAN_OPERATION -> ::atan + HyperbolicOperations.COSH_OPERATION -> ::cosh + HyperbolicOperations.SINH_OPERATION -> ::sinh + HyperbolicOperations.TANH_OPERATION -> ::tanh + HyperbolicOperations.ACOSH_OPERATION -> ::acosh + HyperbolicOperations.ASINH_OPERATION -> ::asinh + HyperbolicOperations.ATANH_OPERATION -> ::atanh + PowerOperations.SQRT_OPERATION -> ::sqrt + ExponentialOperations.EXP_OPERATION -> ::exp + ExponentialOperations.LN_OPERATION -> ::ln + else -> super.unaryOperationFunction(operation) } } /** * Advanced Number-like field that implements basic operations. */ -public interface ExtendedField : ExtendedFieldOperations, Field { +public interface ExtendedField : ExtendedFieldOperations, Field, NumericAlgebra { public override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2 public override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2 public override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) @@ -46,10 +45,11 @@ public interface ExtendedField : ExtendedFieldOperations, Field { public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2 - public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { - PowerOperations.POW_OPERATION -> power(left, right) - else -> super.rightSideNumberOperation(operation, left, right) - } + public override fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.rightSideNumberOperationFunction(operation) + } } /** @@ -80,10 +80,13 @@ public object RealField : ExtendedField, Norm { public override val one: Double get() = 1.0 - public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { - PowerOperations.POW_OPERATION -> left pow right - else -> super.binaryOperation(operation, left, right) - } + override fun number(value: Number): Double = value.toDouble() + + public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.binaryOperationFunction(operation) + } public override inline fun add(a: Double, b: Double): Double = a + b public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() @@ -130,10 +133,13 @@ public object FloatField : ExtendedField, Norm { public override val one: Float get() = 1.0f - public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) { - PowerOperations.POW_OPERATION -> left pow right - else -> super.binaryOperation(operation, left, right) - } + override fun number(value: Number): Float = value.toFloat() + + public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.binaryOperationFunction(operation) + } public override inline fun add(a: Float, b: Float): Float = a + b public override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() @@ -173,13 +179,15 @@ public object FloatField : ExtendedField, Norm { * A field for [Int] without boxing. Does not produce corresponding ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public object IntRing : Ring, Norm { +public object IntRing : Ring, Norm, NumericAlgebra { public override val zero: Int get() = 0 public override val one: Int get() = 1 + override fun number(value: Number): Int = value.toInt() + public override inline fun add(a: Int, b: Int): Int = a + b public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a @@ -197,13 +205,15 @@ public object IntRing : Ring, Norm { * A field for [Short] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public object ShortRing : Ring, Norm { +public object ShortRing : Ring, Norm, NumericAlgebra { public override val zero: Short get() = 0 public override val one: Short get() = 1 + override fun number(value: Number): Short = value.toShort() + public override inline fun add(a: Short, b: Short): Short = (a + b).toShort() public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() @@ -221,13 +231,15 @@ public object ShortRing : Ring, Norm { * A field for [Byte] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public object ByteRing : Ring, Norm { +public object ByteRing : Ring, Norm, NumericAlgebra { public override val zero: Byte get() = 0 public override val one: Byte get() = 1 + override fun number(value: Number): Byte = value.toByte() + public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() @@ -245,12 +257,14 @@ public object ByteRing : Ring, Norm { * A field for [Double] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public object LongRing : Ring, Norm { +public object LongRing : Ring, Norm, NumericAlgebra { public override val zero: Long - get() = 0 + get() = 0L public override val one: Long - get() = 1 + get() = 1L + + override fun number(value: Number): Long = value.toLong() public override inline fun add(a: Long, b: Long): Long = a + b public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt index f1f1074e5..6de69cabe 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt @@ -1,9 +1,7 @@ package kscience.kmath.structures -import kscience.kmath.operations.Complex -import kscience.kmath.operations.ComplexField -import kscience.kmath.operations.FieldElement -import kscience.kmath.operations.complex +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.* import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -12,15 +10,22 @@ public typealias ComplexNDElement = BufferedNDFieldElement, - ExtendedNDField> { + ExtendedNDField>, + RingWithNumbers>{ override val strides: Strides = DefaultStrides(shape) override val elementContext: ComplexField get() = ComplexField override val zero: ComplexNDElement by lazy { produce { zero } } override val one: ComplexNDElement by lazy { produce { one } } + override fun number(value: Number): NDBuffer { + val c = value.toComplex() + return produce { c } + } + public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer = Buffer.complex(size) { initializer(it) } @@ -29,7 +34,7 @@ public class ComplexNDField(override val shape: IntArray) : */ override fun map( arg: NDBuffer, - transform: ComplexField.(Complex) -> Complex + transform: ComplexField.(Complex) -> Complex, ): ComplexNDElement { check(arg) val array = buildBuffer(arg.strides.linearSize) { offset -> ComplexField.transform(arg.buffer[offset]) } @@ -43,7 +48,7 @@ public class ComplexNDField(override val shape: IntArray) : override fun mapIndexed( arg: NDBuffer, - transform: ComplexField.(index: IntArray, Complex) -> Complex + transform: ComplexField.(index: IntArray, Complex) -> Complex, ): ComplexNDElement { check(arg) @@ -60,7 +65,7 @@ public class ComplexNDField(override val shape: IntArray) : override fun combine( a: NDBuffer, b: NDBuffer, - transform: ComplexField.(Complex, Complex) -> Complex + transform: ComplexField.(Complex, Complex) -> Complex, ): ComplexNDElement { check(a, b) @@ -141,7 +146,7 @@ public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = Comple public fun NDElement.Companion.complex( vararg shape: Int, - initializer: ComplexField.(IntArray) -> Complex + initializer: ComplexField.(IntArray) -> Complex, ): ComplexNDElement = NDField.complex(*shape).produce(initializer) /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt index 08160adf4..e7d89ca7e 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt @@ -1,5 +1,6 @@ package kscience.kmath.structures +import kscience.kmath.misc.UnstableKMathAPI import kotlin.jvm.JvmName import kotlin.native.concurrent.ThreadLocal import kotlin.reflect.KClass @@ -38,14 +39,22 @@ public interface NDStructure { */ public fun elements(): Sequence> + //force override equality and hash code public override fun equals(other: Any?): Boolean public override fun hashCode(): Int + /** + * Feature is additional property or hint that does not directly affect the structure, but could in some cases help + * optimize operations and performance. If the feature is not present, null is defined. + */ + @UnstableKMathAPI + public fun getFeature(type: KClass): T? = null + public companion object { /** * Indicates whether some [NDStructure] is equal to another one. */ - public fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { + public fun contentEquals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { if (st1 === st2) return true // fast comparison of buffers if possible @@ -120,6 +129,9 @@ public interface NDStructure { */ public operator fun NDStructure.get(vararg index: Int): T = get(index) +@UnstableKMathAPI +public inline fun NDStructure<*>.getFeature(): T? = getFeature(T::class) + /** * Represents mutable [NDStructure]. */ @@ -133,6 +145,9 @@ public interface MutableNDStructure : NDStructure { public operator fun set(index: IntArray, value: T) } +/** + * Transform a structure element-by element in place. + */ public inline fun MutableNDStructure.mapInPlace(action: (IntArray, T) -> T): Unit = elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } @@ -260,7 +275,7 @@ public abstract class NDBuffer : NDStructure { override fun elements(): Sequence> = strides.indices().map { it to this[it] } override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false) } override fun hashCode(): Int { diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt index f7b2ee31e..3f4d15c4d 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt @@ -150,6 +150,8 @@ public class RealBufferField(public val size: Int) : ExtendedField by lazy { RealBuffer(size) { 0.0 } } public override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } + override fun number(value: Number): Buffer = RealBuffer(size) { value.toDouble() } + public override fun add(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.add(a, b) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt index ed28fb9f2..60e6de440 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt @@ -1,13 +1,17 @@ package kscience.kmath.structures +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.FieldElement import kscience.kmath.operations.RealField +import kscience.kmath.operations.RingWithNumbers public typealias RealNDElement = BufferedNDFieldElement +@OptIn(UnstableKMathAPI::class) public class RealNDField(override val shape: IntArray) : BufferedNDField, - ExtendedNDField> { + ExtendedNDField>, + RingWithNumbers> { override val strides: Strides = DefaultStrides(shape) @@ -15,35 +19,36 @@ public class RealNDField(override val shape: IntArray) : override val zero: RealNDElement by lazy { produce { zero } } override val one: RealNDElement by lazy { produce { one } } - public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = - RealBuffer(DoubleArray(size) { initializer(it) }) + override fun number(value: Number): NDBuffer { + val d = value.toDouble() + return produce { d } + } - /** - * Inline transform an NDStructure to - */ - override fun map( + @Suppress("OVERRIDE_BY_INLINE") + override inline fun map( arg: NDBuffer, - transform: RealField.(Double) -> Double + transform: RealField.(Double) -> Double, ): RealNDElement { check(arg) - val array = buildBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) } + val array = RealBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) } return BufferedNDFieldElement(this, array) } - override fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement { - val array = buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) } + @Suppress("OVERRIDE_BY_INLINE") + override inline fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement { + val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) } return BufferedNDFieldElement(this, array) } - override fun mapIndexed( + @Suppress("OVERRIDE_BY_INLINE") + override inline fun mapIndexed( arg: NDBuffer, - transform: RealField.(index: IntArray, Double) -> Double + transform: RealField.(index: IntArray, Double) -> Double, ): RealNDElement { check(arg) - return BufferedNDFieldElement( this, - buildBuffer(arg.strides.linearSize) { offset -> + RealBuffer(arg.strides.linearSize) { offset -> elementContext.transform( arg.strides.index(offset), arg.buffer[offset] @@ -51,15 +56,17 @@ public class RealNDField(override val shape: IntArray) : }) } - override fun combine( + @Suppress("OVERRIDE_BY_INLINE") + override inline fun combine( a: NDBuffer, b: NDBuffer, - transform: RealField.(Double, Double) -> Double + transform: RealField.(Double, Double) -> Double, ): RealNDElement { check(a, b) - return BufferedNDFieldElement( - this, - buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) + val buffer = RealBuffer(strides.linearSize) { offset -> + elementContext.transform(a.buffer[offset], b.buffer[offset]) + } + return BufferedNDFieldElement(this, buffer) } override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt index 99c1de9bf..d20e9e53b 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt @@ -1,12 +1,42 @@ package kscience.kmath.structures /** - * A structure that is guaranteed to be two-dimensional + * A structure that is guaranteed to be two-dimensional. + * + * @param T the type of items. */ public interface Structure2D : NDStructure { - public val rowNum: Int get() = shape[0] - public val colNum: Int get() = shape[1] + /** + * The number of rows in this structure. + */ + public val rowNum: Int + /** + * The number of columns in this structure. + */ + public val colNum: Int + + public override val shape: IntArray get() = intArrayOf(rowNum, colNum) + + /** + * The buffer of rows of this structure. It gets elements from the structure dynamically. + */ + public val rows: Buffer> + get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } } + + /** + * The buffer of columns of this structure. It gets elements from the structure dynamically. + */ + public val columns: Buffer> + get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } } + + /** + * Retrieves an element from the structure by two indices. + * + * @param i the first index. + * @param j the second index. + * @return an element. + */ public operator fun get(i: Int, j: Int): T override operator fun get(index: IntArray): T { @@ -14,15 +44,9 @@ public interface Structure2D : NDStructure { return get(index[0], index[1]) } - public val rows: Buffer> - get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } } - - public val columns: Buffer> - get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } } - override fun elements(): Sequence> = sequence { - for (i in (0 until rowNum)) - for (j in (0 until colNum)) yield(intArrayOf(i, j) to this@Structure2D[i, j]) + for (i in 0 until rowNum) + for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j)) } public companion object @@ -34,7 +58,11 @@ public interface Structure2D : NDStructure { private inline class Structure2DWrapper(val structure: NDStructure) : Structure2D { override val shape: IntArray get() = structure.shape + override val rowNum: Int get() = shape[0] + override val colNum: Int get() = shape[1] + override operator fun get(i: Int, j: Int): T = structure[i, j] + override fun elements(): Sequence> = structure.elements() } @@ -46,4 +74,9 @@ public fun NDStructure.as2D(): Structure2D = if (shape.size == 2) else error("Can't create 2d-structure from ${shape.size}d-structure") +/** + * Alias for [Structure2D] with more familiar name. + * + * @param T the type of items. + */ public typealias Matrix = Structure2D diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt index 0a582e339..d7755dcb5 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt @@ -7,6 +7,7 @@ import kscience.kmath.structures.as2D import kotlin.test.Test import kotlin.test.assertEquals +@Suppress("UNUSED_VARIABLE") class MatrixTest { @Test fun testTranspose() { diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/FieldVerifier.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/FieldVerifier.kt index 1ca09ab0c..89f31c75b 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/FieldVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/FieldVerifier.kt @@ -12,6 +12,8 @@ internal class FieldVerifier(override val algebra: Field, a: T, b: T, c: T super.verify() algebra { + assertEquals(a + b, b + a, "Addition in $algebra is not commutative.") + assertEquals(a * b, b * a, "Multiplication in $algebra is not commutative.") assertNotEquals(a / b, b / a, "Division in $algebra is not anti-commutative.") assertNotEquals((a / b) / c, a / (b / c), "Division in $algebra is associative.") assertEquals((a + b) / c, (a / c) + (b / c), "Division in $algebra is not right-distributive.") diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/RingVerifier.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/RingVerifier.kt index 863169b9b..359ba1701 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/RingVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/RingVerifier.kt @@ -10,7 +10,7 @@ internal open class RingVerifier(override val algebra: Ring, a: T, b: T, c super.verify() algebra { - assertEquals(a * b, a * b, "Multiplication in $algebra is not commutative.") + assertEquals(a + b, b + a, "Addition in $algebra is not commutative.") assertEquals(a * b * c, a * (b * c), "Multiplication in $algebra is not associative.") assertEquals(c * (a + b), (c * a) + (c * b), "Multiplication in $algebra is not distributive.") assertEquals(a * one, one * a, "$one in $algebra is not a neutral multiplication element.") diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/SpaceVerifier.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/SpaceVerifier.kt index 4dc855829..045abb71f 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/SpaceVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/SpaceVerifier.kt @@ -15,7 +15,6 @@ internal open class SpaceVerifier( AlgebraicVerifier> { override fun verify() { algebra { - assertEquals(a + b, b + a, "Addition in $algebra is not commutative.") assertEquals(a + b + c, a + (b + c), "Addition in $algebra is not associative.") assertEquals(x * (a + b), x * a + x * b, "Addition in $algebra is not distributive.") assertEquals((a + b) * x, a * x + b * x, "Addition in $algebra is not distributive.") diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt index 79b56ea4a..1129a8a36 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt @@ -1,10 +1,15 @@ package kscience.kmath.structures +import kscience.kmath.operations.internal.FieldVerifier import kotlin.test.Test import kotlin.test.assertEquals +internal class NDFieldTest { + @Test + fun verify() { + NDField.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) } + } -class NDFieldTest { @Test fun testStrides() { val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() } diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt index 4320821e6..a1d2992d6 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt @@ -8,6 +8,7 @@ import kotlin.math.pow import kotlin.test.Test import kotlin.test.assertEquals +@Suppress("UNUSED_VARIABLE") class NumberNDFieldTest { val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() } val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() } diff --git a/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt index 2f0978237..9bd6a9fc4 100644 --- a/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt @@ -7,7 +7,7 @@ import java.math.MathContext /** * A field over [BigInteger]. */ -public object JBigIntegerField : Field { +public object JBigIntegerField : Field, NumericAlgebra { public override val zero: BigInteger get() = BigInteger.ZERO @@ -28,9 +28,9 @@ public object JBigIntegerField : Field { * * @property mathContext the [MathContext] to use. */ -public abstract class JBigDecimalFieldBase internal constructor(public val mathContext: MathContext = MathContext.DECIMAL64) : - Field, - PowerOperations { +public abstract class JBigDecimalFieldBase internal constructor( + private val mathContext: MathContext = MathContext.DECIMAL64, +) : Field, PowerOperations, NumericAlgebra { public override val zero: BigDecimal get() = BigDecimal.ZERO diff --git a/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt index bb0d19c23..7aa746797 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt @@ -24,7 +24,7 @@ public class LazyNDStructure( } public override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false) } public override fun hashCode(): Int { diff --git a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt index 68a5dc262..0244eae7f 100644 --- a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt +++ b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt @@ -1,11 +1,6 @@ package kscience.kmath.dimensions -import kscience.kmath.linear.GenericMatrixContext -import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.Point -import kscience.kmath.linear.transpose -import kscience.kmath.operations.RealField -import kscience.kmath.operations.Ring +import kscience.kmath.linear.* import kscience.kmath.operations.invoke import kscience.kmath.structures.Matrix import kscience.kmath.structures.Structure2D @@ -42,9 +37,11 @@ public interface DMatrix : Structure2D { * An inline wrapper for a Matrix */ public inline class DMatrixWrapper( - private val structure: Structure2D + private val structure: Structure2D, ) : DMatrix { override val shape: IntArray get() = structure.shape + override val rowNum: Int get() = shape[0] + override val colNum: Int get() = shape[1] override operator fun get(i: Int, j: Int): T = structure[i, j] } @@ -81,7 +78,7 @@ public inline class DPointWrapper(public val point: Point) /** * Basic operations on dimension-safe matrices. Operates on [Matrix] */ -public inline class DMatrixContext>(public val context: GenericMatrixContext>) { +public inline class DMatrixContext(public val context: MatrixContext>) { public inline fun Matrix.coerce(): DMatrix { require(rowNum == Dimension.dim().toInt()) { "Row number mismatch: expected ${Dimension.dim()} but found $rowNum" @@ -115,7 +112,7 @@ public inline class DMatrixContext>(public val context: Ge } public inline infix fun DMatrix.dot( - other: DMatrix + other: DMatrix, ): DMatrix = context { this@dot dot other }.coerce() public inline infix fun DMatrix.dot(vector: DPoint): DPoint = @@ -139,18 +136,20 @@ public inline class DMatrixContext>(public val context: Ge public inline fun DMatrix.transpose(): DMatrix = context { (this@transpose as Matrix).transpose() }.coerce() - /** - * A square unit matrix - */ - public inline fun one(): DMatrix = produce { i, j -> - if (i == j) context.elementContext.one else context.elementContext.zero - } - - public inline fun zero(): DMatrix = produce { _, _ -> - context.elementContext.zero - } - public companion object { - public val real: DMatrixContext = DMatrixContext(MatrixContext.real) + public val real: DMatrixContext = DMatrixContext(MatrixContext.real) } } + + +/** + * A square unit matrix + */ +public inline fun DMatrixContext.one(): DMatrix = produce { i, j -> + if (i == j) 1.0 else 0.0 +} + +public inline fun DMatrixContext.zero(): DMatrix = + produce { _, _ -> + 0.0 + } \ No newline at end of file diff --git a/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt index f44b16753..b9193d4dd 100644 --- a/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt +++ b/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt @@ -3,8 +3,10 @@ package kscience.dimensions import kscience.kmath.dimensions.D2 import kscience.kmath.dimensions.D3 import kscience.kmath.dimensions.DMatrixContext +import kscience.kmath.dimensions.one import kotlin.test.Test +@Suppress("UNUSED_VARIABLE") internal class DMatrixContextTest { @Test fun testDimensionSafeMatrix() { diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt index ed6b1571e..82a5399fd 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt @@ -1,12 +1,14 @@ package kscience.kmath.ejml +import kscience.kmath.linear.* +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.NDStructure +import kscience.kmath.structures.RealBuffer import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.simple.SimpleMatrix -import kscience.kmath.linear.DeterminantFeature -import kscience.kmath.linear.FeaturedMatrix -import kscience.kmath.linear.LUPDecompositionFeature -import kscience.kmath.linear.MatrixFeature -import kscience.kmath.structures.NDStructure +import kotlin.reflect.KClass +import kotlin.reflect.cast /** * Represents featured matrix over EJML [SimpleMatrix]. @@ -14,58 +16,75 @@ import kscience.kmath.structures.NDStructure * @property origin the underlying [SimpleMatrix]. * @author Iaroslav Postovalov */ -public class EjmlMatrix(public val origin: SimpleMatrix, features: Set? = null) : FeaturedMatrix { - public override val rowNum: Int - get() = origin.numRows() +public class EjmlMatrix( + public val origin: SimpleMatrix, +) : Matrix { + public override val rowNum: Int get() = origin.numRows() - public override val colNum: Int - get() = origin.numCols() + public override val colNum: Int get() = origin.numCols() - public override val shape: IntArray - get() = intArrayOf(origin.numRows(), origin.numCols()) - - public override val features: Set = setOf( - object : LUPDecompositionFeature, DeterminantFeature { - override val determinant: Double - get() = origin.determinant() - - private val lup by lazy { - val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()) - .also { it.decompose(origin.ddrm.copy()) } - - Triple( - EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), - EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), - EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), - ) + @UnstableKMathAPI + override fun getFeature(type: KClass): T? = when (type) { + InverseMatrixFeature::class -> object : InverseMatrixFeature { + override val inverse: Matrix by lazy { EjmlMatrix(origin.invert()) } + } + DeterminantFeature::class -> object : DeterminantFeature { + override val determinant: Double by lazy(origin::determinant) + } + SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature { + private val svd by lazy { + DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false) + .apply { decompose(origin.ddrm.copy()) } } - override val l: FeaturedMatrix - get() = lup.second - - override val u: FeaturedMatrix - get() = lup.third - - override val p: FeaturedMatrix - get() = lup.first + override val u: Matrix by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) } + override val s: Matrix by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) } + override val v: Matrix by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) } + override val singularValues: Point by lazy { RealBuffer(svd.singularValues) } } - ) union features.orEmpty() + QRDecompositionFeature::class -> object : QRDecompositionFeature { + private val qr by lazy { + DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) } + } - public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix = - EjmlMatrix(origin, this.features + features) + override val q: Matrix by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) } + override val r: Matrix by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) } + } + CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { + override val l: Matrix by lazy { + val cholesky = + DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) } + + EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature + } + } + LupDecompositionFeature::class -> object : LupDecompositionFeature { + private val lup by lazy { + DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) } + } + + override val l: Matrix by lazy { + EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature + } + + override val u: Matrix by lazy { + EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature + } + + override val p: Matrix by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) } + } + else -> null + }?.let { type.cast(it) } public override operator fun get(i: Int, j: Int): Double = origin[i, j] - public override fun equals(other: Any?): Boolean { - if (other is EjmlMatrix) return origin.isIdentical(other.origin, 0.0) - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is Matrix<*>) return false + return NDStructure.contentEquals(this, other) } - public override fun hashCode(): Int { - var result = origin.hashCode() - result = 31 * result + features.hashCode() - return result - } + override fun hashCode(): Int = origin.hashCode() + - public override fun toString(): String = "EjmlMatrix(origin=$origin, features=$features)" } diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt index 31792e39c..8184d0110 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -1,16 +1,14 @@ package kscience.kmath.ejml +import kscience.kmath.linear.InverseMatrixFeature import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.Point +import kscience.kmath.linear.origin +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix +import kscience.kmath.structures.getFeature import org.ejml.simple.SimpleMatrix -/** - * Converts this matrix to EJML one. - */ -public fun Matrix.toEjml(): EjmlMatrix = - if (this is EjmlMatrix) this else EjmlMatrixContext.produce(rowNum, colNum) { i, j -> get(i, j) } - /** * Represents context of basic operations operating with [EjmlMatrix]. * @@ -18,6 +16,15 @@ public fun Matrix.toEjml(): EjmlMatrix = */ public object EjmlMatrixContext : MatrixContext { + /** + * Converts this matrix to EJML one. + */ + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toEjml(): EjmlMatrix = when (val matrix = origin) { + is EjmlMatrix -> matrix + else -> produce(rowNum, colNum) { i, j -> get(i, j) } + } + /** * Converts this vector to EJML one. */ @@ -33,6 +40,11 @@ public object EjmlMatrixContext : MatrixContext { } }) + override fun point(size: Int, initializer: (Int) -> Double): Point = + EjmlVector(SimpleMatrix(size, 1).also { + (0 until it.numRows()).forEach { row -> it[row, 0] = initializer(row) } + }) + public override fun Matrix.dot(other: Matrix): EjmlMatrix = EjmlMatrix(toEjml().origin.mult(other.toEjml().origin)) @@ -74,11 +86,7 @@ public fun EjmlMatrixContext.solve(a: Matrix, b: Matrix): EjmlMa public fun EjmlMatrixContext.solve(a: Matrix, b: Point): EjmlVector = EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) -/** - * Returns the inverse of given matrix: b = a^(-1). - * - * @param a the matrix. - * @return the inverse of this matrix. - * @author Iaroslav Postovalov - */ -public fun EjmlMatrixContext.inverse(a: Matrix): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert()) +@OptIn(UnstableKMathAPI::class) +public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature>()!!.inverse as EjmlMatrix + +public fun EjmlMatrixContext.inverse(matrix: Matrix): Matrix = matrix.toEjml().inverted() \ No newline at end of file diff --git a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt index e0f15be83..455b52d9d 100644 --- a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt +++ b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt @@ -1,9 +1,11 @@ package kscience.kmath.ejml import kscience.kmath.linear.DeterminantFeature -import kscience.kmath.linear.LUPDecompositionFeature +import kscience.kmath.linear.LupDecompositionFeature import kscience.kmath.linear.MatrixFeature -import kscience.kmath.linear.getFeature +import kscience.kmath.linear.plus +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.structures.getFeature import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.simple.SimpleMatrix import kotlin.random.Random @@ -38,13 +40,14 @@ internal class EjmlMatrixTest { assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList()) } + @OptIn(UnstableKMathAPI::class) @Test fun features() { val m = randomMatrix val w = EjmlMatrix(m) val det = w.getFeature>() ?: fail() assertEquals(m.determinant(), det.determinant) - val lup = w.getFeature>() ?: fail() + val lup = w.getFeature>() ?: fail() val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols()) .also { it.decompose(m.ddrm.copy()) } @@ -56,9 +59,10 @@ internal class EjmlMatrixTest { private object SomeFeature : MatrixFeature {} + @OptIn(UnstableKMathAPI::class) @Test fun suggestFeature() { - assertNotNull(EjmlMatrix(randomMatrix).suggestFeature(SomeFeature).getFeature()) + assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature()) } @Test diff --git a/kmath-for-real/README.md b/kmath-for-real/README.md index 2ddf78e57..d6b66b7da 100644 --- a/kmath-for-real/README.md +++ b/kmath-for-real/README.md @@ -21,7 +21,7 @@ > maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } > maven { url 'https://dl.bintray.com/hotkeytlt/maven' } - +> > } > > dependencies { diff --git a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt index e8ad835e5..274030aff 100644 --- a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt @@ -1,14 +1,12 @@ package kscience.kmath.real -import kscience.kmath.linear.FeaturedMatrix import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.RealMatrixContext.elementContext import kscience.kmath.linear.VirtualMatrix import kscience.kmath.linear.inverseWithLUP +import kscience.kmath.linear.real import kscience.kmath.misc.UnstableKMathAPI -import kscience.kmath.operations.invoke -import kscience.kmath.operations.sum import kscience.kmath.structures.Buffer +import kscience.kmath.structures.Matrix import kscience.kmath.structures.RealBuffer import kscience.kmath.structures.asIterable import kotlin.math.pow @@ -25,7 +23,7 @@ import kotlin.math.pow * Functions that help create a real (Double) matrix */ -public typealias RealMatrix = FeaturedMatrix +public typealias RealMatrix = Matrix public fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = MatrixContext.real.produce(rowNum, colNum, initializer) @@ -122,8 +120,7 @@ public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix = extractColumns(columnIndex..columnIndex) public fun RealMatrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> - val column = columns[j] - elementContext { sum(column.asIterable()) } + columns[j].asIterable().sum() } public fun RealMatrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> diff --git a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt index 5c33b76a9..a89f99b3c 100644 --- a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt @@ -1,6 +1,5 @@ package kaceince.kmath.real -import kscience.kmath.linear.VirtualMatrix import kscience.kmath.linear.build import kscience.kmath.real.* import kscience.kmath.structures.Matrix @@ -42,7 +41,7 @@ internal class RealMatrixTest { 1.0, 0.0, 0.0, 0.0, 1.0, 2.0 ) - assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3)) + assertEquals(matrix2, matrix1.repeatStackVertical(3)) } @Test diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md index 176dfc09d..ff4ff4542 100644 --- a/kmath-nd4j/README.md +++ b/kmath-nd4j/README.md @@ -23,7 +23,7 @@ This subproject implements the following features: > maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } > maven { url 'https://dl.bintray.com/hotkeytlt/maven' } - +> > } > > dependencies { diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt index a8c874fc3..db2a44861 100644 --- a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt @@ -1,5 +1,6 @@ package kscience.kmath.nd4j +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.* import kscience.kmath.structures.NDAlgebra import kscience.kmath.structures.NDField @@ -35,7 +36,7 @@ public interface Nd4jArrayAlgebra : NDAlgebra> public override fun mapIndexed( arg: Nd4jArrayStructure, - transform: C.(index: IntArray, T) -> T + transform: C.(index: IntArray, T) -> T, ): Nd4jArrayStructure { check(arg) val new = Nd4j.create(*shape).wrap() @@ -46,7 +47,7 @@ public interface Nd4jArrayAlgebra : NDAlgebra> public override fun combine( a: Nd4jArrayStructure, b: Nd4jArrayStructure, - transform: C.(T, T) -> T + transform: C.(T, T) -> T, ): Nd4jArrayStructure { check(a, b) val new = Nd4j.create(*shape).wrap() @@ -61,8 +62,8 @@ public interface Nd4jArrayAlgebra : NDAlgebra> * @param T the type of the element contained in ND structure. * @param S the type of space of structure elements. */ -public interface Nd4jArraySpace : NDSpace>, - Nd4jArrayAlgebra where S : Space { +public interface Nd4jArraySpace> : NDSpace>, Nd4jArrayAlgebra { + public override val zero: Nd4jArrayStructure get() = Nd4j.zeros(*shape).wrap() @@ -103,7 +104,9 @@ public interface Nd4jArraySpace : NDSpace>, * @param T the type of the element contained in ND structure. * @param R the type of ring of structure elements. */ -public interface Nd4jArrayRing : NDRing>, Nd4jArraySpace where R : Ring { +@OptIn(UnstableKMathAPI::class) +public interface Nd4jArrayRing> : NDRing>, Nd4jArraySpace { + public override val one: Nd4jArrayStructure get() = Nd4j.ones(*shape).wrap() @@ -111,21 +114,21 @@ public interface Nd4jArrayRing : NDRing>, Nd4j check(a, b) return a.ndArray.mul(b.ndArray).wrap() } - - public override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure { - check(this) - return ndArray.sub(b).wrap() - } - - public override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure { - check(this) - return ndArray.add(b).wrap() - } - - public override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { - check(b) - return b.ndArray.rsub(this).wrap() - } +// +// public override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure { +// check(this) +// return ndArray.sub(b).wrap() +// } +// +// public override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure { +// check(this) +// return ndArray.add(b).wrap() +// } +// +// public override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { +// check(b) +// return b.ndArray.rsub(this).wrap() +// } public companion object { private val intNd4jArrayRingCache: ThreadLocal> = @@ -165,7 +168,8 @@ public interface Nd4jArrayRing : NDRing>, Nd4j * @param N the type of ND structure. * @param F the type field of structure elements. */ -public interface Nd4jArrayField : NDField>, Nd4jArrayRing where F : Field { +public interface Nd4jArrayField> : NDField>, Nd4jArrayRing { + public override fun divide(a: Nd4jArrayStructure, b: Nd4jArrayStructure): Nd4jArrayStructure { check(a, b) return a.ndArray.div(b.ndArray).wrap() diff --git a/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt index c2304070f..4e29e6105 100644 --- a/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt @@ -62,6 +62,7 @@ class MCScopeTest { } + @OptIn(ObsoleteCoroutinesApi::class) fun compareResult(test: ATest) { val res1 = runBlocking(Dispatchers.Default) { test() } val res2 = runBlocking(newSingleThreadContext("test")) { test() } diff --git a/settings.gradle.kts b/settings.gradle.kts index 10e4d9577..a1ea40148 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -8,8 +8,8 @@ pluginManagement { maven("https://dl.bintray.com/kotlin/kotlinx") } - val toolsVersion = "0.7.0" - val kotlinVersion = "1.4.20" + val toolsVersion = "0.7.3-1.4.30-RC" + val kotlinVersion = "1.4.30-RC" plugins { id("kotlinx.benchmark") version "0.2.0-dev-20"