diff --git a/.github/workflows/gradle.yml b/.github/workflows/gradle.yml index adc74adfe..467a867bc 100644 --- a/.github/workflows/gradle.yml +++ b/.github/workflows/gradle.yml @@ -1,17 +1,101 @@ name: Gradle build -on: [push] +on: [ push ] jobs: - build: - - runs-on: ubuntu-latest + build-ubuntu: + runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v1 - - name: Set up JDK 11 - uses: actions/setup-java@v1 - with: - java-version: 11 - - name: Build with Gradle - run: ./gradlew build + - uses: actions/checkout@v2 + - name: Set up JDK 11 + uses: actions/setup-java@v1 + with: + java-version: 11 + - name: Install Chrome + run: | + sudo apt install -y libappindicator1 fonts-liberation + wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb + sudo dpkg -i google-chrome*.deb + - name: Cache gradle + uses: actions/cache@v2 + with: + path: | + .gradle + build + ~/.gradle + key: gradle + restore-keys: gradle + + - name: Cache konan + uses: actions/cache@v2 + with: + path: | + ~/.konan/dependencies + ~/.konan/kotlin-native-prebuilt-linux-* + key: ${{ runner.os }}-konan + restore-keys: ${{ runner.os }}-konan + - name: Build with Gradle + run: ./gradlew -Dorg.gradle.daemon=false --build-cache build + + build-osx: + runs-on: macos-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK 11 + uses: actions/setup-java@v1 + with: + java-version: 11 + - name: Cache gradle + uses: actions/cache@v2 + with: + path: | + .gradle + build + ~/.gradle + key: gradle + restore-keys: gradle + + - name: Cache konan + uses: actions/cache@v2 + with: + path: | + ~/.konan/dependencies + ~/.konan/kotlin-native-prebuilt-macos-* + key: ${{ runner.os }}-konan + restore-keys: ${{ runner.os }}-konan + - name: Build with Gradle + run: sudo ./gradlew -Dorg.gradle.daemon=false --build-cache build + + build-windows: + runs-on: windows-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK 11 + uses: actions/setup-java@v1 + with: + java-version: 11 + - name: Add msys to path + run: SETX PATH "%PATH%;C:\msys64\mingw64\bin" + - name: Cache gradle + uses: actions/cache@v2 + with: + path: | + .gradle + build + ~/.gradle + key: ${{ runner.os }}-gradle + restore-keys: ${{ runner.os }}-gradle + + - name: Cache konan + uses: actions/cache@v2 + with: + path: | + ~/.konan/dependencies + ~/.konan/kotlin-native-prebuilt-mingw-* + key: ${{ runner.os }}-konan + restore-keys: ${{ runner.os }}-konan + - name: Build with Gradle + run: ./gradlew --build-cache build diff --git a/.gitignore b/.gitignore index a9294eff9..bade7f08c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,3 @@ out/ # Cache of project .gradletasknamecache - -gradle.properties \ No newline at end of file diff --git a/.space.kts b/.space.kts new file mode 100644 index 000000000..d70ad6d59 --- /dev/null +++ b/.space.kts @@ -0,0 +1,3 @@ +job("Build") { + gradlew("openjdk:11", "build") +} diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bcc57810..aa70e6116 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,16 +2,53 @@ ## [Unreleased] ### Added +- `fun` annotation for SAM interfaces in library +- Explicit `public` visibility for all public APIs +- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140) +- Automatic README generation for features (#139) +- Native support for `memory`, `core` and `dimensions` +- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136) +- A separate `Symbol` entity, which is used for global unbound symbol. +- A `Symbol` indexing scope. +- Basic optimization API for Commons-math. +- Chi squared optimization for array-like data in CM +- `Fitting` utility object in prob/stat +- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray` +- Coroutine-deterministic Monte-Carlo scope with a random number generator +- Some minor utilities to `kmath-for-real` +- Generic operation result parameter to `MatrixContext` +- New `MatrixFeature` interfaces for matrix decompositions ### Changed +- Package changed from `scientifik` to `kscience.kmath` +- Gradle version: 6.6 -> 6.8 +- Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`) +- `Polynomial` secondary constructor made function +- Kotlin version: 1.3.72 -> 1.4.21 +- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library +- Full autodiff refactoring based on `Symbol` +- `kmath-prob` renamed to `kmath-stat` +- Grid generators moved to `kmath-for-real` +- Use `Point` instead of specialized type in `kmath-for-real` +- Optimized dot product for buffer matrices moved to `kmath-for-real` +- EjmlMatrix context is an object +- Matrix LUP `inverse` renamed to `inverseWithLUP` +- `NumericAlgebra` moved outside of regular algebra chain (`Ring` no longer implements it). +- Features moved to NDStructure and became transparent. ### Deprecated ### Removed +- `kmath-koma` module because it doesn't support Kotlin 1.4. +- Support of `legacy` JS backend (we will support only IR) +- `toGrid` method. +- Public visibility of `BufferAccessor2D` ### Fixed +- `symbol` method in `MstExtendedField` (https://github.com/mipt-npm/kmath/pull/140) ### Security + ## [0.1.4] ### Added diff --git a/README.md b/README.md index 53de9f037..0899f77cc 100644 --- a/README.md +++ b/README.md @@ -3,46 +3,55 @@ ![Gradle build](https://github.com/mipt-npm/kmath/workflows/Gradle%20build/badge.svg) -Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/scientifik/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion) +Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion) Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion) # KMath -Could be pronounced as `key-math`. -The Kotlin MATHematics library is intended as a Kotlin-based analog to Python's `numpy` library. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. -## Publications +Could be pronounced as `key-math`. The Kotlin MATHematics library was initially intended as a Kotlin-based analog to +Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture +designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could +be achieved with [kmath-for-real](/kmath-for-real) extension module. + +## Publications and talks + * [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2) * [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814) * [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103) # Goal -* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM and JS for now and Native in future). + +* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native). * Provide basic multiplatform implementations for those abstractions (without significant performance optimization). * Provide bindings and wrappers with those abstractions for popular optimized platform libraries. ## Non-goals -* Be like Numpy. It was the idea at the beginning, but we decided that we can do better in terms of API. -* Provide best performance out of the box. We have specialized libraries for that. Need only API wrappers for them. + +* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API. +* Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them. * Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually. -* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better experience for those, who want to work with specific types. +* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like +for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better +experience for those, who want to work with specific types. ## Features -Actual feature list is [here](doc/features.md) +Current feature list is [here](/docs/features.md) * **Algebra** - * Algebraic structures like rings, spaces and field (**TODO** add example to wiki) + * Algebraic structures like rings, spaces and fields (**TODO** add example to wiki) * Basic linear algebra operations (sums, products, etc.), backed by the `Space` API. - * Complex numbers backed by the `Field` API (meaning that they will be usable in any structure like vectors and N-dimensional arrays). + * Complex numbers backed by the `Field` API (meaning they will be usable in any structure like vectors and + N-dimensional arrays). * Advanced linear algebra operations like matrix inversion and LU decomposition. * **Array-like structures** Full support of many-dimensional array-like structures including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking). -* **Expressions** By writing a single mathematical expression -once, users will be able to apply different types of objects to the expression by providing a context. Expressions -can be used for a wide variety of purposes from high performance calculations to code generation. +* **Expressions** By writing a single mathematical expression once, users will be able to apply different types of +objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high +performance calculations to code generation. * **Histograms** Fast multi-dimensional histograms. @@ -50,13 +59,11 @@ can be used for a wide variety of purposes from high performance calculations to * **Type-safe dimensions** Type-safe dimensions for matrix operations. -* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/) - library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free - to submit a feature request if you want something to be done first. +* **Commons-math wrapper** It is planned to gradually wrap most parts of +[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some +parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to +submit a feature request if you want something to be implemented first. -* **Koma wrapper** [Koma](https://github.com/kyonifer/koma) is a well established numerics library in Kotlin, specifically linear algebra. -The plan is to have wrappers for koma implementations for compatibility with kmath API. - ## Planned features * **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks. @@ -69,41 +76,188 @@ The plan is to have wrappers for koma implementations for compatibility with kma * **Fitting** Non-linear curve fitting facilities +## Modules + +
+ +* ### [examples](examples) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-ast](kmath-ast) +> +> +> **Maturity**: PROTOTYPE +> +> **Features:** +> - [expression-language](kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt) : Expression language and its parser +> - [mst](kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation +> - [mst-building](kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt) : MST building algebraic structure +> - [mst-interpreter](kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST interpreter +> - [mst-jvm-codegen](kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler +> - [mst-js-codegen](kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler + +
+ +* ### [kmath-commons](kmath-commons) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-core](kmath-core) +> Core classes, algebra definitions, basic linear algebra +> +> **Maturity**: DEVELOPMENT +> +> **Features:** +> - [algebras](kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : Algebraic structures: contexts and elements +> - [nd](kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Many-dimensional structures +> - [buffers](kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure +> - [expressions](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions +> - [domains](kmath-core/src/commonMain/kotlin/kscience/kmath/domains) : Domains +> - [autodif](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation + +
+ +* ### [kmath-coroutines](kmath-coroutines) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-dimensions](kmath-dimensions) +> +> +> **Maturity**: PROTOTYPE +
+ +* ### [kmath-ejml](kmath-ejml) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-for-real](kmath-for-real) +> Extension module that should be used to achieve numpy-like behavior. +All operations are specialized to work with `Double` numbers without declaring algebraic contexts. +One can still use generic algebras though. +> +> **Maturity**: EXPERIMENTAL +> +> **Features:** +> - [RealVector](kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealVector.kt) : Numpy-like operations for Buffers/Points +> - [RealMatrix](kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt) : Numpy-like operations for 2d real structures +> - [grids](kmath-for-real/src/commonMain/kotlin/kscience/kmath/structures/grids.kt) : Uniform grid generators + +
+ +* ### [kmath-functions](kmath-functions) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-geometry](kmath-geometry) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-histograms](kmath-histograms) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-kotlingrad](kmath-kotlingrad) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-memory](kmath-memory) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-nd4j](kmath-nd4j) +> ND4J NDStructure implementation and according NDAlgebra classes +> +> **Maturity**: EXPERIMENTAL +> +> **Features:** +> - [nd4jarraystructure](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : NDStructure wrapper for INDArray +> - [nd4jarrayrings](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Rings over Nd4jArrayStructure of Int and Long +> - [nd4jarrayfields](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : Fields over Nd4jArrayStructure of Float and Double + +
+ +* ### [kmath-stat](kmath-stat) +> +> +> **Maturity**: EXPERIMENTAL +
+ +* ### [kmath-viktor](kmath-viktor) +> +> +> **Maturity**: EXPERIMENTAL +
+ + ## Multi-platform support -KMath is developed as a multi-platform library, which means that most of interfaces are declared in the [common module](kmath-core/src/commonMain). Implementation is also done in the common module wherever possible. In some cases, features are delegated to platform-specific implementations even if they could be done in the common module for performance reasons. Currently, the JVM is the main focus of development, however Kotlin/Native and Kotlin/JS contributions are also welcome. +KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the +[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features +are delegated to platform-specific implementations even if they could be provided in the common module for performance +reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and +feedback are also welcome. ## Performance -Calculation performance is one of major goals of KMath in the future, but in some cases it is not possible to achieve both performance and flexibility. We expect to focus on creating convenient universal API first and then work on increasing performance for specific cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be better than SciPy. +Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve +both performance and flexibility. -### Dependency +We expect to focus on creating convenient universal API first and then work on increasing performance for specific +cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized +native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be +better than SciPy. -Release artifacts are accessible from bintray with following configuration (see documentation for [kotlin-multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) form more details): +### Repositories + +Release artifacts are accessible from bintray with following configuration (see documentation of +[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details): ```kotlin -repositories{ - maven("https://dl.bintray.com/mipt-npm/scientifik") +repositories { + maven("https://dl.bintray.com/mipt-npm/kscience") } -dependencies{ - api("kscience.kmath:kmath-core:${kmathVersion}") - //api("scientifik:kmath-core:${kmathVersion}") for 0.1.3 and earlier +dependencies { + api("kscience.kmath:kmath-core:0.2.0-dev-4") + // api("kscience.kmath:kmath-core-jvm:0.2.0-dev-4") for jvm-specific version } ``` Gradle `6.0+` is required for multiplatform artifacts. -### Development +#### Development + +Development builds are uploaded to the separate repository: -Development builds are accessible from the reposirtory ```kotlin -repositories{ +repositories { maven("https://dl.bintray.com/mipt-npm/dev") } ``` -with the same artifact names. ## Contributing -The project requires a lot of additional work. Please feel free to contribute in any way and propose new features. +The project requires a lot of additional work. The most important thing we need is a feedback about what features are +required the most. Feel free to create feature requests. We are also welcome to code contributions, +especially in issues marked with +[waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label. diff --git a/build.gradle.kts b/build.gradle.kts index b24ecd15b..d171bd608 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,36 +1,44 @@ +import ru.mipt.npm.gradle.KSciencePublishPlugin + plugins { - id("scientifik.publish") apply false - id("org.jetbrains.changelog") version "0.4.0" + id("ru.mipt.npm.project") } -val kmathVersion by extra("0.1.4") - -val bintrayRepo by extra("scientifik") -val githubProject by extra("kmath") +internal val kmathVersion: String by extra("0.2.0-dev-5") +internal val bintrayRepo: String by extra("kscience") +internal val githubProject: String by extra("kmath") allprojects { repositories { jcenter() - maven("https://dl.bintray.com/kotlin/kotlinx") + maven("https://clojars.org/repo") + maven("https://dl.bintray.com/egor-bogomolov/astminer/") maven("https://dl.bintray.com/hotkeytlt/maven") + maven("https://dl.bintray.com/kotlin/kotlin-eap") + maven("https://dl.bintray.com/kotlin/kotlinx") + maven("https://dl.bintray.com/mipt-npm/dev") + maven("https://dl.bintray.com/mipt-npm/kscience") + maven("https://jitpack.io") + maven("http://logicrunch.research.it.uu.se/maven/") + mavenCentral() } group = "kscience.kmath" version = kmathVersion - - afterEvaluate { - extensions.findByType()?.run { - targets.all { - sourceSets.all { - languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") - } - } - } - } } subprojects { - if (name.startsWith("kmath")) { - apply(plugin = "scientifik.publish") - } -} \ No newline at end of file + if (name.startsWith("kmath")) apply() +} + +readme { + readmeTemplate = file("docs/templates/README-TEMPLATE.md") +} + +apiValidation { + validationDisabled = true +} + +ksciencePublish { + spaceRepo = "https://maven.pkg.jetbrains.space/mipt-npm/p/sci/maven" +} diff --git a/doc/features.md b/doc/features.md deleted file mode 100644 index e6a820c1e..000000000 --- a/doc/features.md +++ /dev/null @@ -1,17 +0,0 @@ -# Features - -* [Algebra](./algebra.md) - [Context-based](./contexts.md) operations on different primitives and structures. - -* [NDStructures](./nd-structure.md) - -* [Linear algebra](./linear.md) - Matrices, operations and linear equations solving. To be moved to separate module. Currently supports basic -api and multiple library back-ends. - -* [Histograms](./histograms.md) - Multidimensional histogram calculation and operations. - -* [Expressions](./expressions.md) - -* Commons math integration - -* Koma integration - diff --git a/doc/algebra.md b/docs/algebra.md similarity index 97% rename from doc/algebra.md rename to docs/algebra.md index b1b77a31f..c3227517f 100644 --- a/doc/algebra.md +++ b/docs/algebra.md @@ -5,7 +5,7 @@ operation, say `+`, one needs two objects of a type `T` and an algebra context, say `Space`. Next one needs to run the actual operation in the context: ```kotlin -import scientifik.kmath.operations.* +import kscience.kmath.operations.* val a: T = ... val b: T = ... @@ -47,7 +47,7 @@ but it also holds reference to the `ComplexField` singleton, which allows perfor numbers without explicit involving the context like: ```kotlin -import scientifik.kmath.operations.* +import kscience.kmath.operations.* // Using elements val c1 = Complex(1.0, 1.0) @@ -82,7 +82,7 @@ operations in all performance-critical places. The performance of element operat KMath submits both contexts and elements for builtin algebraic structures: ```kotlin -import scientifik.kmath.operations.* +import kscience.kmath.operations.* val c1 = Complex(1.0, 2.0) val c2 = ComplexField.i @@ -95,7 +95,7 @@ val c3 = ComplexField { c1 + c2 } Also, `ComplexField` features special operations to mix complex and real numbers, for example: ```kotlin -import scientifik.kmath.operations.* +import kscience.kmath.operations.* val c1 = Complex(1.0, 2.0) val c2 = ComplexField { c1 - 1.0 } // Returns: Complex(re=0.0, im=2.0) diff --git a/doc/buffers.md b/docs/buffers.md similarity index 100% rename from doc/buffers.md rename to docs/buffers.md diff --git a/doc/codestyle.md b/docs/codestyle.md similarity index 100% rename from doc/codestyle.md rename to docs/codestyle.md diff --git a/doc/contexts.md b/docs/contexts.md similarity index 100% rename from doc/contexts.md rename to docs/contexts.md diff --git a/doc/expressions.md b/docs/expressions.md similarity index 100% rename from doc/expressions.md rename to docs/expressions.md diff --git a/docs/features.md b/docs/features.md new file mode 100644 index 000000000..1068a4417 --- /dev/null +++ b/docs/features.md @@ -0,0 +1,14 @@ +# Features + +* [Algebra](algebra.md) - [Context-based](contexts.md) operations on different primitives and structures. + +* [NDStructures](nd-structure.md) + +* [Linear algebra](linear.md) - Matrices, operations and linear equations solving. To be moved to separate module. Currently supports basic +api and multiple library back-ends. + +* [Histograms](histograms.md) - Multidimensional histogram calculation and operations. + +* [Expressions](expressions.md) + +* Commons math integration diff --git a/doc/histograms.md b/docs/histograms.md similarity index 100% rename from doc/histograms.md rename to docs/histograms.md diff --git a/docs/images/KM.svg b/docs/images/KM.svg new file mode 100644 index 000000000..50126cbc5 --- /dev/null +++ b/docs/images/KM.svg @@ -0,0 +1,59 @@ + +image/svg+xml \ No newline at end of file diff --git a/docs/images/KM_mono.svg b/docs/images/KM_mono.svg new file mode 100644 index 000000000..3b6890b6b --- /dev/null +++ b/docs/images/KM_mono.svg @@ -0,0 +1,55 @@ + +image/svg+xml \ No newline at end of file diff --git a/docs/images/KMath.svg b/docs/images/KMath.svg new file mode 100644 index 000000000..d88cfe7b0 --- /dev/null +++ b/docs/images/KMath.svg @@ -0,0 +1,91 @@ + +image/svg+xml \ No newline at end of file diff --git a/docs/images/KMath_mono.svg b/docs/images/KMath_mono.svg new file mode 100644 index 000000000..3a62ac383 --- /dev/null +++ b/docs/images/KMath_mono.svg @@ -0,0 +1,371 @@ + +image/svg+xml \ No newline at end of file diff --git a/doc/linear.md b/docs/linear.md similarity index 79% rename from doc/linear.md rename to docs/linear.md index 883df275e..6ccc6caac 100644 --- a/doc/linear.md +++ b/docs/linear.md @@ -6,10 +6,10 @@ back-ends. The new operations added as extensions to contexts instead of being m Two major contexts used for linear algebra and hyper-geometry: -* `VectorSpace` forms a mathematical space on top of array-like structure (`Buffer` and its typealias `Point` used for geometry). +* `VectorSpace` forms a mathematical space on top of array-like structure (`Buffer` and its type alias `Point` used for geometry). * `MatrixContext` forms a space-like context for 2d-structures. It does not store matrix size and therefore does not implement -`Space` interface (it is not possible to create zero element without knowing the matrix size). +`Space` interface (it is impossible to create zero element without knowing the matrix size). ## Vector spaces diff --git a/doc/nd-structure.md b/docs/nd-structure.md similarity index 100% rename from doc/nd-structure.md rename to docs/nd-structure.md diff --git a/docs/templates/ARTIFACT-TEMPLATE.md b/docs/templates/ARTIFACT-TEMPLATE.md new file mode 100644 index 000000000..d46a431bd --- /dev/null +++ b/docs/templates/ARTIFACT-TEMPLATE.md @@ -0,0 +1,37 @@ +> #### Artifact: +> +> This module artifact: `${group}:${name}:${version}`. +> +> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/${name}/images/download.svg) ](https://bintray.com/mipt-npm/kscience/${name}/_latestVersion) +> +> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/${name}/images/download.svg) ](https://bintray.com/mipt-npm/dev/${name}/_latestVersion) +> +> **Gradle:** +> +> ```gradle +> repositories { +> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } +> maven { url 'https://dl.bintray.com/mipt-npm/kscience' } +> maven { url 'https://dl.bintray.com/mipt-npm/dev' } +> maven { url 'https://dl.bintray.com/hotkeytlt/maven' } +> +> } +> +> dependencies { +> implementation '${group}:${name}:${version}' +> } +> ``` +> **Gradle Kotlin DSL:** +> +> ```kotlin +> repositories { +> maven("https://dl.bintray.com/kotlin/kotlin-eap") +> maven("https://dl.bintray.com/mipt-npm/kscience") +> maven("https://dl.bintray.com/mipt-npm/dev") +> maven("https://dl.bintray.com/hotkeytlt/maven") +> } +> +> dependencies { +> implementation("${group}:${name}:${version}") +> } +> ``` \ No newline at end of file diff --git a/docs/templates/README-TEMPLATE.md b/docs/templates/README-TEMPLATE.md new file mode 100644 index 000000000..ee1df818c --- /dev/null +++ b/docs/templates/README-TEMPLATE.md @@ -0,0 +1,134 @@ +[![JetBrains Research](https://jb.gg/badges/research.svg)](https://confluence.jetbrains.com/display/ALL/JetBrains+on+GitHub) +[![DOI](https://zenodo.org/badge/129486382.svg)](https://zenodo.org/badge/latestdoi/129486382) + +![Gradle build](https://github.com/mipt-npm/kmath/workflows/Gradle%20build/badge.svg) + +Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion) + +Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion) + +# KMath + +Could be pronounced as `key-math`. The Kotlin MATHematics library was initially intended as a Kotlin-based analog to +Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture +designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could +be achieved with [kmath-for-real](/kmath-for-real) extension module. + +## Publications and talks + +* [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2) +* [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814) +* [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103) + +# Goal + +* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native). +* Provide basic multiplatform implementations for those abstractions (without significant performance optimization). +* Provide bindings and wrappers with those abstractions for popular optimized platform libraries. + +## Non-goals + +* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API. +* Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them. +* Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually. +* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like +for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better +experience for those, who want to work with specific types. + +## Features + +Current feature list is [here](/docs/features.md) + +* **Algebra** + * Algebraic structures like rings, spaces and fields (**TODO** add example to wiki) + * Basic linear algebra operations (sums, products, etc.), backed by the `Space` API. + * Complex numbers backed by the `Field` API (meaning they will be usable in any structure like vectors and + N-dimensional arrays). + * Advanced linear algebra operations like matrix inversion and LU decomposition. + +* **Array-like structures** Full support of many-dimensional array-like structures +including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking). + +* **Expressions** By writing a single mathematical expression once, users will be able to apply different types of +objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high +performance calculations to code generation. + +* **Histograms** Fast multi-dimensional histograms. + +* **Streaming** Streaming operations on mathematical objects and objects buffers. + +* **Type-safe dimensions** Type-safe dimensions for matrix operations. + +* **Commons-math wrapper** It is planned to gradually wrap most parts of +[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some +parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to +submit a feature request if you want something to be implemented first. + +## Planned features + +* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks. + +* **Array statistics** + +* **Integration** Univariate and multivariate integration framework. + +* **Probability and distributions** + +* **Fitting** Non-linear curve fitting facilities + +## Modules + +$modules + +## Multi-platform support + +KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the +[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features +are delegated to platform-specific implementations even if they could be provided in the common module for performance +reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and +feedback are also welcome. + +## Performance + +Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve +both performance and flexibility. + +We expect to focus on creating convenient universal API first and then work on increasing performance for specific +cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized +native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be +better than SciPy. + +### Repositories + +Release artifacts are accessible from bintray with following configuration (see documentation of +[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details): + +```kotlin +repositories { + maven("https://dl.bintray.com/mipt-npm/kscience") +} + +dependencies { + api("kscience.kmath:kmath-core:$version") + // api("kscience.kmath:kmath-core-jvm:$version") for jvm-specific version +} +``` + +Gradle `6.0+` is required for multiplatform artifacts. + +#### Development + +Development builds are uploaded to the separate repository: + +```kotlin +repositories { + maven("https://dl.bintray.com/mipt-npm/dev") +} +``` + +## Contributing + +The project requires a lot of additional work. The most important thing we need is a feedback about what features are +required the most. Feel free to create feature requests. We are also welcome to code contributions, +especially in issues marked with +[waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label. diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index f5a4d5831..c079eaa84 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -1,58 +1,78 @@ -import org.jetbrains.kotlin.allopen.gradle.AllOpenExtension import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { java kotlin("jvm") - kotlin("plugin.allopen") version "1.3.72" - id("kotlinx.benchmark") version "0.2.0-dev-8" + kotlin("plugin.allopen") + id("kotlinx.benchmark") } -configure { - annotation("org.openjdk.jmh.annotations.State") -} +allOpen.annotation("org.openjdk.jmh.annotations.State") +sourceSets.register("benchmarks") repositories { - maven("http://dl.bintray.com/kyonifer/maven") - maven("https://dl.bintray.com/mipt-npm/scientifik") + jcenter() + maven("https://clojars.org/repo") + maven("https://dl.bintray.com/egor-bogomolov/astminer/") + maven("https://dl.bintray.com/hotkeytlt/maven") + maven("https://dl.bintray.com/kotlin/kotlin-eap") + maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/mipt-npm/dev") + maven("https://dl.bintray.com/mipt-npm/kscience") + maven("https://jitpack.io") + maven("http://logicrunch.research.it.uu.se/maven/") mavenCentral() } -sourceSets { - register("benchmarks") -} - dependencies { implementation(project(":kmath-ast")) + implementation(project(":kmath-kotlingrad")) implementation(project(":kmath-core")) implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) - implementation(project(":kmath-prob")) - implementation(project(":kmath-koma")) + implementation(project(":kmath-stat")) implementation(project(":kmath-viktor")) implementation(project(":kmath-dimensions")) - implementation("com.kyonifer:koma-core-ejml:0.12") - implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") - implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8") - "benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath + implementation(project(":kmath-ejml")) + implementation(project(":kmath-nd4j")) + + implementation(project(":kmath-for-real")) + + implementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7") + implementation("org.nd4j:nd4j-native:1.0.0-beta7") + +// uncomment if your system supports AVX2 +// val os = System.getProperty("os.name") +// +// if (System.getProperty("os.arch") in arrayOf("x86_64", "amd64")) when { +// os.startsWith("Windows") -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:windows-x86_64-avx2") +// os == "Linux" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:linux-x86_64-avx2") +// os == "Mac OS X" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:macosx-x86_64-avx2") +// } else + implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") + + implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11") + implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20") + implementation("org.slf4j:slf4j-simple:1.7.30") + + // plotting + implementation("kscience.plotlykt:plotlykt-server:0.3.1-dev") + + "benchmarksImplementation"("org.jetbrains.kotlinx:kotlinx.benchmark.runtime-jvm:0.2.0-dev-20") + "benchmarksImplementation"(sourceSets.main.get().output + sourceSets.main.get().runtimeClasspath) } // Configure benchmark benchmark { // Setup configurations - targets { - // This one matches sourceSet name above - register("benchmarks") - } + targets.register("benchmarks") + // This one matches sourceSet name above - configurations { - register("fast") { - warmups = 5 // number of warmup iterations - iterations = 3 // number of iterations - iterationTime = 500 // time in seconds per iteration - iterationTimeUnit = "ms" // time unity for iterationTime, default is seconds - } + configurations.register("fast") { + warmups = 1 // number of warmup iterations + iterations = 3 // number of iterations + iterationTime = 500 // time in seconds per iteration + iterationTimeUnit = "ms" // time unity for iterationTime, default is seconds } } @@ -64,8 +84,5 @@ kotlin.sourceSets.all { } tasks.withType { - kotlinOptions { - jvmTarget = Scientifik.JVM_TARGET.toString() - freeCompilerArgs = freeCompilerArgs + "-Xopt-in=kotlin.RequiresOptIn" - } + kotlinOptions.jvmTarget = "11" } diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt new file mode 100644 index 000000000..c5edcdedf --- /dev/null +++ b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -0,0 +1,63 @@ +package kscience.kmath.ast + +import kscience.kmath.asm.compile +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.expressionInField +import kscience.kmath.expressions.invoke +import kscience.kmath.expressions.symbol +import kscience.kmath.operations.Field +import kscience.kmath.operations.RealField +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import kotlin.random.Random + +@State(Scope.Benchmark) +internal class ExpressionsInterpretersBenchmark { + private val algebra: Field = RealField + + @Benchmark + fun functionalExpression() { + val expr = algebra.expressionInField { + symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0) + } + + invokeAndSum(expr) + } + + @Benchmark + fun mstExpression() { + val expr = algebra.mstInField { + symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0 + } + + invokeAndSum(expr) + } + + @Benchmark + fun asmExpression() { + val expr = algebra.mstInField { + symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0 + }.compile() + + invokeAndSum(expr) + } + + @Benchmark + fun rawExpression() { + val x by symbol + val expr = Expression { args -> args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 } + invokeAndSum(expr) + } + + private fun invokeAndSum(expr: Expression) { + val random = Random(0) + var sum = 0.0 + + repeat(1000000) { + sum += expr("x" to random.nextDouble()) + } + + println(sum) + } +} diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ArrayBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ArrayBenchmark.kt new file mode 100644 index 000000000..ebf31a590 --- /dev/null +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ArrayBenchmark.kt @@ -0,0 +1,34 @@ +package kscience.kmath.benchmarks + +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import java.nio.IntBuffer + +@State(Scope.Benchmark) +internal class ArrayBenchmark { + @Benchmark + fun benchmarkArrayRead() { + var res = 0 + for (i in 1..size) res += array[size - i] + } + + @Benchmark + fun benchmarkBufferRead() { + var res = 0 + for (i in 1..size) res += arrayBuffer[size - i] + } + + @Benchmark + fun nativeBufferRead() { + var res = 0 + for (i in 1..size) res += nativeBuffer[size - i] + } + + companion object { + const val size: Int = 1000 + val array: IntArray = IntArray(size) { it } + val arrayBuffer: IntBuffer = IntBuffer.wrap(array) + val nativeBuffer: IntBuffer = IntBuffer.allocate(size).also { for (i in 0 until size) it.put(i, i) } + } +} diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/BufferBenchmark.kt similarity index 50% rename from examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt rename to examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/BufferBenchmark.kt index e40b0c4b7..4c64517f1 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/BufferBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/BufferBenchmark.kt @@ -1,17 +1,18 @@ -package scientifik.kmath.structures +package kscience.kmath.benchmarks +import kscience.kmath.operations.Complex +import kscience.kmath.operations.complex +import kscience.kmath.structures.MutableBuffer +import kscience.kmath.structures.RealBuffer import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State -import scientifik.kmath.operations.Complex -import scientifik.kmath.operations.complex @State(Scope.Benchmark) -class BufferBenchmark { - +internal class BufferBenchmark { @Benchmark fun genericRealBufferReadWrite() { - val buffer = RealBuffer(size){it.toDouble()} + val buffer = RealBuffer(size) { it.toDouble() } (0 until size).forEach { buffer[it] @@ -20,7 +21,7 @@ class BufferBenchmark { @Benchmark fun complexBufferReadWrite() { - val buffer = MutableBuffer.complex(size / 2){Complex(it.toDouble(), -it.toDouble())} + val buffer = MutableBuffer.complex(size / 2) { Complex(it.toDouble(), -it.toDouble()) } (0 until size / 2).forEach { buffer[it] @@ -28,6 +29,6 @@ class BufferBenchmark { } companion object { - const val size = 100 + const val size: Int = 100 } } \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt new file mode 100644 index 000000000..5c59afaee --- /dev/null +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt @@ -0,0 +1,72 @@ +package kscience.kmath.benchmarks + +import kotlinx.benchmark.Benchmark +import kscience.kmath.commons.linear.CMMatrixContext +import kscience.kmath.ejml.EjmlMatrixContext + +import kscience.kmath.linear.BufferMatrixContext +import kscience.kmath.linear.RealMatrixContext +import kscience.kmath.linear.real +import kscience.kmath.operations.RealField +import kscience.kmath.operations.invoke +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.Matrix +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import kotlin.random.Random + +@State(Scope.Benchmark) +class DotBenchmark { + companion object { + val random = Random(12224) + val dim = 1000 + + //creating invertible matrix + val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } + val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } + + val cmMatrix1 = CMMatrixContext { matrix1.toCM() } + val cmMatrix2 = CMMatrixContext { matrix2.toCM() } + + val ejmlMatrix1 = EjmlMatrixContext { matrix1.toEjml() } + val ejmlMatrix2 = EjmlMatrixContext { matrix2.toEjml() } + } + + @Benchmark + fun commonsMathMultiplication() { + CMMatrixContext { + cmMatrix1 dot cmMatrix2 + } + } + + @Benchmark + fun ejmlMultiplication() { + EjmlMatrixContext { + ejmlMatrix1 dot ejmlMatrix2 + } + } + + @Benchmark + fun ejmlMultiplicationwithConversion() { + EjmlMatrixContext { + val ejmlMatrix1 = matrix1.toEjml() + val ejmlMatrix2 = matrix2.toEjml() + + ejmlMatrix1 dot ejmlMatrix2 + } + } + + @Benchmark + fun bufferedMultiplication() { + BufferMatrixContext(RealField, Buffer.Companion::real).invoke { + matrix1 dot matrix2 + } + } + + @Benchmark + fun realMultiplication() { + RealMatrixContext { + matrix1 dot matrix2 + } + } +} \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt new file mode 100644 index 000000000..5ff43ef80 --- /dev/null +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt @@ -0,0 +1,46 @@ +package kscience.kmath.linear + + +import kotlinx.benchmark.Benchmark +import kscience.kmath.commons.linear.CMMatrixContext +import kscience.kmath.commons.linear.CMMatrixContext.dot +import kscience.kmath.commons.linear.inverse +import kscience.kmath.ejml.EjmlMatrixContext +import kscience.kmath.ejml.inverse +import kscience.kmath.operations.invoke +import kscience.kmath.structures.Matrix +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import kotlin.random.Random + +@State(Scope.Benchmark) +class LinearAlgebraBenchmark { + companion object { + val random = Random(1224) + val dim = 100 + + //creating invertible matrix + val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } + val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 } + val matrix = l dot u + } + + @Benchmark + fun kmathLUPInversion() { + MatrixContext.real.inverseWithLUP(matrix) + } + + @Benchmark + fun cmLUPInversion() { + CMMatrixContext { + inverse(matrix) + } + } + + @Benchmark + fun ejmlInverse() { + EjmlMatrixContext { + inverse(matrix) + } + } +} \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/NDFieldBenchmark.kt similarity index 60% rename from examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt rename to examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/NDFieldBenchmark.kt index 46da6c6d8..1be8e7236 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/NDFieldBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/NDFieldBenchmark.kt @@ -1,13 +1,14 @@ -package scientifik.kmath.structures +package kscience.kmath.benchmarks +import kscience.kmath.operations.RealField +import kscience.kmath.operations.invoke +import kscience.kmath.structures.* import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.invoke @State(Scope.Benchmark) -class NDFieldBenchmark { +internal class NDFieldBenchmark { @Benchmark fun autoFieldAdd() { bufferedField { @@ -40,11 +41,10 @@ class NDFieldBenchmark { } companion object { - val dim = 1000 - val n = 100 - - val bufferedField = NDField.auto(RealField, dim, dim) - val specializedField = NDField.real(dim, dim) - val genericField = NDField.boxing(RealField, dim, dim) + const val dim: Int = 1000 + const val n: Int = 100 + val bufferedField: BufferedNDField = NDField.auto(RealField, dim, dim) + val specializedField: RealNDField = NDField.real(dim, dim) + val genericField: BoxingNDField = NDField.boxing(RealField, dim, dim) } } \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ViktorBenchmark.kt similarity index 72% rename from examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt rename to examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ViktorBenchmark.kt index 9627743c9..8663e353c 100644 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/ViktorBenchmark.kt @@ -1,17 +1,20 @@ -package scientifik.kmath.structures +package kscience.kmath.benchmarks +import kscience.kmath.operations.RealField +import kscience.kmath.operations.invoke +import kscience.kmath.structures.BufferedNDField +import kscience.kmath.structures.NDField +import kscience.kmath.structures.RealNDField +import kscience.kmath.viktor.ViktorNDField import org.jetbrains.bio.viktor.F64Array import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.invoke -import scientifik.kmath.viktor.ViktorNDField @State(Scope.Benchmark) -class ViktorBenchmark { - final val dim = 1000 - final val n = 100 +internal class ViktorBenchmark { + final val dim: Int = 1000 + final val n: Int = 100 // automatically build context most suited for given type. final val autoField: BufferedNDField = NDField.auto(RealField, dim, dim) @@ -36,13 +39,13 @@ class ViktorBenchmark { @Benchmark fun rawViktor() { - val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) + val one = F64Array.full(init = 1.0, shape = intArrayOf(dim, dim)) var res = one repeat(n) { res = res + one } } @Benchmark - fun realdFieldLog() { + fun realFieldLog() { realField { val fortyTwo = produce { 42.0 } var res = one diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ArrayBenchmark.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ArrayBenchmark.kt deleted file mode 100644 index d605e1b9c..000000000 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/structures/ArrayBenchmark.kt +++ /dev/null @@ -1,48 +0,0 @@ -package scientifik.kmath.structures - -import org.openjdk.jmh.annotations.Benchmark -import org.openjdk.jmh.annotations.Scope -import org.openjdk.jmh.annotations.State -import java.nio.IntBuffer - - -@State(Scope.Benchmark) -class ArrayBenchmark { - - @Benchmark - fun benchmarkArrayRead() { - var res = 0 - for (i in 1..size) { - res += array[size - i] - } - } - - @Benchmark - fun benchmarkBufferRead() { - var res = 0 - for (i in 1..size) { - res += arrayBuffer.get(size - i) - } - } - - @Benchmark - fun nativeBufferRead() { - var res = 0 - for (i in 1..size) { - res += nativeBuffer.get(size - i) - } - } - - companion object { - val size = 1000 - - val array = IntArray(size) { it } - val arrayBuffer = IntBuffer.wrap(array) - val nativeBuffer = IntBuffer.allocate(size).also { - for (i in 0 until size) { - it.put(i, i) - } - - } - } -} \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt b/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt deleted file mode 100644 index 3b0d56291..000000000 --- a/examples/src/benchmarks/kotlin/scientifik/kmath/utils/utils.kt +++ /dev/null @@ -1,11 +0,0 @@ -package scientifik.kmath.utils - -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract -import kotlin.system.measureTimeMillis - -internal inline fun measureAndPrint(title: String, block: () -> Unit) { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - val time = measureTimeMillis(block) - println("$title completed in $time millis") -} diff --git a/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt new file mode 100644 index 000000000..b3c827503 --- /dev/null +++ b/examples/src/main/kotlin/kscience/kmath/ast/KotlingradSupport.kt @@ -0,0 +1,24 @@ +package kscience.kmath.ast + +import kscience.kmath.asm.compile +import kscience.kmath.expressions.derivative +import kscience.kmath.expressions.invoke +import kscience.kmath.expressions.symbol +import kscience.kmath.kotlingrad.differentiable +import kscience.kmath.operations.RealField + +/** + * In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with + * valid derivative. + */ +fun main() { + val x by symbol + + val actualDerivative = MstExpression(RealField, "x^2-4*x-44".parseMath()) + .differentiable() + .derivative(x) + .compile() + + val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() + assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) +} diff --git a/examples/src/main/kotlin/kscience/kmath/commons/fit/fitWithAutoDiff.kt b/examples/src/main/kotlin/kscience/kmath/commons/fit/fitWithAutoDiff.kt new file mode 100644 index 000000000..c0cd9dc5c --- /dev/null +++ b/examples/src/main/kotlin/kscience/kmath/commons/fit/fitWithAutoDiff.kt @@ -0,0 +1,102 @@ +package kscience.kmath.commons.fit + +import kotlinx.html.br +import kotlinx.html.h3 +import kscience.kmath.commons.optimization.chiSquared +import kscience.kmath.commons.optimization.minimize +import kscience.kmath.expressions.symbol +import kscience.kmath.real.RealVector +import kscience.kmath.real.map +import kscience.kmath.real.step +import kscience.kmath.stat.* +import kscience.kmath.structures.asIterable +import kscience.kmath.structures.toList +import kscience.plotly.* +import kscience.plotly.models.ScatterMode +import kscience.plotly.models.TraceValues +import kotlin.math.pow +import kotlin.math.sqrt + +//Forward declaration of symbols that will be used in expressions. +// This declaration is required for +private val a by symbol +private val b by symbol +private val c by symbol + +/** + * Shortcut to use buffers in plotly + */ +operator fun TraceValues.invoke(vector: RealVector) { + numbers = vector.asIterable() +} + +/** + * Least squares fie with auto-differentiation. Uses `kmath-commons` and `kmath-for-real` modules. + */ +fun main() { + + //A generator for a normally distributed values + val generator = Distribution.normal() + + //A chain/flow of random values with the given seed + val chain = generator.sample(RandomGenerator.default(112667)) + + + //Create a uniformly distributed x values like numpy.arrange + val x = 1.0..100.0 step 1.0 + + + //Perform an operation on each x value (much more effective, than numpy) + val y = x.map { + val value = it.pow(2) + it + 1 + value + chain.nextDouble() * sqrt(value) + } + // this will also work, but less effective: + // val y = x.pow(2)+ x + 1 + chain.nextDouble() + + // create same errors for all xs + val yErr = y.map { sqrt(it) }//RealVector.same(x.size, sigma) + + // compute differentiable chi^2 sum for given model ax^2 + bx + c + val chi2 = Fitting.chiSquared(x, y, yErr) { x1 -> + //bind variables to autodiff context + val a = bind(a) + val b = bind(b) + //Include default value for c if it is not provided as a parameter + val c = bindOrNull(c) ?: one + a * x1.pow(2) + b * x1 + c + } + + //minimize the chi^2 in given starting point. Derivatives are not required, they are already included. + val result: OptimizationResult = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) + + //display a page with plot and numerical results + val page = Plotly.page { + plot { + scatter { + mode = ScatterMode.markers + x(x) + y(y) + error_y { + array = yErr.toList() + } + name = "data" + } + scatter { + mode = ScatterMode.lines + x(x) + y(x.map { result.point[a]!! * it.pow(2) + result.point[b]!! * it + 1 }) + name = "fit" + } + } + br() + h3{ + +"Fit result: $result" + } + h3{ + +"Chi2/dof = ${result.value / (x.size - 3)}" + } + } + + page.makeFile() +} \ No newline at end of file diff --git a/examples/src/main/kotlin/kscience/kmath/operations/BigIntDemo.kt b/examples/src/main/kotlin/kscience/kmath/operations/BigIntDemo.kt new file mode 100644 index 000000000..0e9811ff8 --- /dev/null +++ b/examples/src/main/kotlin/kscience/kmath/operations/BigIntDemo.kt @@ -0,0 +1,6 @@ +package kscience.kmath.operations + +fun main() { + val res = BigIntField { number(1) * 2 } + println("bigint:$res") +} \ No newline at end of file diff --git a/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt new file mode 100644 index 000000000..e84fd8df3 --- /dev/null +++ b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt @@ -0,0 +1,23 @@ +package kscience.kmath.operations + +import kscience.kmath.structures.NDElement +import kscience.kmath.structures.NDField +import kscience.kmath.structures.complex + +fun main() { + // 2d element + val element = NDElement.complex(2, 2) { (i,j) -> + Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble()) + } + println(element) + + // 1d element operation + val result = with(NDField.complex(8)) { + val a = produce { (it) -> i * it - it.toDouble() } + val b = 3 + val c = Complex(1.0, 1.0) + + (a pow b) + c + } + println(result) +} diff --git a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt similarity index 82% rename from examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionBenchmark.kt rename to examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt index b060cddb6..99d3cd504 100644 --- a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt @@ -1,26 +1,22 @@ -package scientifik.kmath.commons.prob +package kscience.kmath.commons.prob import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking +import kscience.kmath.stat.* import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler import org.apache.commons.rng.simple.RandomSource -import scientifik.kmath.chains.BlockingRealChain -import scientifik.kmath.prob.* import java.time.Duration import java.time.Instant - -private suspend fun runChain(): Duration { +private fun runChain(): Duration { val generator = RandomGenerator.fromSource(RandomSource.MT, 123L) - val normal = Distribution.normal(NormalSamplerMethod.Ziggurat) - val chain = normal.sample(generator) as BlockingRealChain - + val chain = normal.sample(generator) val startTime = Instant.now() var sum = 0.0 - repeat(10000001) { counter -> + repeat(10000001) { counter -> sum += chain.nextDouble() if (counter % 100000 == 0) { @@ -29,6 +25,7 @@ private suspend fun runChain(): Duration { println("Chain sampler completed $counter elements in $duration: $meanValue") } } + return Duration.between(startTime, Instant.now()) } @@ -36,10 +33,9 @@ private fun runDirect(): Duration { val provider = RandomSource.create(RandomSource.MT, 123L) val sampler = ZigguratNormalizedGaussianSampler(provider) val startTime = Instant.now() - var sum = 0.0 - repeat(10000001) { counter -> + repeat(10000001) { counter -> sum += sampler.sample() if (counter % 100000 == 0) { @@ -48,6 +44,7 @@ private fun runDirect(): Duration { println("Direct sampler completed $counter elements in $duration: $meanValue") } } + return Duration.between(startTime, Instant.now()) } @@ -56,16 +53,9 @@ private fun runDirect(): Duration { */ fun main() { runBlocking(Dispatchers.Default) { - val chainJob = async { - runChain() - } - - val directJob = async { - runDirect() - } - + val chainJob = async { runChain() } + val directJob = async { runDirect() } println("Chain: ${chainJob.await()}") println("Direct: ${directJob.await()}") } - -} \ No newline at end of file +} diff --git a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt b/examples/src/main/kotlin/kscience/kmath/stat/DistributionDemo.kt similarity index 52% rename from examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt rename to examples/src/main/kotlin/kscience/kmath/stat/DistributionDemo.kt index e059415dc..24a4cb1a7 100644 --- a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt +++ b/examples/src/main/kotlin/kscience/kmath/stat/DistributionDemo.kt @@ -1,15 +1,18 @@ -package scientifik.kmath.commons.prob +package kscience.kmath.stat import kotlinx.coroutines.runBlocking -import scientifik.kmath.chains.Chain -import scientifik.kmath.chains.collectWithState -import scientifik.kmath.prob.Distribution -import scientifik.kmath.prob.RandomGenerator -import scientifik.kmath.prob.normal +import kscience.kmath.chains.Chain +import kscience.kmath.chains.collectWithState -data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) +/** + * The state of distribution averager + */ +private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) -fun Chain.mean(): Chain = collectWithState(AveragingChainState(), { it.copy() }) { chain -> +/** + * Averaging + */ +private fun Chain.mean(): Chain = collectWithState(AveragingChainState(), { it.copy() }) { chain -> val next = chain.next() num++ value += next diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt similarity index 82% rename from examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt rename to examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt index 2329f3fc3..b69590473 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt @@ -1,9 +1,9 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.linear.transpose -import scientifik.kmath.operations.Complex -import scientifik.kmath.operations.ComplexField -import scientifik.kmath.operations.invoke +import kscience.kmath.linear.transpose +import kscience.kmath.operations.Complex +import kscience.kmath.operations.ComplexField +import kscience.kmath.operations.invoke import kotlin.system.measureTimeMillis fun main() { @@ -11,7 +11,7 @@ fun main() { val n = 1000 val realField = NDField.real(dim, dim) - val complexField = NDField.complex(dim, dim) + val complexField: ComplexNDField = NDField.complex(dim, dim) val realTime = measureTimeMillis { realField { diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt similarity index 75% rename from examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt rename to examples/src/main/kotlin/kscience/kmath/structures/NDField.kt index 1bc0ed7c8..b5130c92b 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt @@ -1,9 +1,10 @@ -package scientifik.kmath.structures +package kscience.kmath.structures import kotlinx.coroutines.GlobalScope -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.invoke -import kotlin.contracts.ExperimentalContracts +import kscience.kmath.nd4j.Nd4jArrayField +import kscience.kmath.operations.RealField +import kscience.kmath.operations.invoke +import org.nd4j.linalg.factory.Nd4j import kotlin.contracts.InvocationKind import kotlin.contracts.contract import kotlin.system.measureTimeMillis @@ -15,6 +16,8 @@ internal inline fun measureAndPrint(title: String, block: () -> Unit) { } fun main() { + // initializing Nd4j + Nd4j.zeros(0) val dim = 1000 val n = 1000 @@ -24,11 +27,13 @@ fun main() { val specializedField = NDField.real(dim, dim) //A generic boxing field. It should be used for objects, not primitives. val genericField = NDField.boxing(RealField, dim, dim) + // Nd4j specialized field. + val nd4jField = Nd4jArrayField.real(dim, dim) measureAndPrint("Automatic field addition") { autoField { var res: NDBuffer = one - repeat(n) { res += number(1.0) } + repeat(n) { res += 1.0 } } } @@ -44,6 +49,13 @@ fun main() { } } + measureAndPrint("Nd4j specialized addition") { + nd4jField { + var res = one + repeat(n) { res += 1.0 } + } + } + measureAndPrint("Lazy addition") { val res = specializedField.one.mapAsync(GlobalScope) { var c = 0.0 @@ -61,7 +73,7 @@ fun main() { genericField { var res: NDBuffer = one repeat(n) { - res += one // couldn't avoid using `one` due to resolution ambiguity } + res += 1.0 // couldn't avoid using `one` due to resolution ambiguity } } } } diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/structures/StructureReadBenchmark.kt similarity index 81% rename from examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt rename to examples/src/main/kotlin/kscience/kmath/structures/StructureReadBenchmark.kt index a33fdb2c4..51fd4f956 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureReadBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/StructureReadBenchmark.kt @@ -1,35 +1,33 @@ -package scientifik.kmath.structures +package kscience.kmath.structures import kotlin.system.measureTimeMillis -fun main(args: Array) { +fun main() { val n = 6000 - val array = DoubleArray(n * n) { 1.0 } val buffer = RealBuffer(array) val strides = DefaultStrides(intArrayOf(n, n)) - val structure = BufferNDStructure(strides, buffer) measureTimeMillis { - var res: Double = 0.0 + var res = 0.0 strides.indices().forEach { res = structure[it] } } // warmup val time1 = measureTimeMillis { - var res: Double = 0.0 + var res = 0.0 strides.indices().forEach { res = structure[it] } } println("Structure reading finished in $time1 millis") val time2 = measureTimeMillis { - var res: Double = 0.0 + var res = 0.0 strides.indices().forEach { res = buffer[strides.offset(it)] } } println("Buffer reading finished in $time2 millis") val time3 = measureTimeMillis { - var res: Double = 0.0 + var res = 0.0 strides.indices().forEach { res = array[strides.offset(it)] } } println("Array reading finished in $time3 millis") diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/structures/StructureWriteBenchmark.kt similarity index 73% rename from examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt rename to examples/src/main/kotlin/kscience/kmath/structures/StructureWriteBenchmark.kt index 0241f12ad..db55b454f 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/StructureWriteBenchmark.kt @@ -1,29 +1,20 @@ -package scientifik.kmath.structures +package kscience.kmath.structures import kotlin.system.measureTimeMillis - -fun main(args: Array) { - +fun main() { val n = 6000 - val structure = NDStructure.build(intArrayOf(n, n), Buffer.Companion::auto) { 1.0 } - structure.mapToBuffer { it + 1 } // warm-up - - val time1 = measureTimeMillis { - val res = structure.mapToBuffer { it + 1 } - } + val time1 = measureTimeMillis { val res = structure.mapToBuffer { it + 1 } } println("Structure mapping finished in $time1 millis") - val array = DoubleArray(n * n) { 1.0 } val time2 = measureTimeMillis { val target = DoubleArray(n * n) - val res = array.forEachIndexed { index, value -> - target[index] = value + 1 - } + val res = array.forEachIndexed { index, value -> target[index] = value + 1 } } + println("Array mapping finished in $time2 millis") val buffer = RealBuffer(DoubleArray(n * n) { 1.0 }) diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt b/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt similarity index 56% rename from examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt rename to examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt index 5d323823a..96684f7dc 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/typeSafeDimensions.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt @@ -1,12 +1,11 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.dimensions.D2 -import scientifik.kmath.dimensions.D3 -import scientifik.kmath.dimensions.DMatrixContext -import scientifik.kmath.dimensions.Dimension -import scientifik.kmath.operations.RealField +import kscience.kmath.dimensions.D2 +import kscience.kmath.dimensions.D3 +import kscience.kmath.dimensions.DMatrixContext +import kscience.kmath.dimensions.Dimension -fun DMatrixContext.simple() { +private fun DMatrixContext.simple() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i + j).toDouble() } @@ -14,12 +13,11 @@ fun DMatrixContext.simple() { m1.transpose() + m2 } - -object D5 : Dimension { +private object D5 : Dimension { override val dim: UInt = 5u } -fun DMatrixContext.custom() { +private fun DMatrixContext.custom() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i - j).toDouble() } val m3 = produce { i, j -> (i - j).toDouble() } diff --git a/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt deleted file mode 100644 index 17a70a4aa..000000000 --- a/examples/src/main/kotlin/scientifik/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ /dev/null @@ -1,70 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.asm.compile -import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.expressionInField -import scientifik.kmath.expressions.invoke -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.RealField -import kotlin.random.Random -import kotlin.system.measureTimeMillis - -class ExpressionsInterpretersBenchmark { - private val algebra: Field = RealField - fun functionalExpression() { - val expr = algebra.expressionInField { - variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) - } - - invokeAndSum(expr) - } - - fun mstExpression() { - val expr = algebra.mstInField { - symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) - } - - invokeAndSum(expr) - } - - fun asmExpression() { - val expr = algebra.mstInField { - symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) - }.compile() - - invokeAndSum(expr) - } - - private fun invokeAndSum(expr: Expression) { - val random = Random(0) - var sum = 0.0 - - repeat(1000000) { - sum += expr("x" to random.nextDouble()) - } - - println(sum) - } -} - -fun main() { - val benchmark = ExpressionsInterpretersBenchmark() - - val fe = measureTimeMillis { - benchmark.functionalExpression() - } - - println("fe=$fe") - - val mst = measureTimeMillis { - benchmark.mstExpression() - } - - println("mst=$mst") - - val asm = measureTimeMillis { - benchmark.asmExpression() - } - - println("asm=$asm") -} diff --git a/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt deleted file mode 100644 index 6cc5411b8..000000000 --- a/examples/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt +++ /dev/null @@ -1,55 +0,0 @@ -package scientifik.kmath.linear - -import koma.matrix.ejml.EJMLMatrixFactory -import scientifik.kmath.commons.linear.CMMatrixContext -import scientifik.kmath.commons.linear.inverse -import scientifik.kmath.commons.linear.toCM -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.invoke -import scientifik.kmath.structures.Matrix -import kotlin.contracts.ExperimentalContracts -import kotlin.random.Random -import kotlin.system.measureTimeMillis - -@ExperimentalContracts -fun main() { - val random = Random(1224) - val dim = 100 - //creating invertible matrix - val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } - val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 } - val matrix = l dot u - - val n = 5000 // iterations - - MatrixContext.real { - repeat(50) { val res = inverse(matrix) } - val inverseTime = measureTimeMillis { repeat(n) { val res = inverse(matrix) } } - println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis") - } - - //commons-math - - val commonsTime = measureTimeMillis { - CMMatrixContext { - val cm = matrix.toCM() //avoid overhead on conversion - repeat(n) { val res = inverse(cm) } - } - } - - - println("[commons-math] Inversion of $n matrices $dim x $dim finished in $commonsTime millis") - - //koma-ejml - - val komaTime = measureTimeMillis { - (KomaMatrixContext(EJMLMatrixFactory(), RealField)) { - val km = matrix.toKoma() //avoid overhead on conversion - repeat(n) { - val res = inverse(km) - } - } - } - - println("[koma-ejml] Inversion of $n matrices $dim x $dim finished in $komaTime millis") -} \ No newline at end of file diff --git a/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt deleted file mode 100644 index 3ae550682..000000000 --- a/examples/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt +++ /dev/null @@ -1,49 +0,0 @@ -package scientifik.kmath.linear - -import koma.matrix.ejml.EJMLMatrixFactory -import scientifik.kmath.commons.linear.CMMatrixContext -import scientifik.kmath.commons.linear.toCM -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.invoke -import scientifik.kmath.structures.Matrix -import kotlin.random.Random -import kotlin.system.measureTimeMillis - -fun main() { - val random = Random(12224) - val dim = 1000 - //creating invertible matrix - val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } - val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } - -// //warmup -// matrix1 dot matrix2 - - CMMatrixContext { - val cmMatrix1 = matrix1.toCM() - val cmMatrix2 = matrix2.toCM() - - val cmTime = measureTimeMillis { - cmMatrix1 dot cmMatrix2 - } - - println("CM implementation time: $cmTime") - } - - (KomaMatrixContext(EJMLMatrixFactory(), RealField)) { - val komaMatrix1 = matrix1.toKoma() - val komaMatrix2 = matrix2.toKoma() - - val komaTime = measureTimeMillis { - komaMatrix1 dot komaMatrix2 - } - - println("Koma-ejml implementation time: $komaTime") - } - - val genericTime = measureTimeMillis { - val res = matrix1 dot matrix2 - } - - println("Generic implementation time: $genericTime") -} \ No newline at end of file diff --git a/examples/src/main/kotlin/scientifik/kmath/operations/BigIntDemo.kt b/examples/src/main/kotlin/scientifik/kmath/operations/BigIntDemo.kt deleted file mode 100644 index 10b038943..000000000 --- a/examples/src/main/kotlin/scientifik/kmath/operations/BigIntDemo.kt +++ /dev/null @@ -1,8 +0,0 @@ -package scientifik.kmath.operations - -fun main() { - val res = BigIntField { - number(1) * 2 - } - println("bigint:$res") -} \ No newline at end of file diff --git a/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt deleted file mode 100644 index 6dbfebce1..000000000 --- a/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt +++ /dev/null @@ -1,19 +0,0 @@ -package scientifik.kmath.operations - -import scientifik.kmath.structures.NDElement -import scientifik.kmath.structures.NDField -import scientifik.kmath.structures.complex - -fun main() { - val element = NDElement.complex(2, 2) { index: IntArray -> - Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) - } - - val compute = (NDField.complex(8)) { - val a = produce { (it) -> i * it - it.toDouble() } - val b = 3 - val c = Complex(1.0, 1.0) - - (a pow b) + c - } -} diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 000000000..930bba550 --- /dev/null +++ b/gradle.properties @@ -0,0 +1,9 @@ +kotlin.code.style=official +kotlin.parallel.tasks.in.project=true +kotlin.mpp.enableGranularSourceSetsMetadata=true +kotlin.native.enableDependencyPropagation=false +kotlin.mpp.stability.nowarn=true + +org.gradle.jvmargs=-XX:MaxMetaspaceSize=512m +org.gradle.parallel=true +systemProp.org.gradle.internal.publish.checksums.insecure=true \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 62d4c0535..e708b1c02 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index bb8b2fc26..da9702f9e 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.5.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.8-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index fbd7c5158..4f906e0c8 100755 --- a/gradlew +++ b/gradlew @@ -130,7 +130,7 @@ fi if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - + JAVACMD=`cygpath --unix "$JAVACMD"` # We build the pattern for arguments to be converted via cygpath diff --git a/gradlew.bat b/gradlew.bat index 5093609d5..107acd32c 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init +if "%ERRORLEVEL%" == "0" goto execute echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. @@ -54,7 +54,7 @@ goto fail set JAVA_HOME=%JAVA_HOME:"=% set JAVA_EXE=%JAVA_HOME%/bin/java.exe -if exist "%JAVA_EXE%" goto init +if exist "%JAVA_EXE%" goto execute echo. echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% @@ -64,21 +64,6 @@ echo location of your Java installation. goto fail -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - :execute @rem Setup the command line @@ -86,7 +71,7 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar @rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* :end @rem End local scope for the variables with windows NT shell diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 2339d0426..19e9ee4a9 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -2,72 +2,85 @@ This subproject implements the following features: -- Expression Language and its parser. -- MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation. -- Type-safe builder for MST. -- Evaluating expressions by traversing MST. + - [expression-language](src/jvmMain/kotlin/kscience/kmath/ast/parser.kt) : Expression language and its parser + - [mst](src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST (Mathematical Syntax Tree) as expression language's syntax intermediate representation + - [mst-building](src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt) : MST building algebraic structure + - [mst-interpreter](src/commonMain/kotlin/kscience/kmath/ast/MST.kt) : MST interpreter + - [mst-jvm-codegen](src/jvmMain/kotlin/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler + - [mst-js-codegen](src/jsMain/kotlin/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler + > #### Artifact: -> This module is distributed in the artifact `scientifik:kmath-ast:0.1.4-dev-8`. -> +> +> This module artifact: `kscience.kmath:kmath-ast:0.2.0-dev-4`. +> +> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-ast/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-ast/_latestVersion) +> +> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-ast/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-ast/_latestVersion) +> > **Gradle:** > > ```gradle > repositories { -> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } +> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } +> maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } -> maven { url https://dl.bintray.com/hotkeytlt/maven' } +> maven { url 'https://dl.bintray.com/hotkeytlt/maven' } +> > } > > dependencies { -> implementation 'scientifik:kmath-ast:0.1.4-dev-8' +> implementation 'kscience.kmath:kmath-ast:0.2.0-dev-4' > } > ``` > **Gradle Kotlin DSL:** > > ```kotlin > repositories { -> maven("https://dl.bintray.com/mipt-npm/scientifik") +> maven("https://dl.bintray.com/kotlin/kotlin-eap") +> maven("https://dl.bintray.com/mipt-npm/kscience") > maven("https://dl.bintray.com/mipt-npm/dev") > maven("https://dl.bintray.com/hotkeytlt/maven") > } > > dependencies { -> implementation("scientifik:kmath-ast:0.1.4-dev-8") +> implementation("kscience.kmath:kmath-ast:0.2.0-dev-4") > } > ``` -> -## Dynamic Expression Code Generation with ObjectWeb ASM +## Dynamic expression code generation -`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds -a special implementation of `Expression` with implemented `invoke` function. +### On JVM -For example, the following builder: +`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds +a special implementation of `Expression` with implemented `invoke` function. + +For example, the following builder: ```kotlin RealField.mstInField { symbol("x") + 2 }.compile() ``` -… leads to generation of bytecode, which can be decompiled to the following Java class: +… leads to generation of bytecode, which can be decompiled to the following Java class: ```java -package scientifik.kmath.asm.generated; +package kscience.kmath.asm.generated; import java.util.Map; -import scientifik.kmath.asm.internal.MapIntrinsics; -import scientifik.kmath.expressions.Expression; -import scientifik.kmath.operations.RealField; +import kotlin.jvm.functions.Function2; +import kscience.kmath.asm.internal.MapIntrinsics; +import kscience.kmath.expressions.Expression; +import kscience.kmath.expressions.Symbol; -public final class AsmCompiledExpression_1073786867_0 implements Expression { - private final RealField algebra; +public final class AsmCompiledExpression_45045_0 implements Expression { + private final Object[] constants; - public final Double invoke(Map arguments) { - return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D); + public final Double invoke(Map arguments) { + return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2); } - public AsmCompiledExpression_1073786867_0(RealField algebra) { - this.algebra = algebra; + public AsmCompiledExpression_45045_0(Object[] constants) { + this.constants = constants; } } @@ -75,17 +88,35 @@ public final class AsmCompiledExpression_1073786867_0 implements Expression` with implemented `invoke` function. + +For example, the following builder: + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +``` + +… leads to generation of bytecode, which can be decompiled to the following Java class: + +```java +package kscience.kmath.asm.generated; + +import java.util.Map; +import kotlin.jvm.functions.Function2; +import kscience.kmath.asm.internal.MapIntrinsics; +import kscience.kmath.expressions.Expression; +import kscience.kmath.expressions.Symbol; + +public final class AsmCompiledExpression_45045_0 implements Expression { + private final Object[] constants; + + public final Double invoke(Map arguments) { + return (Double)((Function2)this.constants[0]).invoke((Double)MapIntrinsics.getOrFail(arguments, "x"), 2); + } + + public AsmCompiledExpression_45045_0(Object[] constants) { + this.constants = constants; + } +} + +``` + +### Example Usage + +This API extends MST and MstExpression, so you may optimize as both of them: + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +RealField.expression("x+2".parseMath()) +``` + +#### Known issues + +- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid + class loading overhead. +- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. + +### On JS + +A similar feature is also available on JS. + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +``` + +The code above returns expression implemented with such a JS function: + +```js +var executable = function (constants, arguments) { + return constants[1](constants[0](arguments, "x"), 2); +}; +``` + +#### Known issues + +- This feature uses `eval` which can be unavailable in several environments. diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt new file mode 100644 index 000000000..212fd0d0b --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt @@ -0,0 +1,86 @@ +package kscience.kmath.ast + +import kscience.kmath.operations.Algebra +import kscience.kmath.operations.NumericAlgebra + +/** + * A Mathematical Syntax Tree (MST) node for mathematical expressions. + * + * @author Alexander Nozik + */ +public sealed class MST { + /** + * A node containing raw string. + * + * @property value the value of this node. + */ + public data class Symbolic(val value: String) : MST() + + /** + * A node containing a numeric value or scalar. + * + * @property value the value of this number. + */ + public data class Numeric(val value: Number) : MST() + + /** + * A node containing an unary operation. + * + * @property operation the identifier of operation. + * @property value the argument of this operation. + */ + public data class Unary(val operation: String, val value: MST) : MST() + + /** + * A node containing binary operation. + * + * @property operation the identifier operation. + * @property left the left operand. + * @property right the right operand. + */ + public data class Binary(val operation: String, val left: MST, val right: MST) : MST() +} + +// TODO add a function with named arguments + +/** + * Interprets the [MST] node with this [Algebra]. + * + * @receiver the algebra that provides operations. + * @param node the node to evaluate. + * @return the value of expression. + * @author Alexander Nozik + */ +public fun Algebra.evaluate(node: MST): T = when (node) { + is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) + ?: error("Numeric nodes are not supported by $this") + + is MST.Symbolic -> symbol(node.value) + + is MST.Unary -> when { + this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value)) + else -> unaryOperationFunction(node.operation)(evaluate(node.value)) + } + + is MST.Binary -> when { + this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric -> + binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value)) + + this is NumericAlgebra && node.left is MST.Numeric -> + leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right)) + + this is NumericAlgebra && node.right is MST.Numeric -> + rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value) + + else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right)) + } +} + +/** + * Interprets the [MST] node with this [Algebra]. + * + * @receiver the node to evaluate. + * @param algebra the algebra that provides operations. + * @return the value of expression. + */ +public fun MST.interpret(algebra: Algebra): T = algebra.evaluate(this) diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt new file mode 100644 index 000000000..eadbc85ee --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -0,0 +1,149 @@ +package kscience.kmath.ast + +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.* + +/** + * [Algebra] over [MST] nodes. + */ +public object MstAlgebra : NumericAlgebra { + public override fun number(value: Number): MST.Numeric = MST.Numeric(value) + public override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value) + + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + { arg -> MST.Unary(operation, arg) } + + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + { left, right -> MST.Binary(operation, left, right) } +} + +/** + * [Space] over [MST] nodes. + */ +public object MstSpace : Space, NumericAlgebra { + public override val zero: MST.Numeric by lazy { number(0.0) } + + public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) + public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) + public override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(SpaceOperations.PLUS_OPERATION)(a, b) + public override operator fun MST.unaryPlus(): MST.Unary = + unaryOperationFunction(SpaceOperations.PLUS_OPERATION)(this) + + public override operator fun MST.unaryMinus(): MST.Unary = + unaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this) + + public override operator fun MST.minus(b: MST): MST.Binary = + binaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this, b) + + public override fun multiply(a: MST, k: Number): MST.Binary = + binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(k)) + + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + MstAlgebra.binaryOperationFunction(operation) + + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstAlgebra.unaryOperationFunction(operation) +} + +/** + * [Ring] over [MST] nodes. + */ +@OptIn(UnstableKMathAPI::class) +public object MstRing : Ring, RingWithNumbers { + public override val zero: MST.Numeric + get() = MstSpace.zero + + public override val one: MST.Numeric by lazy { number(1.0) } + + public override fun number(value: Number): MST.Numeric = MstSpace.number(value) + public override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) + public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = + binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) + + public override operator fun MST.unaryPlus(): MST.Unary = MstSpace { +this@unaryPlus } + public override operator fun MST.unaryMinus(): MST.Unary = MstSpace { -this@unaryMinus } + public override operator fun MST.minus(b: MST): MST.Binary = MstSpace { this@minus - b } + + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + MstSpace.binaryOperationFunction(operation) + + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstAlgebra.unaryOperationFunction(operation) +} + +/** + * [Field] over [MST] nodes. + */ +@OptIn(UnstableKMathAPI::class) +public object MstField : Field, RingWithNumbers { + public override val zero: MST.Numeric + get() = MstRing.zero + + public override val one: MST.Numeric + get() = MstRing.one + + public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value) + public override fun number(value: Number): MST.Numeric = MstRing.number(value) + public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = + binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) + + public override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } + public override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } + public override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b } + + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + MstRing.binaryOperationFunction(operation) + + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstRing.unaryOperationFunction(operation) +} + +/** + * [ExtendedField] over [MST] nodes. + */ +public object MstExtendedField : ExtendedField, NumericAlgebra { + public override val zero: MST.Numeric + get() = MstField.zero + + public override val one: MST.Numeric + get() = MstField.one + + public override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) + public override fun number(value: Number): MST.Numeric = MstRing.number(value) + public override fun sin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) + public override fun cos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg) + public override fun tan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.TAN_OPERATION)(arg) + public override fun asin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ASIN_OPERATION)(arg) + public override fun acos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ACOS_OPERATION)(arg) + public override fun atan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ATAN_OPERATION)(arg) + public override fun sinh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.SINH_OPERATION)(arg) + public override fun cosh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.COSH_OPERATION)(arg) + public override fun tanh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.TANH_OPERATION)(arg) + public override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.ASINH_OPERATION)(arg) + public override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.ACOSH_OPERATION)(arg) + public override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(HyperbolicOperations.ATANH_OPERATION)(arg) + public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) + public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) + public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) + public override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus } + public override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus } + public override operator fun MST.minus(b: MST): MST.Binary = MstField { this@minus - b } + + public override fun power(arg: MST, pow: Number): MST.Binary = + binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) + + public override fun exp(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg) + public override fun ln(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg) + + public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + MstField.binaryOperationFunction(operation) + + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstField.unaryOperationFunction(operation) +} diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt new file mode 100644 index 000000000..03d33aa2b --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt @@ -0,0 +1,127 @@ +package kscience.kmath.ast + +import kscience.kmath.expressions.* +import kscience.kmath.operations.* +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/** + * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than + * ASM-generated expressions. + * + * @property algebra the algebra that provides operations. + * @property mst the [MST] node. + * @author Alexander Nozik + */ +public class MstExpression>(public val algebra: A, public val mst: MST) : Expression { + private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { + override fun symbol(value: String): T = try { + algebra.symbol(value) + } catch (ignored: IllegalStateException) { + null + } ?: arguments.getValue(StringSymbol(value)) + + override fun unaryOperationFunction(operation: String): (arg: T) -> T = algebra.unaryOperationFunction(operation) + override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = algebra.binaryOperationFunction(operation) + + @Suppress("UNCHECKED_CAST") + override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) + (algebra as NumericAlgebra).number(value) + else + error("Numeric nodes are not supported by $this") + } + + override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) +} + +/** + * Builds [MstExpression] over [Algebra]. + * + * @author Alexander Nozik + */ +public inline fun , E : Algebra> A.mst( + mstAlgebra: E, + block: E.() -> MST, +): MstExpression = MstExpression(this, mstAlgebra.block()) + +/** + * Builds [MstExpression] over [Space]. + * + * @author Alexander Nozik + */ +public inline fun > A.mstInSpace(block: MstSpace.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstSpace.block()) +} + +/** + * Builds [MstExpression] over [Ring]. + * + * @author Alexander Nozik + */ +public inline fun > A.mstInRing(block: MstRing.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstRing.block()) +} + +/** + * Builds [MstExpression] over [Field]. + * + * @author Alexander Nozik + */ +public inline fun > A.mstInField(block: MstField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstField.block()) +} + +/** + * Builds [MstExpression] over [ExtendedField]. + * + * @author Iaroslav Postovalov + */ +public inline fun > A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MstExpression(this, MstExtendedField.block()) +} + +/** + * Builds [MstExpression] over [FunctionalExpressionSpace]. + * + * @author Alexander Nozik + */ +public inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInSpace(block) +} + +/** + * Builds [MstExpression] over [FunctionalExpressionRing]. + * + * @author Alexander Nozik + */ +public inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInRing(block) +} + +/** + * Builds [MstExpression] over [FunctionalExpressionField]. + * + * @author Alexander Nozik + */ +public inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInField(block) +} + +/** + * Builds [MstExpression] over [FunctionalExpressionExtendedField]. + * + * @author Iaroslav Postovalov + */ +public inline fun > FunctionalExpressionExtendedField.mstInExtendedField( + block: MstExtendedField.() -> MST, +): MstExpression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return algebra.mstInExtendedField(block) +} diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt deleted file mode 100644 index 0e8151c04..000000000 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt +++ /dev/null @@ -1,87 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.NumericAlgebra -import scientifik.kmath.operations.RealField - -/** - * A Mathematical Syntax Tree node for mathematical expressions. - */ -sealed class MST { - /** - * A node containing raw string. - * - * @property value the value of this node. - */ - data class Symbolic(val value: String) : MST() - - /** - * A node containing a numeric value or scalar. - * - * @property value the value of this number. - */ - data class Numeric(val value: Number) : MST() - - /** - * A node containing an unary operation. - * - * @property operation the identifier of operation. - * @property value the argument of this operation. - */ - data class Unary(val operation: String, val value: MST) : MST() { - companion object - } - - /** - * A node containing binary operation. - * - * @property operation the identifier operation. - * @property left the left operand. - * @property right the right operand. - */ - data class Binary(val operation: String, val left: MST, val right: MST) : MST() { - companion object - } -} - -// TODO add a function with named arguments - -/** - * Interprets the [MST] node with this [Algebra]. - * - * @receiver the algebra that provides operations. - * @param node the node to evaluate. - * @return the value of expression. - */ -fun Algebra.evaluate(node: MST): T = when (node) { - is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) - ?: error("Numeric nodes are not supported by $this") - is MST.Symbolic -> symbol(node.value) - is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) - is MST.Binary -> when { - this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) - - node.left is MST.Numeric && node.right is MST.Numeric -> { - val number = RealField.binaryOperation( - node.operation, - node.left.value.toDouble(), - node.right.value.toDouble() - ) - - number(number) - } - - node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right)) - node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value) - else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right)) - } -} - -/** - * Interprets the [MST] node with this [Algebra]. - * - * @receiver the node to evaluate. - * @param algebra the algebra that provides operations. - * @return the value of expression. - */ -fun MST.interpret(algebra: Algebra): T = algebra.evaluate(this) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt deleted file mode 100644 index 23deae24b..000000000 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstAlgebra.kt +++ /dev/null @@ -1,102 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.operations.* - -/** - * [Algebra] over [MST] nodes. - */ -object MstAlgebra : NumericAlgebra { - override fun number(value: Number): MST = MST.Numeric(value) - - override fun symbol(value: String): MST = MST.Symbolic(value) - - override fun unaryOperation(operation: String, arg: MST): MST = - MST.Unary(operation, arg) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MST.Binary(operation, left, right) -} - -/** - * [Space] over [MST] nodes. - */ -object MstSpace : Space, NumericAlgebra { - override val zero: MST = number(0.0) - - override fun number(value: Number): MST = MstAlgebra.number(value) - override fun symbol(value: String): MST = MstAlgebra.symbol(value) - override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MstAlgebra.binaryOperation(operation, left, right) - - override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) -} - -/** - * [Ring] over [MST] nodes. - */ -object MstRing : Ring, NumericAlgebra { - override val zero: MST = number(0.0) - override val one: MST = number(1.0) - - override fun number(value: Number): MST = MstSpace.number(value) - override fun symbol(value: String): MST = MstSpace.symbol(value) - override fun add(a: MST, b: MST): MST = MstSpace.add(a, b) - - override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) - - override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MstSpace.binaryOperation(operation, left, right) - - override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) -} - -/** - * [Field] over [MST] nodes. - */ -object MstField : Field { - override val zero: MST = number(0.0) - override val one: MST = number(1.0) - - override fun symbol(value: String): MST = MstRing.symbol(value) - override fun number(value: Number): MST = MstRing.number(value) - override fun add(a: MST, b: MST): MST = MstRing.add(a, b) - override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) - override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) - override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MstRing.binaryOperation(operation, left, right) - - override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg) -} - -/** - * [ExtendedField] over [MST] nodes. - */ -object MstExtendedField : ExtendedField { - override val zero: MST = number(0.0) - override val one: MST = number(1.0) - - override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) - override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) - override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) - override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) - override fun add(a: MST, b: MST): MST = MstField.add(a, b) - override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) - override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) - override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) - override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) - override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) - override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg) - - override fun binaryOperation(operation: String, left: MST, right: MST): MST = - MstField.binaryOperation(operation, left, right) - - override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg) -} diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt deleted file mode 100644 index 3cee33956..000000000 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MstExpression.kt +++ /dev/null @@ -1,103 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.expressions.* -import scientifik.kmath.operations.* -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract - -/** - * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than - * ASM-generated expressions. - * - * @property algebra the algebra that provides operations. - * @property mst the [MST] node. - */ -class MstExpression(val algebra: Algebra, val mst: MST) : Expression { - private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { - override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) - override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: T, right: T): T = - algebra.binaryOperation(operation, left, right) - - override fun number(value: Number): T = if (algebra is NumericAlgebra) - algebra.number(value) - else - error("Numeric nodes are not supported by $this") - } - - override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) -} - -/** - * Builds [MstExpression] over [Algebra]. - */ -inline fun , E : Algebra> A.mst( - mstAlgebra: E, - block: E.() -> MST -): MstExpression = MstExpression(this, mstAlgebra.block()) - -/** - * Builds [MstExpression] over [Space]. - */ -inline fun Space.mstInSpace(block: MstSpace.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return MstExpression(this, MstSpace.block()) -} - -/** - * Builds [MstExpression] over [Ring]. - */ -inline fun Ring.mstInRing(block: MstRing.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return MstExpression(this, MstRing.block()) -} - -/** - * Builds [MstExpression] over [Field]. - */ -inline fun Field.mstInField(block: MstField.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return MstExpression(this, MstField.block()) -} - -/** - * Builds [MstExpression] over [ExtendedField]. - */ -inline fun Field.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return MstExpression(this, MstExtendedField.block()) -} - -/** - * Builds [MstExpression] over [FunctionalExpressionSpace]. - */ -inline fun > FunctionalExpressionSpace.mstInSpace(block: MstSpace.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return algebra.mstInSpace(block) -} - -/** - * Builds [MstExpression] over [FunctionalExpressionRing]. - */ -inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return algebra.mstInRing(block) -} - -/** - * Builds [MstExpression] over [FunctionalExpressionField]. - */ -inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return algebra.mstInField(block) -} - -/** - * Builds [MstExpression] over [FunctionalExpressionExtendedField]. - */ -inline fun > FunctionalExpressionExtendedField.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return algebra.mstInExtendedField(block) -} diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt new file mode 100644 index 000000000..5c08ada31 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt @@ -0,0 +1,82 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.MST +import kscience.kmath.ast.MST.* +import kscience.kmath.ast.MstExpression +import kscience.kmath.estree.internal.ESTreeBuilder +import kscience.kmath.estree.internal.estree.BaseExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.operations.Algebra +import kscience.kmath.operations.NumericAlgebra + +@PublishedApi +internal fun MST.compileWith(algebra: Algebra): Expression { + fun ESTreeBuilder.visit(node: MST): BaseExpression = when (node) { + is Symbolic -> { + val symbol = try { + algebra.symbol(node.value) + } catch (ignored: IllegalStateException) { + null + } + + if (symbol != null) + constant(symbol) + else + variable(node.value) + } + + is Numeric -> constant(node.value) + + is Unary -> when { + algebra is NumericAlgebra && node.value is Numeric -> constant( + algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value))) + + else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value)) + } + + is Binary -> when { + algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant( + algebra + .binaryOperationFunction(node.operation) + .invoke(algebra.number(node.left.value), algebra.number(node.right.value)) + ) + + algebra is NumericAlgebra && node.left is Numeric -> call( + algebra.leftSideNumberOperationFunction(node.operation), + visit(node.left), + visit(node.right), + ) + + algebra is NumericAlgebra && node.right is Numeric -> call( + algebra.rightSideNumberOperationFunction(node.operation), + visit(node.left), + visit(node.right), + ) + + else -> call( + algebra.binaryOperationFunction(node.operation), + visit(node.left), + visit(node.right), + ) + } + } + + return ESTreeBuilder { visit(this@compileWith) }.instance +} + + +/** + * Compiles an [MST] to ESTree generated expression using given algebra. + * + * @author Alexander Nozik. + */ +public fun Algebra.expression(mst: MST): Expression = + mst.compileWith(this) + +/** + * Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression. + * + * @author Alexander Nozik. + */ +public fun MstExpression>.compile(): Expression = + mst.compileWith(algebra) diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/ESTreeBuilder.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/ESTreeBuilder.kt new file mode 100644 index 000000000..e1823813a --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/ESTreeBuilder.kt @@ -0,0 +1,79 @@ +package kscience.kmath.estree.internal + +import kscience.kmath.estree.internal.astring.generate +import kscience.kmath.estree.internal.estree.* +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.Symbol + +internal class ESTreeBuilder(val bodyCallback: ESTreeBuilder.() -> BaseExpression) { + private class GeneratedExpression(val executable: dynamic, val constants: Array) : Expression { + @Suppress("UNUSED_VARIABLE") + override fun invoke(arguments: Map): T { + val e = executable + val c = constants + val a = js("{}") + arguments.forEach { (key, value) -> a[key.identity] = value } + return js("e(c, a)").unsafeCast() + } + } + + val instance: Expression by lazy { + val node = Program( + sourceType = "script", + VariableDeclaration( + kind = "var", + VariableDeclarator( + id = Identifier("executable"), + init = FunctionExpression( + params = arrayOf(Identifier("constants"), Identifier("arguments")), + body = BlockStatement(ReturnStatement(bodyCallback())), + ), + ), + ), + ) + + eval(generate(node)) + GeneratedExpression(js("executable"), constants.toTypedArray()) + } + + private val constants = mutableListOf() + + fun constant(value: Any?) = when { + value == null || jsTypeOf(value) == "number" || jsTypeOf(value) == "string" || jsTypeOf(value) == "boolean" -> + SimpleLiteral(value) + + jsTypeOf(value) == "undefined" -> Identifier("undefined") + + else -> { + val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex + + MemberExpression( + computed = true, + optional = false, + `object` = Identifier("constants"), + property = SimpleLiteral(idx), + ) + } + } + + fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name)) + + fun call(function: Function, vararg args: BaseExpression): BaseExpression = SimpleCallExpression( + optional = false, + callee = constant(function), + *args, + ) + + private companion object { + @Suppress("UNUSED_VARIABLE") + val getOrFail: (`object`: dynamic, key: String) -> dynamic = { `object`, key -> + val k = key + val o = `object` + + if (!(js("k in o") as Boolean)) + throw NoSuchElementException("Key $key is missing in the map.") + + js("o[k]") + } + } +} diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.kt new file mode 100644 index 000000000..cf0a8de25 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.kt @@ -0,0 +1,33 @@ +@file:JsModule("astring") +@file:JsNonModule + +package kscience.kmath.estree.internal.astring + +import kscience.kmath.estree.internal.estree.BaseNode + +internal external interface Options { + var indent: String? + get() = definedExternally + set(value) = definedExternally + var lineEnd: String? + get() = definedExternally + set(value) = definedExternally + var startingIndentLevel: Number? + get() = definedExternally + set(value) = definedExternally + var comments: Boolean? + get() = definedExternally + set(value) = definedExternally + var generator: Any? + get() = definedExternally + set(value) = definedExternally + var sourceMap: Any? + get() = definedExternally + set(value) = definedExternally +} + +internal external fun generate(node: BaseNode, options: Options /* Options & `T$0` */ = definedExternally): String + +internal external fun generate(node: BaseNode): String + +internal external var baseGenerator: Generator diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.typealises.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.typealises.kt new file mode 100644 index 000000000..5a7fe4f16 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/astring/astring.typealises.kt @@ -0,0 +1,3 @@ +package kscience.kmath.estree.internal.astring + +internal typealias Generator = Any diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/emitter/emitter.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/emitter/emitter.kt new file mode 100644 index 000000000..1e0a95a16 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/emitter/emitter.kt @@ -0,0 +1,13 @@ +package kscience.kmath.estree.internal.emitter + +internal open external class Emitter { + constructor(obj: Any) + constructor() + + open fun on(event: String, fn: () -> Unit) + open fun off(event: String, fn: () -> Unit) + open fun once(event: String, fn: () -> Unit) + open fun emit(event: String, vararg any: Any) + open fun listeners(event: String): Array<() -> Unit> + open fun hasListeners(event: String): Boolean +} diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.extensions.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.extensions.kt new file mode 100644 index 000000000..5bc197d0c --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.extensions.kt @@ -0,0 +1,62 @@ +package kscience.kmath.estree.internal.estree + +internal fun Program(sourceType: String, vararg body: dynamic) = object : Program { + override var type = "Program" + override var sourceType = sourceType + override var body = body +} + +internal fun VariableDeclaration(kind: String, vararg declarations: VariableDeclarator) = object : VariableDeclaration { + override var type = "VariableDeclaration" + override var declarations = declarations.toList().toTypedArray() + override var kind = kind +} + +internal fun VariableDeclarator(id: dynamic, init: dynamic) = object : VariableDeclarator { + override var type = "VariableDeclarator" + override var id = id + override var init = init +} + +internal fun Identifier(name: String) = object : Identifier { + override var type = "Identifier" + override var name = name +} + +internal fun FunctionExpression(params: Array, body: BlockStatement) = object : FunctionExpression { + override var params = params + override var type = "FunctionExpression" + override var body = body +} + +internal fun BlockStatement(vararg body: dynamic) = object : BlockStatement { + override var type = "BlockStatement" + override var body = body +} + +internal fun ReturnStatement(argument: dynamic) = object : ReturnStatement { + override var type = "ReturnStatement" + override var argument = argument +} + +internal fun SimpleLiteral(value: dynamic) = object : SimpleLiteral { + override var type = "Literal" + override var value = value +} + +internal fun MemberExpression(computed: Boolean, optional: Boolean, `object`: dynamic, property: dynamic) = + object : MemberExpression { + override var type = "MemberExpression" + override var computed = computed + override var optional = optional + override var `object` = `object` + override var property = property + } + +internal fun SimpleCallExpression(optional: Boolean, callee: dynamic, vararg arguments: dynamic) = + object : SimpleCallExpression { + override var type = "CallExpression" + override var optional = optional + override var callee = callee + override var arguments = arguments + } diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.kt new file mode 100644 index 000000000..a5385d1ee --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/estree/estree.kt @@ -0,0 +1,644 @@ +package kscience.kmath.estree.internal.estree + +import kotlin.js.RegExp + +internal external interface BaseNodeWithoutComments { + var type: String + var loc: SourceLocation? + get() = definedExternally + set(value) = definedExternally + var range: dynamic /* JsTuple */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseNode : BaseNodeWithoutComments { + var leadingComments: Array? + get() = definedExternally + set(value) = definedExternally + var trailingComments: Array? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface Comment : BaseNodeWithoutComments { + override var type: String /* "Line" | "Block" */ + var value: String +} + +internal external interface SourceLocation { + var source: String? + get() = definedExternally + set(value) = definedExternally + var start: Position + var end: Position +} + +internal external interface Position { + var line: Number + var column: Number +} + +internal external interface Program : BaseNode { + override var type: String /* "Program" */ + var sourceType: String /* "script" | "module" */ + var body: Array + var comments: Array? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface Directive : BaseNode { + override var type: String /* "ExpressionStatement" */ + var expression: dynamic /* SimpleLiteral | RegExpLiteral */ + get() = definedExternally + set(value) = definedExternally + var directive: String +} + +internal external interface BaseFunction : BaseNode { + var params: Array + var generator: Boolean? + get() = definedExternally + set(value) = definedExternally + var async: Boolean? + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* BlockStatement | ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseStatement : BaseNode + +internal external interface EmptyStatement : BaseStatement { + override var type: String /* "EmptyStatement" */ +} + +internal external interface BlockStatement : BaseStatement { + override var type: String /* "BlockStatement" */ + var body: Array + var innerComments: Array? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ExpressionStatement : BaseStatement { + override var type: String /* "ExpressionStatement" */ + var expression: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface IfStatement : BaseStatement { + override var type: String /* "IfStatement" */ + var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var consequent: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally + var alternate: dynamic /* ExpressionStatement? | BlockStatement? | EmptyStatement? | DebuggerStatement? | WithStatement? | ReturnStatement? | LabeledStatement? | BreakStatement? | ContinueStatement? | IfStatement? | SwitchStatement? | ThrowStatement? | TryStatement? | WhileStatement? | DoWhileStatement? | ForStatement? | ForInStatement? | ForOfStatement? | FunctionDeclaration? | VariableDeclaration? | ClassDeclaration? */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface LabeledStatement : BaseStatement { + override var type: String /* "LabeledStatement" */ + var label: Identifier + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BreakStatement : BaseStatement { + override var type: String /* "BreakStatement" */ + var label: Identifier? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ContinueStatement : BaseStatement { + override var type: String /* "ContinueStatement" */ + var label: Identifier? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface WithStatement : BaseStatement { + override var type: String /* "WithStatement" */ + var `object`: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface SwitchStatement : BaseStatement { + override var type: String /* "SwitchStatement" */ + var discriminant: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var cases: Array +} + +internal external interface ReturnStatement : BaseStatement { + override var type: String /* "ReturnStatement" */ + var argument: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ThrowStatement : BaseStatement { + override var type: String /* "ThrowStatement" */ + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface TryStatement : BaseStatement { + override var type: String /* "TryStatement" */ + var block: BlockStatement + var handler: CatchClause? + get() = definedExternally + set(value) = definedExternally + var finalizer: BlockStatement? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface WhileStatement : BaseStatement { + override var type: String /* "WhileStatement" */ + var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface DoWhileStatement : BaseStatement { + override var type: String /* "DoWhileStatement" */ + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally + var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ForStatement : BaseStatement { + override var type: String /* "ForStatement" */ + var init: dynamic /* VariableDeclaration? | ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var test: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var update: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseForXStatement : BaseStatement { + var left: dynamic /* VariableDeclaration | Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var body: dynamic /* ExpressionStatement | BlockStatement | EmptyStatement | DebuggerStatement | WithStatement | ReturnStatement | LabeledStatement | BreakStatement | ContinueStatement | IfStatement | SwitchStatement | ThrowStatement | TryStatement | WhileStatement | DoWhileStatement | ForStatement | ForInStatement | ForOfStatement | FunctionDeclaration | VariableDeclaration | ClassDeclaration */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ForInStatement : BaseForXStatement { + override var type: String /* "ForInStatement" */ +} + +internal external interface DebuggerStatement : BaseStatement { + override var type: String /* "DebuggerStatement" */ +} + +internal external interface BaseDeclaration : BaseStatement + +internal external interface FunctionDeclaration : BaseFunction, BaseDeclaration { + override var type: String /* "FunctionDeclaration" */ + var id: Identifier? + override var body: BlockStatement +} + +internal external interface VariableDeclaration : BaseDeclaration { + override var type: String /* "VariableDeclaration" */ + var declarations: Array + var kind: String /* "var" | "let" | "const" */ +} + +internal external interface VariableDeclarator : BaseNode { + override var type: String /* "VariableDeclarator" */ + var id: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + var init: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseExpression : BaseNode + +internal external interface ChainExpression : BaseExpression { + override var type: String /* "ChainExpression" */ + var expression: dynamic /* SimpleCallExpression | MemberExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ThisExpression : BaseExpression { + override var type: String /* "ThisExpression" */ +} + +internal external interface ArrayExpression : BaseExpression { + override var type: String /* "ArrayExpression" */ + var elements: Array +} + +internal external interface ObjectExpression : BaseExpression { + override var type: String /* "ObjectExpression" */ + var properties: Array +} + +internal external interface Property : BaseNode { + override var type: String /* "Property" */ + var key: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var value: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern */ + get() = definedExternally + set(value) = definedExternally + var kind: String /* "init" | "get" | "set" */ + var method: Boolean + var shorthand: Boolean + var computed: Boolean +} + +internal external interface FunctionExpression : BaseFunction, BaseExpression { + var id: Identifier? + get() = definedExternally + set(value) = definedExternally + override var type: String /* "FunctionExpression" */ + override var body: BlockStatement +} + +internal external interface SequenceExpression : BaseExpression { + override var type: String /* "SequenceExpression" */ + var expressions: Array +} + +internal external interface UnaryExpression : BaseExpression { + override var type: String /* "UnaryExpression" */ + var operator: String /* "-" | "+" | "!" | "~" | "typeof" | "void" | "delete" */ + var prefix: Boolean + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BinaryExpression : BaseExpression { + override var type: String /* "BinaryExpression" */ + var operator: String /* "==" | "!=" | "===" | "!==" | "<" | "<=" | ">" | ">=" | "<<" | ">>" | ">>>" | "+" | "-" | "*" | "/" | "%" | "**" | "|" | "^" | "&" | "in" | "instanceof" */ + var left: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface AssignmentExpression : BaseExpression { + override var type: String /* "AssignmentExpression" */ + var operator: String /* "=" | "+=" | "-=" | "*=" | "/=" | "%=" | "**=" | "<<=" | ">>=" | ">>>=" | "|=" | "^=" | "&=" */ + var left: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface UpdateExpression : BaseExpression { + override var type: String /* "UpdateExpression" */ + var operator: String /* "++" | "--" */ + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var prefix: Boolean +} + +internal external interface LogicalExpression : BaseExpression { + override var type: String /* "LogicalExpression" */ + var operator: String /* "||" | "&&" | "??" */ + var left: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ConditionalExpression : BaseExpression { + override var type: String /* "ConditionalExpression" */ + var test: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var alternate: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var consequent: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseCallExpression : BaseExpression { + var callee: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | Super */ + get() = definedExternally + set(value) = definedExternally + var arguments: Array +} + +internal external interface SimpleCallExpression : BaseCallExpression { + override var type: String /* "CallExpression" */ + var optional: Boolean +} + +internal external interface NewExpression : BaseCallExpression { + override var type: String /* "NewExpression" */ +} + +internal external interface MemberExpression : BaseExpression, BasePattern { + override var type: String /* "MemberExpression" */ + var `object`: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression | Super */ + get() = definedExternally + set(value) = definedExternally + var property: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var computed: Boolean + var optional: Boolean +} + +internal external interface BasePattern : BaseNode + +internal external interface SwitchCase : BaseNode { + override var type: String /* "SwitchCase" */ + var test: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var consequent: Array +} + +internal external interface CatchClause : BaseNode { + override var type: String /* "CatchClause" */ + var param: dynamic /* Identifier? | ObjectPattern? | ArrayPattern? | RestElement? | AssignmentPattern? | MemberExpression? */ + get() = definedExternally + set(value) = definedExternally + var body: BlockStatement +} + +internal external interface Identifier : BaseNode, BaseExpression, BasePattern { + override var type: String /* "Identifier" */ + var name: String +} + +internal external interface SimpleLiteral : BaseNode, BaseExpression { + override var type: String /* "Literal" */ + var value: dynamic /* String? | Boolean? | Number? */ + get() = definedExternally + set(value) = definedExternally + var raw: String? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface `T$1` { + var pattern: String + var flags: String +} + +internal external interface RegExpLiteral : BaseNode, BaseExpression { + override var type: String /* "Literal" */ + var value: RegExp? + get() = definedExternally + set(value) = definedExternally + var regex: `T$1` + var raw: String? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ForOfStatement : BaseForXStatement { + override var type: String /* "ForOfStatement" */ + var await: Boolean +} + +internal external interface Super : BaseNode { + override var type: String /* "Super" */ +} + +internal external interface SpreadElement : BaseNode { + override var type: String /* "SpreadElement" */ + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ArrowFunctionExpression : BaseExpression, BaseFunction { + override var type: String /* "ArrowFunctionExpression" */ + var expression: Boolean + override var body: dynamic /* BlockStatement | ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface YieldExpression : BaseExpression { + override var type: String /* "YieldExpression" */ + var argument: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var delegate: Boolean +} + +internal external interface TemplateLiteral : BaseExpression { + override var type: String /* "TemplateLiteral" */ + var quasis: Array + var expressions: Array +} + +internal external interface TaggedTemplateExpression : BaseExpression { + override var type: String /* "TaggedTemplateExpression" */ + var tag: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var quasi: TemplateLiteral +} + +internal external interface `T$2` { + var cooked: String + var raw: String +} + +internal external interface TemplateElement : BaseNode { + override var type: String /* "TemplateElement" */ + var tail: Boolean + var value: `T$2` +} + +internal external interface AssignmentProperty : Property { + override var value: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + override var kind: String /* "init" */ + override var method: Boolean +} + +internal external interface ObjectPattern : BasePattern { + override var type: String /* "ObjectPattern" */ + var properties: Array +} + +internal external interface ArrayPattern : BasePattern { + override var type: String /* "ArrayPattern" */ + var elements: Array +} + +internal external interface RestElement : BasePattern { + override var type: String /* "RestElement" */ + var argument: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface AssignmentPattern : BasePattern { + override var type: String /* "AssignmentPattern" */ + var left: dynamic /* Identifier | ObjectPattern | ArrayPattern | RestElement | AssignmentPattern | MemberExpression */ + get() = definedExternally + set(value) = definedExternally + var right: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface BaseClass : BaseNode { + var superClass: dynamic /* ThisExpression? | ArrayExpression? | ObjectExpression? | FunctionExpression? | ArrowFunctionExpression? | YieldExpression? | SimpleLiteral? | RegExpLiteral? | UnaryExpression? | UpdateExpression? | BinaryExpression? | AssignmentExpression? | LogicalExpression? | MemberExpression? | ConditionalExpression? | SimpleCallExpression? | NewExpression? | SequenceExpression? | TemplateLiteral? | TaggedTemplateExpression? | ClassExpression? | MetaProperty? | Identifier? | AwaitExpression? | ImportExpression? | ChainExpression? */ + get() = definedExternally + set(value) = definedExternally + var body: ClassBody +} + +internal external interface ClassBody : BaseNode { + override var type: String /* "ClassBody" */ + var body: Array +} + +internal external interface MethodDefinition : BaseNode { + override var type: String /* "MethodDefinition" */ + var key: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally + var value: FunctionExpression + var kind: String /* "constructor" | "method" | "get" | "set" */ + var computed: Boolean + var static: Boolean +} + +internal external interface ClassDeclaration : BaseClass, BaseDeclaration { + override var type: String /* "ClassDeclaration" */ + var id: Identifier? +} + +internal external interface ClassExpression : BaseClass, BaseExpression { + override var type: String /* "ClassExpression" */ + var id: Identifier? + get() = definedExternally + set(value) = definedExternally +} + +internal external interface MetaProperty : BaseExpression { + override var type: String /* "MetaProperty" */ + var meta: Identifier + var property: Identifier +} + +internal external interface BaseModuleDeclaration : BaseNode + +internal external interface BaseModuleSpecifier : BaseNode { + var local: Identifier +} + +internal external interface ImportDeclaration : BaseModuleDeclaration { + override var type: String /* "ImportDeclaration" */ + var specifiers: Array + var source: dynamic /* SimpleLiteral | RegExpLiteral */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ImportSpecifier : BaseModuleSpecifier { + override var type: String /* "ImportSpecifier" */ + var imported: Identifier +} + +internal external interface ImportExpression : BaseExpression { + override var type: String /* "ImportExpression" */ + var source: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ImportDefaultSpecifier : BaseModuleSpecifier { + override var type: String /* "ImportDefaultSpecifier" */ +} + +internal external interface ImportNamespaceSpecifier : BaseModuleSpecifier { + override var type: String /* "ImportNamespaceSpecifier" */ +} + +internal external interface ExportNamedDeclaration : BaseModuleDeclaration { + override var type: String /* "ExportNamedDeclaration" */ + var declaration: dynamic /* FunctionDeclaration? | VariableDeclaration? | ClassDeclaration? */ + get() = definedExternally + set(value) = definedExternally + var specifiers: Array + var source: dynamic /* SimpleLiteral? | RegExpLiteral? */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ExportSpecifier : BaseModuleSpecifier { + override var type: String /* "ExportSpecifier" */ + var exported: Identifier +} + +internal external interface ExportDefaultDeclaration : BaseModuleDeclaration { + override var type: String /* "ExportDefaultDeclaration" */ + var declaration: dynamic /* FunctionDeclaration | VariableDeclaration | ClassDeclaration | ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface ExportAllDeclaration : BaseModuleDeclaration { + override var type: String /* "ExportAllDeclaration" */ + var source: dynamic /* SimpleLiteral | RegExpLiteral */ + get() = definedExternally + set(value) = definedExternally +} + +internal external interface AwaitExpression : BaseExpression { + override var type: String /* "AwaitExpression" */ + var argument: dynamic /* ThisExpression | ArrayExpression | ObjectExpression | FunctionExpression | ArrowFunctionExpression | YieldExpression | SimpleLiteral | RegExpLiteral | UnaryExpression | UpdateExpression | BinaryExpression | AssignmentExpression | LogicalExpression | MemberExpression | ConditionalExpression | SimpleCallExpression | NewExpression | SequenceExpression | TemplateLiteral | TaggedTemplateExpression | ClassExpression | MetaProperty | Identifier | AwaitExpression | ImportExpression | ChainExpression */ + get() = definedExternally + set(value) = definedExternally +} diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/stream/stream.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/stream/stream.kt new file mode 100644 index 000000000..b3c65a758 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/stream/stream.kt @@ -0,0 +1,7 @@ +package kscience.kmath.estree.internal.stream + +import kscience.kmath.estree.internal.emitter.Emitter + +internal open external class Stream : Emitter { + open fun pipe(dest: Any, options: Any): Any +} diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es2015.iterable.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es2015.iterable.kt new file mode 100644 index 000000000..22d4dd8e0 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es2015.iterable.kt @@ -0,0 +1,25 @@ +package kscience.kmath.estree.internal.tsstdlib + +internal external interface IteratorYieldResult { + var done: Boolean? + get() = definedExternally + set(value) = definedExternally + var value: TYield +} + +internal external interface IteratorReturnResult { + var done: Boolean + var value: TReturn +} + +internal external interface Iterator { + fun next(vararg args: Any /* JsTuple<> | JsTuple */): dynamic /* IteratorYieldResult | IteratorReturnResult */ + val `return`: ((value: TReturn) -> dynamic)? + val `throw`: ((e: Any) -> dynamic)? +} + +internal typealias Iterator__1 = Iterator + +internal external interface Iterable + +internal external interface IterableIterator : Iterator__1 diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es5.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es5.kt new file mode 100644 index 000000000..70f6d9702 --- /dev/null +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/internal/tsstdlib/lib.es5.kt @@ -0,0 +1,82 @@ +@file:Suppress("UNUSED_TYPEALIAS_PARAMETER", "DEPRECATION") + +package kscience.kmath.estree.internal.tsstdlib + +import kotlin.js.RegExp + +internal typealias RegExpMatchArray = Array + +internal typealias RegExpExecArray = Array + +internal external interface RegExpConstructor { + @nativeInvoke + operator fun invoke(pattern: RegExp, flags: String = definedExternally): RegExp + + @nativeInvoke + operator fun invoke(pattern: RegExp): RegExp + + @nativeInvoke + operator fun invoke(pattern: String, flags: String = definedExternally): RegExp + + @nativeInvoke + operator fun invoke(pattern: String): RegExp + var prototype: RegExp + var `$1`: String + var `$2`: String + var `$3`: String + var `$4`: String + var `$5`: String + var `$6`: String + var `$7`: String + var `$8`: String + var `$9`: String + var lastMatch: String +} + +internal external interface ConcatArray { + var length: Number + + @nativeGetter + operator fun get(n: Number): T? + + @nativeSetter + operator fun set(n: Number, value: T) + fun join(separator: String = definedExternally): String + fun slice(start: Number = definedExternally, end: Number = definedExternally): Array +} + +internal external interface ArrayConstructor { + fun from(iterable: Iterable): Array + fun from(iterable: ArrayLike): Array + fun from(iterable: Iterable, mapfn: (v: T, k: Number) -> U, thisArg: Any = definedExternally): Array + fun from(iterable: Iterable, mapfn: (v: T, k: Number) -> U): Array + fun from(iterable: ArrayLike, mapfn: (v: T, k: Number) -> U, thisArg: Any = definedExternally): Array + fun from(iterable: ArrayLike, mapfn: (v: T, k: Number) -> U): Array + fun of(vararg items: T): Array + + @nativeInvoke + operator fun invoke(arrayLength: Number = definedExternally): Array + + @nativeInvoke + operator fun invoke(): Array + + @nativeInvoke + operator fun invoke(arrayLength: Number): Array + + @nativeInvoke + operator fun invoke(vararg items: T): Array + fun isArray(arg: Any): Boolean + var prototype: Array +} + +internal external interface ArrayLike { + var length: Number + + @nativeGetter + operator fun get(n: Number): T? + + @nativeSetter + operator fun set(n: Number, value: T) +} + +internal typealias Extract = Any diff --git a/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt new file mode 100644 index 000000000..b9be02d49 --- /dev/null +++ b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt @@ -0,0 +1,115 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.* +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.ByteRing +import kscience.kmath.operations.ComplexField +import kscience.kmath.operations.RealField +import kscience.kmath.operations.toComplex +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestESTreeConsistencyWithInterpreter { + @Test + fun mstSpace() { + val res1 = MstSpace.mstInSpace { + binaryOperationFunction("+")( + unaryOperationFunction("+")( + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }("x" to MST.Numeric(2)) + + val res2 = MstSpace.mstInSpace { + binaryOperationFunction("+")( + unaryOperationFunction("+")( + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }.compile()("x" to MST.Numeric(2)) + + assertEquals(res1, res2) + } + + @Test + fun byteRing() { + val res1 = ByteRing.mstInRing { + binaryOperationFunction("+")( + unaryOperationFunction("+")( + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }("x" to 3.toByte()) + + val res2 = ByteRing.mstInRing { + binaryOperationFunction("+")( + unaryOperationFunction("+")( + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + number(1) + ) * number(2) + }.compile()("x" to 3.toByte()) + + assertEquals(res1, res2) + } + + @Test + fun realField() { + val res1 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }("x" to 2.0) + + val res2 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }.compile()("x" to 2.0) + + assertEquals(res1, res2) + } + + @Test + fun complexField() { + val res1 = ComplexField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }("x" to 2.0.toComplex()) + + val res2 = ComplexField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }.compile()("x" to 2.0.toComplex()) + + assertEquals(res1, res2) + } +} diff --git a/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeOperationsSupport.kt b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeOperationsSupport.kt new file mode 100644 index 000000000..72a4669d9 --- /dev/null +++ b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeOperationsSupport.kt @@ -0,0 +1,41 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.mstInExtendedField +import kscience.kmath.ast.mstInField +import kscience.kmath.ast.mstInSpace +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestESTreeOperationsSupport { + @Test + fun testUnaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") }.compile() + val res = expression("x" to 2.0) + assertEquals(-2.0, res) + } + + @Test + fun testBinaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile() + val res = expression("x" to 2.0) + assertEquals(-1.0, res) + } + + @Test + fun testConstProductInvocation() { + val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) + assertEquals(4.0, res) + } + + @Test + fun testMultipleCalls() { + val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile() + val r = Random(0) + var s = 0.0 + repeat(1000000) { s += e("x" to r.nextDouble()) } + println(s) + } +} diff --git a/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeSpecialization.kt b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeSpecialization.kt new file mode 100644 index 000000000..9d0d17e58 --- /dev/null +++ b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeSpecialization.kt @@ -0,0 +1,54 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.mstInField +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestESTreeSpecialization { + @Test + fun testUnaryPlus() { + val expr = RealField.mstInField { unaryOperationFunction("+")(symbol("x")) }.compile() + assertEquals(2.0, expr("x" to 2.0)) + } + + @Test + fun testUnaryMinus() { + val expr = RealField.mstInField { unaryOperationFunction("-")(symbol("x")) }.compile() + assertEquals(-2.0, expr("x" to 2.0)) + } + + @Test + fun testAdd() { + val expr = RealField.mstInField { binaryOperationFunction("+")(symbol("x"), symbol("x")) }.compile() + assertEquals(4.0, expr("x" to 2.0)) + } + + @Test + fun testSine() { + val expr = RealField.mstInField { unaryOperationFunction("sin")(symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 0.0)) + } + + @Test + fun testMinus() { + val expr = RealField.mstInField { binaryOperationFunction("-")(symbol("x"), symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 2.0)) + } + + @Test + fun testDivide() { + val expr = RealField.mstInField { binaryOperationFunction("/")(symbol("x"), symbol("x")) }.compile() + assertEquals(1.0, expr("x" to 2.0)) + } + + @Test + fun testPower() { + val expr = RealField + .mstInField { binaryOperationFunction("pow")(symbol("x"), number(2)) } + .compile() + + assertEquals(4.0, expr("x" to 2.0)) + } +} diff --git a/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeVariables.kt b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeVariables.kt new file mode 100644 index 000000000..846120ee2 --- /dev/null +++ b/kmath-ast/src/jsTest/kotlin/kscience/kmath/estree/TestESTreeVariables.kt @@ -0,0 +1,22 @@ +package kscience.kmath.estree + +import kscience.kmath.ast.mstInRing +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +internal class TestESTreeVariables { + @Test + fun testVariable() { + val expr = ByteRing.mstInRing { symbol("x") }.compile() + assertEquals(1.toByte(), expr("x" to 1.toByte())) + } + + @Test + fun testUndefinedVariableFails() { + val expr = ByteRing.mstInRing { symbol("x") }.compile() + assertFailsWith { expr() } + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt new file mode 100644 index 000000000..55cdec243 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt @@ -0,0 +1,87 @@ +package kscience.kmath.asm + +import kscience.kmath.asm.internal.AsmBuilder +import kscience.kmath.asm.internal.buildName +import kscience.kmath.ast.MST +import kscience.kmath.ast.MST.* +import kscience.kmath.ast.MstExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.operations.Algebra +import kscience.kmath.operations.NumericAlgebra + +/** + * Compiles given MST to an Expression using AST compiler. + * + * @param type the target type. + * @param algebra the target algebra. + * @return the compiled expression. + * @author Alexander Nozik + */ +@PublishedApi +internal fun MST.compileWith(type: Class, algebra: Algebra): Expression { + fun AsmBuilder.visit(node: MST): Unit = when (node) { + is Symbolic -> { + val symbol = try { + algebra.symbol(node.value) + } catch (ignored: IllegalStateException) { + null + } + + if (symbol != null) + loadObjectConstant(symbol as Any) + else + loadVariable(node.value) + } + + is Numeric -> loadNumberConstant(node.value) + + is Unary -> when { + algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant( + algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value))) + + else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) } + } + + is Binary -> when { + algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant( + algebra.binaryOperationFunction(node.operation) + .invoke(algebra.number(node.left.value), algebra.number(node.right.value)) + ) + + algebra is NumericAlgebra && node.left is Numeric -> buildCall( + algebra.leftSideNumberOperationFunction(node.operation)) { + visit(node.left) + visit(node.right) + } + + algebra is NumericAlgebra && node.right is Numeric -> buildCall( + algebra.rightSideNumberOperationFunction(node.operation)) { + visit(node.left) + visit(node.right) + } + + else -> buildCall(algebra.binaryOperationFunction(node.operation)) { + visit(node.left) + visit(node.right) + } + } + } + + return AsmBuilder(type, buildName(this)) { visit(this@compileWith) }.instance +} + +/** + * Compiles an [MST] to ASM using given algebra. + * + * @author Alexander Nozik. + */ +public inline fun Algebra.expression(mst: MST): Expression = + mst.compileWith(T::class.java, this) + +/** + * Optimizes performance of an [MstExpression] using ASM codegen. + * + * @author Alexander Nozik. + */ +public inline fun MstExpression>.compile(): Expression = + mst.compileWith(T::class.java, algebra) diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt new file mode 100644 index 000000000..93d8d1143 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -0,0 +1,345 @@ +package kscience.kmath.asm.internal + +import kscience.kmath.asm.internal.AsmBuilder.ClassLoader +import kscience.kmath.ast.MST +import kscience.kmath.expressions.Expression +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.Type.* +import org.objectweb.asm.commons.InstructionAdapter +import java.lang.invoke.MethodHandles +import java.lang.invoke.MethodType +import java.lang.reflect.Modifier +import java.util.stream.Collectors.toMap +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/** + * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. + * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. + * + * @property T the type of AsmExpression to unwrap. + * @property className the unique class name of new loaded class. + * @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0. + * @author Iaroslav Postovalov + */ +internal class AsmBuilder( + classOfT: Class<*>, + private val className: String, + private val callbackAtInvokeL0: AsmBuilder.() -> Unit, +) { + /** + * Internal classloader of [AsmBuilder] with alias to define class from byte array. + */ + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + } + + /** + * The instance of [ClassLoader] used by this builder. + */ + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + + /** + * ASM type for [T]. + */ + private val tType: Type = classOfT.asm + + /** + * ASM type for new class. + */ + private val classType: Type = getObjectType(className.replace(oldChar = '.', newChar = '/')) + + /** + * List of constants to provide to the subclass. + */ + private val constants: MutableList = mutableListOf() + + /** + * Method visitor of `invoke` method of the subclass. + */ + private lateinit var invokeMethodVisitor: InstructionAdapter + + /** + * Subclasses, loads and instantiates [Expression] for given parameters. + * + * The built instance is cached. + */ + @Suppress("UNCHECKED_CAST") + val instance: Expression by lazy { + val hasConstants: Boolean + + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { + visit( + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, + classType.internalName, + "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + OBJECT_TYPE.internalName, + arrayOf(EXPRESSION_TYPE.internalName), + ) + + visitMethod( + ACC_PUBLIC or ACC_FINAL, + "invoke", + getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}${if (Modifier.isFinal(classOfT.modifiers)) "" else "+"}${tType.descriptor}>;)${tType.descriptor}", + null, + ).instructionAdapter { + invokeMethodVisitor = this + visitCode() + val l0 = label() + callbackAtInvokeL0() + areturn(tType) + val l1 = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + l0, + l1, + 0, + ) + + visitLocalVariable( + "arguments", + MAP_TYPE.descriptor, + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", + l0, + l1, + 1, + ) + + visitMaxs(0, 2) + visitEnd() + } + + visitMethod( + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, + "invoke", + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + null, + null, + ).instructionAdapter { + visitCode() + val l0 = label() + load(0, OBJECT_TYPE) + load(1, MAP_TYPE) + invokevirtual(classType.internalName, "invoke", getMethodDescriptor(tType, MAP_TYPE), false) + areturn(tType) + val l1 = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + l0, + l1, + 0, + ) + + visitMaxs(0, 2) + visitEnd() + } + + hasConstants = constants.isNotEmpty() + + if (hasConstants) + visitField( + access = ACC_PRIVATE or ACC_FINAL, + name = "constants", + descriptor = OBJECT_ARRAY_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd, + ) + + visitMethod( + ACC_PUBLIC, + "", + getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), + null, + null, + ).instructionAdapter { + val l0 = label() + load(0, classType) + invokespecial(OBJECT_TYPE.internalName, "", getMethodDescriptor(VOID_TYPE), false) + label() + load(0, classType) + + if (hasConstants) { + label() + load(0, classType) + load(1, OBJECT_ARRAY_TYPE) + putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + } + + label() + visitInsn(RETURN) + val l4 = label() + visitLocalVariable("this", classType.descriptor, null, l0, l4, 0) + + if (hasConstants) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, 1) + + visitMaxs(0, 3) + visitEnd() + } + + visitEnd() + } + + val cls = classLoader.defineClass(className, classWriter.toByteArray()) + // java.io.File("dump.class").writeBytes(classWriter.toByteArray()) + val l = MethodHandles.publicLookup() + + if (hasConstants) + l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array::class.java)) + .invoke(constants.toTypedArray()) as Expression + else + l.findConstructor(cls, MethodType.methodType(Void.TYPE)).invoke() as Expression + } + + /** + * Loads [java.lang.Object] constant from constants. + */ + fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run { + val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex + invokeMethodVisitor.load(0, classType) + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + iconst(idx) + visitInsn(AALOAD) + if (type != OBJECT_TYPE) checkcast(type) + } + + /** + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive + * constant from the constant pool. + */ + fun loadNumberConstant(value: Number) { + val boxed = value.javaClass.asm + val primitive = BOXED_TO_PRIMITIVES[boxed] + + if (primitive != null) { + when (primitive) { + BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + } + + val r = PRIMITIVES_TO_BOXED.getValue(primitive) + + invokeMethodVisitor.invokestatic( + r.internalName, + "valueOf", + getMethodDescriptor(r, primitive), + false, + ) + + return + } + + loadObjectConstant(value, boxed) + } + + /** + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. + */ + fun loadVariable(name: String): Unit = invokeMethodVisitor.run { + load(1, MAP_TYPE) + aconst(name) + + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), + false, + ) + + checkcast(tType) + } + + inline fun buildCall(function: Function, parameters: AsmBuilder.() -> Unit) { + contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } + val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces } + + val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount + ?: error("Provided function object doesn't contain invoke method") + + val type = getType(`interface`) + loadObjectConstant(function, type) + parameters(this) + + invokeMethodVisitor.invokeinterface( + type.internalName, + "invoke", + getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE }), + ) + + invokeMethodVisitor.checkcast(tType) + } + + companion object { + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ + private val BOXED_TO_PRIMITIVES: Map by lazy { + hashMapOf( + Byte::class.java.asm to BYTE_TYPE, + Short::class.java.asm to SHORT_TYPE, + Integer::class.java.asm to INT_TYPE, + Long::class.java.asm to LONG_TYPE, + Float::class.java.asm to FLOAT_TYPE, + Double::class.java.asm to DOUBLE_TYPE, + ) + } + + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ + private val PRIMITIVES_TO_BOXED: Map by lazy { + BOXED_TO_PRIMITIVES.entries.stream().collect( + toMap(Map.Entry::value, Map.Entry::key), + ) + } + + /** + * ASM type for [Expression]. + */ + val EXPRESSION_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Expression") } + + /** + * ASM type for [java.util.Map]. + */ + val MAP_TYPE: Type by lazy { getObjectType("java/util/Map") } + + /** + * ASM type for [java.lang.Object]. + */ + val OBJECT_TYPE: Type by lazy { getObjectType("java/lang/Object") } + + /** + * ASM type for array of [java.lang.Object]. + */ + val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") } + + /** + * ASM type for [java.lang.String]. + */ + val STRING_TYPE: Type by lazy { getObjectType("java/lang/String") } + + /** + * ASM type for MapIntrinsics. + */ + val MAP_INTRINSICS_TYPE: Type by lazy { getObjectType("kscience/kmath/asm/internal/MapIntrinsics") } + + /** + * ASM Type for [kscience.kmath.expressions.Symbol]. + */ + val SYMBOL_TYPE: Type by lazy { getObjectType("kscience/kmath/expressions/Symbol") } + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt new file mode 100644 index 000000000..6d5d19d42 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt @@ -0,0 +1,93 @@ +package kscience.kmath.asm.internal + +import kscience.kmath.ast.MST +import kscience.kmath.expressions.Expression +import org.objectweb.asm.* +import org.objectweb.asm.commons.InstructionAdapter +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/** + * Returns ASM [Type] for given [Class]. + * + * @author Iaroslav Postovalov + */ +internal inline val Class<*>.asm: Type + get() = Type.getType(this) + +/** + * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. + * + * @author Iaroslav Postovalov + */ +internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array { + contract { callsInPlace(predicate, InvocationKind.EXACTLY_ONCE) } + return if (predicate(this)) arrayOf(this) else emptyArray() +} + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor]. + * + * @author Iaroslav Postovalov + */ +private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) + +/** + * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. + * + * @author Iaroslav Postovalov + */ +internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return instructionAdapter().apply(block) +} + +/** + * Constructs a [Label], then applies it to this visitor. + * + * @author Iaroslav Postovalov + */ +internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) } + +/** + * Creates a class name for [Expression] subclassed to implement [mst] provided. + * + * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there + * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. + * + * @author Iaroslav Postovalov + */ +internal tailrec fun buildName(mst: MST, collision: Int = 0): String { + val name = "kscience.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(mst, collision + 1) +} + +@Suppress("FunctionName") +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return ClassWriter(flags).apply(block) +} + +/** + * Invokes [visitField] and applies [block] to the [FieldVisitor]. + * + * @author Iaroslav Postovalov + */ +internal inline fun ClassWriter.visitField( + access: Int, + name: String, + descriptor: String, + signature: String?, + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return visitField(access, name, descriptor, signature, value).apply(block) +} diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt new file mode 100644 index 000000000..588b9611a --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt @@ -0,0 +1,13 @@ +@file:JvmName("MapIntrinsics") + +package kscience.kmath.asm.internal + +import kscience.kmath.expressions.StringSymbol +import kscience.kmath.expressions.Symbol + +/** + * Gets value with given [key] or throws [NoSuchElementException] whenever it is not present. + * + * @author Iaroslav Postovalov + */ +internal fun Map.getOrFail(key: String): V = getValue(StringSymbol(key)) diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt similarity index 56% rename from kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt rename to kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt index cba335a8d..0b66e2c31 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/parser.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/ast/parser.kt @@ -1,4 +1,6 @@ -package scientifik.kmath.ast +// TODO move to common when https://github.com/h0tk3y/better-parse/pull/33 is merged + +package kscience.kmath.ast import com.github.h0tk3y.betterParse.combinators.* import com.github.h0tk3y.betterParse.grammar.Grammar @@ -7,51 +9,54 @@ import com.github.h0tk3y.betterParse.grammar.parser import com.github.h0tk3y.betterParse.grammar.tryParseToEnd import com.github.h0tk3y.betterParse.lexer.Token import com.github.h0tk3y.betterParse.lexer.TokenMatch +import com.github.h0tk3y.betterParse.lexer.literalToken import com.github.h0tk3y.betterParse.lexer.regexToken import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.Parser -import scientifik.kmath.operations.FieldOperations -import scientifik.kmath.operations.PowerOperations -import scientifik.kmath.operations.RingOperations -import scientifik.kmath.operations.SpaceOperations +import kscience.kmath.operations.FieldOperations +import kscience.kmath.operations.PowerOperations +import kscience.kmath.operations.RingOperations +import kscience.kmath.operations.SpaceOperations /** - * TODO move to core + * better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4. + * + * @author Alexander Nozik and Iaroslav Postovalov */ -object ArithmeticsEvaluator : Grammar() { +public object ArithmeticsEvaluator : Grammar() { // TODO replace with "...".toRegex() when better-parse 0.4.1 is released private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?") private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*") - private val lpar: Token by regexToken("\\(") - private val rpar: Token by regexToken("\\)") - private val comma: Token by regexToken(",") - private val mul: Token by regexToken("\\*") - private val pow: Token by regexToken("\\^") - private val div: Token by regexToken("/") - private val minus: Token by regexToken("-") - private val plus: Token by regexToken("\\+") + private val lpar: Token by literalToken("(") + private val rpar: Token by literalToken(")") + private val comma: Token by literalToken(",") + private val mul: Token by literalToken("*") + private val pow: Token by literalToken("^") + private val div: Token by literalToken("/") + private val minus: Token by literalToken("-") + private val plus: Token by literalToken("+") private val ws: Token by regexToken("\\s+", ignore = true) private val number: Parser by num use { MST.Numeric(text.toDouble()) } private val singular: Parser by id use { MST.Symbolic(text) } - private val unaryFunction: Parser by (id and skip(lpar) and parser(::subSumChain) and skip(rpar)) + private val unaryFunction: Parser by (id and -lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar) .map { (id, term) -> MST.Unary(id.text, term) } private val binaryFunction: Parser by id - .and(skip(lpar)) - .and(parser(::subSumChain)) - .and(skip(comma)) - .and(parser(::subSumChain)) - .and(skip(rpar)) + .and(-lpar) + .and(parser(ArithmeticsEvaluator::subSumChain)) + .and(-comma) + .and(parser(ArithmeticsEvaluator::subSumChain)) + .and(-rpar) .map { (id, left, right) -> MST.Binary(id.text, left, right) } private val term: Parser by number .or(binaryFunction) .or(unaryFunction) .or(singular) - .or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) - .or(skip(lpar) and parser(::subSumChain) and skip(rpar)) + .or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) + .or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar) private val powChain: Parser by leftAssociative(term = term, operator = pow) { a, _, b -> MST.Binary(PowerOperations.POW_OPERATION, a, b) @@ -81,17 +86,17 @@ object ArithmeticsEvaluator : Grammar() { } /** - * Tries to parse the string into [MST]. + * Tries to parse the string into [MST] using [ArithmeticsEvaluator]. Returns [ParseResult] representing expression or error. * * @receiver the string to parse. * @return the [MST] node. */ -fun String.tryParseMath(): ParseResult = ArithmeticsEvaluator.tryParseToEnd(this) +public fun String.tryParseMath(): ParseResult = ArithmeticsEvaluator.tryParseToEnd(this) /** - * Parses the string into [MST]. + * Parses the string into [MST] using [ArithmeticsEvaluator]. * * @receiver the string to parse. * @return the [MST] node. */ -fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) +public fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt deleted file mode 100644 index ee0ea15ff..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ /dev/null @@ -1,64 +0,0 @@ -package scientifik.kmath.asm - -import scientifik.kmath.asm.internal.AsmBuilder -import scientifik.kmath.asm.internal.MstType -import scientifik.kmath.asm.internal.buildAlgebraOperationCall -import scientifik.kmath.asm.internal.buildName -import scientifik.kmath.ast.MST -import scientifik.kmath.ast.MstExpression -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra -import kotlin.reflect.KClass - -/** - * Compile given MST to an Expression using AST compiler - */ -fun MST.compileWith(type: KClass, algebra: Algebra): Expression { - fun AsmBuilder.visit(node: MST) { - when (node) { - is MST.Symbolic -> { - val symbol = try { - algebra.symbol(node.value) - } catch (ignored: Throwable) { - null - } - - if (symbol != null) - loadTConstant(symbol) - else - loadVariable(node.value) - } - - is MST.Numeric -> loadNumeric(node.value) - - is MST.Unary -> buildAlgebraOperationCall( - context = algebra, - name = node.operation, - fallbackMethodName = "unaryOperation", - parameterTypes = arrayOf(MstType.fromMst(node.value)) - ) { visit(node.value) } - - is MST.Binary -> buildAlgebraOperationCall( - context = algebra, - name = node.operation, - fallbackMethodName = "binaryOperation", - parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right)) - ) { - visit(node.left) - visit(node.right) - } - } - } - - return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() -} - -/** - * Compile an [MST] to ASM using given algebra - */ -inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) - -/** - * Optimize performance of an [MstExpression] using ASM codegen - */ -inline fun MstExpression.compile(): Expression = mst.compileWith(T::class, algebra) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt deleted file mode 100644 index f8c159baf..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ /dev/null @@ -1,568 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.* -import org.objectweb.asm.Opcodes.* -import org.objectweb.asm.commons.InstructionAdapter -import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader -import scientifik.kmath.ast.MST -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.NumericAlgebra -import java.util.* -import java.util.stream.Collectors -import kotlin.reflect.KClass - -/** - * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. - * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. - * - * @property T the type of AsmExpression to unwrap. - * @property algebra the algebra the applied AsmExpressions use. - * @property className the unique class name of new loaded class. - * @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. - */ -internal class AsmBuilder internal constructor( - private val classOfT: KClass<*>, - private val algebra: Algebra, - private val className: String, - private val invokeLabel0Visitor: AsmBuilder.() -> Unit -) { - /** - * Internal classloader of [AsmBuilder] with alias to define class from byte array. - */ - private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { - internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) - } - - /** - * The instance of [ClassLoader] used by this builder. - */ - private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - - /** - * ASM Type for [algebra]. - */ - private val tAlgebraType: Type = algebra::class.asm - - /** - * ASM type for [T]. - */ - internal val tType: Type = classOfT.asm - - /** - * ASM type for new class. - */ - private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! - - /** - * Index of `this` variable in invoke method of the built subclass. - */ - private val invokeThisVar: Int = 0 - - /** - * Index of `arguments` variable in invoke method of the built subclass. - */ - private val invokeArgumentsVar: Int = 1 - - /** - * List of constants to provide to the subclass. - */ - private val constants: MutableList = mutableListOf() - - /** - * Method visitor of `invoke` method of the subclass. - */ - private lateinit var invokeMethodVisitor: InstructionAdapter - - /** - * State if this [AsmBuilder] needs to generate constants field. - */ - private var hasConstants: Boolean = true - - /** - * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. - */ - internal var primitiveMode: Boolean = false - - /** - * Primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. - */ - internal var primitiveMask: Type = OBJECT_TYPE - - /** - * Boxed primitive type to apple for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. - */ - internal var primitiveMaskBoxed: Type = OBJECT_TYPE - - /** - * Stack of useful objects types on stack to verify types. - */ - private val typeStack: ArrayDeque = ArrayDeque() - - /** - * Stack of useful objects types on stack expected by algebra calls. - */ - internal val expectationStack: ArrayDeque = ArrayDeque(listOf(tType)) - - /** - * The cache for instance built by this builder. - */ - private var generatedInstance: Expression? = null - - /** - * Subclasses, loads and instantiates [Expression] for given parameters. - * - * The built instance is cached. - */ - @Suppress("UNCHECKED_CAST") - internal fun getInstance(): Expression { - generatedInstance?.let { return it } - - if (SIGNATURE_LETTERS.containsKey(classOfT)) { - primitiveMode = true - primitiveMask = SIGNATURE_LETTERS.getValue(classOfT) - primitiveMaskBoxed = tType - } - - val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { - visit( - V1_8, - ACC_PUBLIC or ACC_FINAL or ACC_SUPER, - classType.internalName, - "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", - OBJECT_TYPE.internalName, - arrayOf(EXPRESSION_TYPE.internalName) - ) - - visitMethod( - ACC_PUBLIC or ACC_FINAL, - "invoke", - Type.getMethodDescriptor(tType, MAP_TYPE), - "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", - null - ).instructionAdapter { - invokeMethodVisitor = this - visitCode() - val l0 = label() - invokeLabel0Visitor() - areturn(tType) - val l1 = label() - - visitLocalVariable( - "this", - classType.descriptor, - null, - l0, - l1, - invokeThisVar - ) - - visitLocalVariable( - "arguments", - MAP_TYPE.descriptor, - "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", - l0, - l1, - invokeArgumentsVar - ) - - visitMaxs(0, 2) - visitEnd() - } - - visitMethod( - ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, - "invoke", - Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), - null, - null - ).instructionAdapter { - val thisVar = 0 - val argumentsVar = 1 - visitCode() - val l0 = label() - load(thisVar, OBJECT_TYPE) - load(argumentsVar, MAP_TYPE) - invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) - areturn(tType) - val l1 = label() - - visitLocalVariable( - "this", - classType.descriptor, - null, - l0, - l1, - thisVar - ) - - visitMaxs(0, 2) - visitEnd() - } - - hasConstants = constants.isNotEmpty() - - visitField( - access = ACC_PRIVATE or ACC_FINAL, - name = "algebra", - descriptor = tAlgebraType.descriptor, - signature = null, - value = null, - block = FieldVisitor::visitEnd - ) - - if (hasConstants) - visitField( - access = ACC_PRIVATE or ACC_FINAL, - name = "constants", - descriptor = OBJECT_ARRAY_TYPE.descriptor, - signature = null, - value = null, - block = FieldVisitor::visitEnd - ) - - visitMethod( - ACC_PUBLIC, - "", - - Type.getMethodDescriptor( - Type.VOID_TYPE, - tAlgebraType, - *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), - - null, - null - ).instructionAdapter { - val thisVar = 0 - val algebraVar = 1 - val constantsVar = 2 - val l0 = label() - load(thisVar, classType) - invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) - label() - load(thisVar, classType) - load(algebraVar, tAlgebraType) - putfield(classType.internalName, "algebra", tAlgebraType.descriptor) - - if (hasConstants) { - label() - load(thisVar, classType) - load(constantsVar, OBJECT_ARRAY_TYPE) - putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) - } - - label() - visitInsn(RETURN) - val l4 = label() - visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) - - visitLocalVariable( - "algebra", - tAlgebraType.descriptor, - null, - l0, - l4, - algebraVar - ) - - if (hasConstants) - visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) - - visitMaxs(0, 3) - visitEnd() - } - - visitEnd() - } - - val new = classLoader - .defineClass(className, classWriter.toByteArray()) - .constructors - .first() - .newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression - - generatedInstance = new - return new - } - - /** - * Loads a [T] constant from [constants]. - */ - internal fun loadTConstant(value: T) { - if (classOfT in INLINABLE_NUMBERS) { - val expectedType = expectationStack.pop() - val mustBeBoxed = expectedType.sort == Type.OBJECT - loadNumberConstant(value as Number, mustBeBoxed) - - if (mustBeBoxed) - invokeMethodVisitor.checkcast(tType) - - if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) - return - } - - loadObjectConstant(value as Any, tType) - } - - /** - * Boxes the current value and pushes it. - */ - private fun box(primitive: Type) { - val r = PRIMITIVES_TO_BOXED.getValue(primitive) - - invokeMethodVisitor.invokestatic( - r.internalName, - "valueOf", - Type.getMethodDescriptor(r, primitive), - false - ) - } - - /** - * Unboxes the current boxed value and pushes it. - */ - private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual( - NUMBER_TYPE.internalName, - NUMBER_CONVERTER_METHODS.getValue(primitive), - Type.getMethodDescriptor(primitive), - false - ) - - /** - * Loads [java.lang.Object] constant from constants. - */ - private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { - val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - loadThis() - getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) - iconst(idx) - visitInsn(AALOAD) - checkcast(type) - } - - internal fun loadNumeric(value: Number) { - if (expectationStack.peek() == NUMBER_TYPE) { - loadNumberConstant(value, true) - expectationStack.pop() - typeStack.push(NUMBER_TYPE) - } else (algebra as? NumericAlgebra)?.number(value)?.let { loadTConstant(it) } - ?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.") - } - - /** - * Loads this variable. - */ - private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) - - /** - * Either loads a numeric constant [value] from the class's constants field or boxes a primitive - * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded - * from it). - */ - private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { - val boxed = value::class.asm - val primitive = BOXED_TO_PRIMITIVES[boxed] - - if (primitive != null) { - when (primitive) { - Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) - Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) - Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) - Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - } - - if (mustBeBoxed) - box(primitive) - - return - } - - loadObjectConstant(value, boxed) - - if (!mustBeBoxed) - unboxTo(primitiveMask) - } - - /** - * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be - * provided. - */ - internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - load(invokeArgumentsVar, MAP_TYPE) - aconst(name) - - if (defaultValue != null) - loadTConstant(defaultValue) - - invokestatic( - MAP_INTRINSICS_TYPE.internalName, - "getOrFail", - - Type.getMethodDescriptor( - OBJECT_TYPE, - MAP_TYPE, - OBJECT_TYPE, - *OBJECT_TYPE.wrapToArrayIf { defaultValue != null }), - false - ) - - checkcast(tType) - val expectedType = expectationStack.pop() - - if (expectedType.sort == Type.OBJECT) - typeStack.push(tType) - else { - unboxTo(primitiveMask) - typeStack.push(primitiveMask) - } - } - - /** - * Loads algebra from according field of the class and casts it to class of [algebra] provided. - */ - internal fun loadAlgebra() { - loadThis() - invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) - } - - /** - * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is - * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be - * called before the arguments and this operation. - * - * The result is casted to [T] automatically. - */ - internal fun invokeAlgebraOperation( - owner: String, - method: String, - descriptor: String, - expectedArity: Int, - opcode: Int = INVOKEINTERFACE - ) { - run loop@{ - repeat(expectedArity) { - if (typeStack.isEmpty()) return@loop - typeStack.pop() - } - } - - invokeMethodVisitor.visitMethodInsn( - opcode, - owner, - method, - descriptor, - opcode == INVOKEINTERFACE - ) - - invokeMethodVisitor.checkcast(tType) - val isLastExpr = expectationStack.size == 1 - val expectedType = expectationStack.pop() - - if (expectedType.sort == Type.OBJECT || isLastExpr) - typeStack.push(tType) - else { - unboxTo(primitiveMask) - typeStack.push(primitiveMask) - } - } - - /** - * Writes a LDC Instruction with string constant provided. - */ - internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) - - internal companion object { - /** - * Maps JVM primitive numbers boxed types to their primitive ASM types. - */ - private val SIGNATURE_LETTERS: Map, Type> by lazy { - hashMapOf( - java.lang.Byte::class to Type.BYTE_TYPE, - java.lang.Short::class to Type.SHORT_TYPE, - java.lang.Integer::class to Type.INT_TYPE, - java.lang.Long::class to Type.LONG_TYPE, - java.lang.Float::class to Type.FLOAT_TYPE, - java.lang.Double::class to Type.DOUBLE_TYPE - ) - } - - /** - * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. - */ - private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } - - /** - * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. - */ - private val PRIMITIVES_TO_BOXED: Map by lazy { - BOXED_TO_PRIMITIVES.entries.stream().collect( - Collectors.toMap( - Map.Entry::value, - Map.Entry::key - ) - ) - } - - /** - * Maps primitive ASM types to [Number] functions unboxing them. - */ - private val NUMBER_CONVERTER_METHODS: Map by lazy { - hashMapOf( - Type.BYTE_TYPE to "byteValue", - Type.SHORT_TYPE to "shortValue", - Type.INT_TYPE to "intValue", - Type.LONG_TYPE to "longValue", - Type.FLOAT_TYPE to "floatValue", - Type.DOUBLE_TYPE to "doubleValue" - ) - } - - /** - * Provides boxed number types values of which can be stored in JVM bytecode constant pool. - */ - private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } - - /** - * ASM type for [Expression]. - */ - internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } - - /** - * ASM type for [java.lang.Number]. - */ - internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } - - /** - * ASM type for [java.util.Map]. - */ - internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } - - /** - * ASM type for [java.lang.Object]. - */ - internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } - - /** - * ASM type for array of [java.lang.Object]. - */ - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") - internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } - - /** - * ASM type for [Algebra]. - */ - internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } - - /** - * ASM type for [java.lang.String]. - */ - internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } - - /** - * ASM type for MapIntrinsics. - */ - internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") } - } -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt deleted file mode 100644 index bf73d304b..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt +++ /dev/null @@ -1,17 +0,0 @@ -package scientifik.kmath.asm.internal - -import scientifik.kmath.ast.MST - -internal enum class MstType { - GENERAL, - NUMBER; - - companion object { - fun fromMst(mst: MST): MstType { - if (mst is MST.Numeric) - return NUMBER - - return GENERAL - } - } -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt deleted file mode 100644 index 6f51fe855..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt +++ /dev/null @@ -1,191 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.* -import org.objectweb.asm.Opcodes.INVOKEVIRTUAL -import org.objectweb.asm.commons.InstructionAdapter -import scientifik.kmath.ast.MST -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra -import java.lang.reflect.Method -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract -import kotlin.reflect.KClass - -private val methodNameAdapters: Map, String> by lazy { - hashMapOf( - "+" to 2 to "add", - "*" to 2 to "multiply", - "/" to 2 to "divide", - "+" to 1 to "unaryPlus", - "-" to 1 to "unaryMinus", - "-" to 2 to "minus" - ) -} - -internal val KClass<*>.asm: Type - get() = Type.getType(java) - -/** - * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. - */ -internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array { - contract { callsInPlace(predicate, InvocationKind.EXACTLY_ONCE) } - return if (predicate(this)) arrayOf(this) else emptyArray() -} - -/** - * Creates an [InstructionAdapter] from this [MethodVisitor]. - */ -private fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) - -/** - * Creates an [InstructionAdapter] from this [MethodVisitor] and applies [block] to it. - */ -internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return instructionAdapter().apply(block) -} - -/** - * Constructs a [Label], then applies it to this visitor. - */ -internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) } - -/** - * Creates a class name for [Expression] subclassed to implement [mst] provided. - * - * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there - * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. - */ -internal tailrec fun buildName(mst: MST, collision: Int = 0): String { - val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(mst, collision + 1) -} - -@Suppress("FunctionName") -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return ClassWriter(flags).apply(block) -} - -internal inline fun ClassWriter.visitField( - access: Int, - name: String, - descriptor: String, - signature: String?, - value: Any?, - block: FieldVisitor.() -> Unit -): FieldVisitor { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return visitField(access, name, descriptor, signature, value).apply(block) -} - -private fun AsmBuilder.findSpecific(context: Algebra, name: String, parameterTypes: Array): Method? = - context.javaClass.methods.find { method -> - val nameValid = method.name == name - val arityValid = method.parameters.size == parameterTypes.size - val notBridgeInPrimitive = !(primitiveMode && method.isBridge) - - val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) -> - !(mstType != MstType.NUMBER && type == java.lang.Number::class.java) - } - - nameValid && arityValid && notBridgeInPrimitive && paramsValid - } - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds - * type expectation stack for needed arity. - * - * @return `true` if contains, else `false`. - */ -private fun AsmBuilder.buildExpectationStack( - context: Algebra, - name: String, - parameterTypes: Array -): Boolean { - val arity = parameterTypes.size - val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes) - - if (specific != null) - mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) } - else - repeat(arity) { expectationStack.push(tType) } - - return specific != null -} - -private fun AsmBuilder.mapTypes(method: Method, parameterTypes: Array): List = method - .parameterTypes - .zip(parameterTypes) - .map { (type, mstType) -> - when { - type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE - else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed - } - } - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts - * [AsmBuilder.invokeAlgebraOperation] of this method. - * - * @return `true` if contains, else `false`. - */ -private fun AsmBuilder.tryInvokeSpecific( - context: Algebra, - name: String, - parameterTypes: Array -): Boolean { - val arity = parameterTypes.size - val theName = methodNameAdapters[name to arity] ?: name - val spec = findSpecific(context, theName, parameterTypes) ?: return false - val owner = context::class.asm - - invokeAlgebraOperation( - owner = owner.internalName, - method = theName, - descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()), - expectedArity = arity, - opcode = INVOKEVIRTUAL - ) - - return true -} - -/** - * Builds specialized algebra call with option to fallback to generic algebra operation accepting String. - */ -internal inline fun AsmBuilder.buildAlgebraOperationCall( - context: Algebra, - name: String, - fallbackMethodName: String, - parameterTypes: Array, - parameters: AsmBuilder.() -> Unit -) { - contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } - val arity = parameterTypes.size - loadAlgebra() - if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) - parameters() - - if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = fallbackMethodName, - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - *Array(arity) { AsmBuilder.OBJECT_TYPE } - ), - - expectedArity = arity - ) -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt deleted file mode 100644 index 80e83c1bf..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt +++ /dev/null @@ -1,7 +0,0 @@ -@file:JvmName("MapIntrinsics") - -package scientifik.kmath.asm.internal - -@JvmOverloads -internal fun Map.getOrFail(key: K, default: V? = null): V = - this[key] ?: default ?: error("Parameter not found: $key") diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt similarity index 52% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt rename to kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt index 3acc6eb28..ae180bf3f 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt @@ -1,24 +1,20 @@ -package scietifik.kmath.asm +package kscience.kmath.asm -import scientifik.kmath.asm.compile -import scientifik.kmath.ast.mstInField -import scientifik.kmath.ast.mstInRing -import scientifik.kmath.ast.mstInSpace -import scientifik.kmath.expressions.invoke -import scientifik.kmath.operations.ByteRing -import scientifik.kmath.operations.RealField +import kscience.kmath.ast.* +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.ByteRing +import kscience.kmath.operations.ComplexField +import kscience.kmath.operations.RealField +import kscience.kmath.operations.toComplex import kotlin.test.Test import kotlin.test.assertEquals -internal class TestAsmAlgebras { +internal class TestAsmConsistencyWithInterpreter { @Test - fun space() { - val res1 = ByteRing.mstInSpace { - binaryOperation( - "+", - - unaryOperation( - "+", + fun mstSpace() { + val res1 = MstSpace.mstInSpace { + binaryOperationFunction("+")( + unaryOperationFunction("+")( number(3.toByte()) - (number(2.toByte()) + (multiply( add(number(1), number(1)), 2 @@ -27,14 +23,11 @@ internal class TestAsmAlgebras { number(1) ) + symbol("x") + zero - }("x" to 2.toByte()) + }("x" to MST.Numeric(2)) - val res2 = ByteRing.mstInSpace { - binaryOperation( - "+", - - unaryOperation( - "+", + val res2 = MstSpace.mstInSpace { + binaryOperationFunction("+")( + unaryOperationFunction("+")( number(3.toByte()) - (number(2.toByte()) + (multiply( add(number(1), number(1)), 2 @@ -43,19 +36,16 @@ internal class TestAsmAlgebras { number(1) ) + symbol("x") + zero - }.compile()("x" to 2.toByte()) + }.compile()("x" to MST.Numeric(2)) assertEquals(res1, res2) } @Test - fun ring() { + fun byteRing() { val res1 = ByteRing.mstInRing { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperationFunction("+")( + unaryOperationFunction("+")( (symbol("x") - (2.toByte() + (multiply( add(number(1), number(1)), 2 @@ -67,17 +57,13 @@ internal class TestAsmAlgebras { }("x" to 3.toByte()) val res2 = ByteRing.mstInRing { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperationFunction("+")( + unaryOperationFunction("+")( (symbol("x") - (2.toByte() + (multiply( add(number(1), number(1)), 2 ) + 1.toByte()))) * 3.0 - 1.toByte() ), - number(1) ) * number(2) }.compile()("x" to 3.toByte()) @@ -86,10 +72,9 @@ internal class TestAsmAlgebras { } @Test - fun field() { + fun realField() { val res1 = RealField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( - "+", + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one @@ -97,8 +82,7 @@ internal class TestAsmAlgebras { }("x" to 2.0) val res2 = RealField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( - "+", + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one @@ -107,4 +91,25 @@ internal class TestAsmAlgebras { assertEquals(res1, res2) } + + @Test + fun complexField() { + val res1 = ComplexField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }("x" to 2.0.toComplex()) + + val res2 = ComplexField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + number(1) / 2 + number(2.0) * one + ) + zero + }.compile()("x" to 2.0.toComplex()) + + assertEquals(res1, res2) + } } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmOperationsSupport.kt similarity index 51% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt rename to kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmOperationsSupport.kt index 36c254c38..2ce52aa87 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmOperationsSupport.kt @@ -1,14 +1,15 @@ -package scietifik.kmath.asm +package kscience.kmath.asm -import scientifik.kmath.asm.compile -import scientifik.kmath.ast.mstInField -import scientifik.kmath.ast.mstInSpace -import scientifik.kmath.expressions.invoke -import scientifik.kmath.operations.RealField +import kscience.kmath.ast.mstInExtendedField +import kscience.kmath.ast.mstInField +import kscience.kmath.ast.mstInSpace +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField +import kotlin.random.Random import kotlin.test.Test import kotlin.test.assertEquals -internal class TestAsmExpressions { +internal class TestAsmOperationsSupport { @Test fun testUnaryOperationInvocation() { val expression = RealField.mstInSpace { -symbol("x") }.compile() @@ -28,4 +29,13 @@ internal class TestAsmExpressions { val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) assertEquals(4.0, res) } + + @Test + fun testMultipleCalls() { + val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile() + val r = Random(0) + var s = 0.0 + repeat(1000000) { s += e("x" to r.nextDouble()) } + println(s) + } } diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt new file mode 100644 index 000000000..602c54651 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt @@ -0,0 +1,54 @@ +package kscience.kmath.asm + +import kscience.kmath.ast.mstInField +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestAsmSpecialization { + @Test + fun testUnaryPlus() { + val expr = RealField.mstInField { unaryOperationFunction("+")(symbol("x")) }.compile() + assertEquals(2.0, expr("x" to 2.0)) + } + + @Test + fun testUnaryMinus() { + val expr = RealField.mstInField { unaryOperationFunction("-")(symbol("x")) }.compile() + assertEquals(-2.0, expr("x" to 2.0)) + } + + @Test + fun testAdd() { + val expr = RealField.mstInField { binaryOperationFunction("+")(symbol("x"), symbol("x")) }.compile() + assertEquals(4.0, expr("x" to 2.0)) + } + + @Test + fun testSine() { + val expr = RealField.mstInField { unaryOperationFunction("sin")(symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 0.0)) + } + + @Test + fun testMinus() { + val expr = RealField.mstInField { binaryOperationFunction("-")(symbol("x"), symbol("x")) }.compile() + assertEquals(0.0, expr("x" to 2.0)) + } + + @Test + fun testDivide() { + val expr = RealField.mstInField { binaryOperationFunction("/")(symbol("x"), symbol("x")) }.compile() + assertEquals(1.0, expr("x" to 2.0)) + } + + @Test + fun testPower() { + val expr = RealField + .mstInField { binaryOperationFunction("pow")(symbol("x"), number(2)) } + .compile() + + assertEquals(4.0, expr("x" to 2.0)) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt new file mode 100644 index 000000000..c91568dbf --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt @@ -0,0 +1,22 @@ +package kscience.kmath.asm + +import kscience.kmath.ast.mstInRing +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +internal class TestAsmVariables { + @Test + fun testVariable() { + val expr = ByteRing.mstInRing { symbol("x") }.compile() + assertEquals(1.toByte(), expr("x" to 1.toByte())) + } + + @Test + fun testUndefinedVariableFails() { + val expr = ByteRing.mstInRing { symbol("x") }.compile() + assertFailsWith { expr() } + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserPrecedenceTest.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserPrecedenceTest.kt similarity index 81% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserPrecedenceTest.kt rename to kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserPrecedenceTest.kt index 9bdbb12c9..561fe51bd 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserPrecedenceTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserPrecedenceTest.kt @@ -1,9 +1,7 @@ -package scietifik.kmath.ast +package kscience.kmath.ast -import scientifik.kmath.ast.evaluate -import scientifik.kmath.ast.parseMath -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.RealField +import kscience.kmath.operations.Field +import kscience.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt similarity index 62% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt rename to kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt index 9179c3428..3aa5392c8 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/ast/ParserTest.kt @@ -1,13 +1,10 @@ -package scietifik.kmath.ast +package kscience.kmath.ast -import scientifik.kmath.ast.evaluate -import scientifik.kmath.ast.mstInField -import scientifik.kmath.ast.parseMath -import scientifik.kmath.expressions.invoke -import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.Complex -import scientifik.kmath.operations.ComplexField -import scientifik.kmath.operations.RealField +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.Algebra +import kscience.kmath.operations.Complex +import kscience.kmath.operations.ComplexField +import kscience.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals @@ -45,12 +42,15 @@ internal class ParserTest { val magicalAlgebra = object : Algebra { override fun symbol(value: String): String = value - override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError() - - override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) { - "magic" -> "$left ★ $right" - else -> throw NotImplementedError() + override fun unaryOperationFunction(operation: String): (arg: String) -> String { + throw NotImplementedError() } + + override fun binaryOperationFunction(operation: String): (left: String, right: String) -> String = + when (operation) { + "magic" -> { left, right -> "$left ★ $right" } + else -> throw NotImplementedError() + } } val mst = "magic(a, b)".parseMath() diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt deleted file mode 100644 index a88431e9d..000000000 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt +++ /dev/null @@ -1,55 +0,0 @@ -package scietifik.kmath.asm - -import scientifik.kmath.asm.compile -import scientifik.kmath.ast.mstInField -import scientifik.kmath.expressions.invoke -import scientifik.kmath.operations.RealField -import kotlin.test.Test -import kotlin.test.assertEquals - -internal class TestAsmSpecialization { - @Test - fun testUnaryPlus() { - val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() - assertEquals(2.0, expr("x" to 2.0)) - } - - @Test - fun testUnaryMinus() { - val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() - assertEquals(-2.0, expr("x" to 2.0)) - } - - @Test - fun testAdd() { - val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() - assertEquals(4.0, expr("x" to 2.0)) - } - - @Test - fun testSine() { - val expr = RealField.mstInField { unaryOperation("sin", symbol("x")) }.compile() - assertEquals(0.0, expr("x" to 0.0)) - } - - @Test - fun testMinus() { - val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() - assertEquals(0.0, expr("x" to 2.0)) - } - - @Test - fun testDivide() { - val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() - assertEquals(1.0, expr("x" to 2.0)) - } - - @Test - fun testPower() { - val expr = RealField - .mstInField { binaryOperation("power", symbol("x"), number(2)) } - .compile() - - assertEquals(4.0, expr("x" to 2.0)) - } -} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt deleted file mode 100644 index aafc75448..000000000 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmVariables.kt +++ /dev/null @@ -1,22 +0,0 @@ -package scietifik.kmath.asm - -import scientifik.kmath.ast.mstInRing -import scientifik.kmath.expressions.invoke -import scientifik.kmath.operations.ByteRing -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith - -internal class TestAsmVariables { - @Test - fun testVariableWithoutDefault() { - val expr = ByteRing.mstInRing { symbol("x") } - assertEquals(1.toByte(), expr("x" to 1.toByte())) - } - - @Test - fun testVariableWithoutDefaultFails() { - val expr = ByteRing.mstInRing { symbol("x") } - assertFailsWith { expr() } - } -} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt deleted file mode 100644 index 75659cc35..000000000 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ /dev/null @@ -1,25 +0,0 @@ -package scietifik.kmath.ast - -import scientifik.kmath.asm.compile -import scientifik.kmath.asm.expression -import scientifik.kmath.ast.mstInField -import scientifik.kmath.ast.parseMath -import scientifik.kmath.expressions.invoke -import scientifik.kmath.operations.Complex -import scientifik.kmath.operations.ComplexField -import kotlin.test.Test -import kotlin.test.assertEquals - -internal class AsmTest { - @Test - fun `compile MST`() { - val res = ComplexField.expression("2+2*(2+2)".parseMath())() - assertEquals(Complex(10.0, 0.0), res) - } - - @Test - fun `compile MSTExpression`() { - val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()() - assertEquals(Complex(10.0, 0.0), res) - } -} diff --git a/kmath-commons/build.gradle.kts b/kmath-commons/build.gradle.kts index 63c832b7c..6a44c92f2 100644 --- a/kmath-commons/build.gradle.kts +++ b/kmath-commons/build.gradle.kts @@ -1,12 +1,12 @@ -plugins { id("scientifik.jvm") } +plugins { + id("ru.mipt.npm.jvm") +} description = "Commons math binding for kmath" dependencies { api(project(":kmath-core")) api(project(":kmath-coroutines")) - api(project(":kmath-prob")) + api(project(":kmath-stat")) api(project(":kmath-functions")) api("org.apache.commons:commons-math3:3.6.1") } - -kotlin.sourceSets.all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt new file mode 100644 index 000000000..2912ddc4c --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -0,0 +1,126 @@ +package kscience.kmath.commons.expressions + +import kscience.kmath.expressions.* +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.ExtendedField +import kscience.kmath.operations.RingWithNumbers +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure + +/** + * A field over commons-math [DerivativeStructure]. + * + * @property order The derivation order. + * @property bindings The map of bindings values. All bindings are considered free parameters + */ +@OptIn(UnstableKMathAPI::class) +public class DerivativeStructureField( + public val order: Int, + bindings: Map, +) : ExtendedField, ExpressionAlgebra, RingWithNumbers { + public val numberOfVariables: Int = bindings.size + + public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } + public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) } + + override fun number(value: Number): DerivativeStructure = const(value.toDouble()) + + /** + * A class that implements both [DerivativeStructure] and a [Symbol] + */ + public inner class DerivativeStructureSymbol( + size: Int, + index: Int, + symbol: Symbol, + value: Double, + ) : DerivativeStructure(size, order, index, value), Symbol { + override val identity: String = symbol.identity + override fun toString(): String = identity + override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + override fun hashCode(): Int = identity.hashCode() + } + + /** + * Identity-based symbol bindings map + */ + private val variables: Map = bindings.entries.mapIndexed { index, (key, value) -> + key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value) + }.toMap() + + override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value) + + public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] + + public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity) + + override fun symbol(value: String): DerivativeStructureSymbol = bind(StringSymbol(value)) + + public fun DerivativeStructure.derivative(symbols: List): Double { + require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" } + val ordersCount = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size } + return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray()) + } + + public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList()) + + public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) + + public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) { + is Double -> a.multiply(k) + is Int -> a.multiply(k) + else -> a.multiply(k.toDouble()) + } + + public override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b) + public override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) + public override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() + public override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() + public override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan() + public override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() + public override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() + public override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() + public override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh() + public override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh() + public override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh() + public override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh() + public override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh() + public override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh() + + public override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { + is Double -> arg.pow(pow) + is Int -> arg.pow(pow) + else -> arg.pow(pow.toDouble()) + } + + public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) + public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() + public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() + + public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) + public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) + public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this + public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this + + public companion object : + AutoDiffProcessor> { + public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression> = + DerivativeStructureExpression(function) + } +} + + +/** + * A constructs that creates a derivative structure with required order on-demand + */ +public class DerivativeStructureExpression( + public val function: DerivativeStructureField.() -> DerivativeStructure, +) : DifferentiableExpression> { + public override operator fun invoke(arguments: Map): Double = + DerivativeStructureField(0, arguments).function().value + + /** + * Get the derivative expression with given orders + */ + public override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) } + } +} diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt new file mode 100644 index 000000000..850446afa --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt @@ -0,0 +1,89 @@ +package kscience.kmath.commons.linear + +import kscience.kmath.linear.DiagonalFeature +import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.Point +import kscience.kmath.linear.origin +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.structures.Matrix +import org.apache.commons.math3.linear.* +import kotlin.reflect.KClass +import kotlin.reflect.cast + +public inline class CMMatrix(public val origin: RealMatrix) : Matrix { + public override val rowNum: Int get() = origin.rowDimension + public override val colNum: Int get() = origin.columnDimension + + @UnstableKMathAPI + override fun getFeature(type: KClass): T? = when (type) { + DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null + else -> null + }?.let { type.cast(it) } + + public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) +} + + +public fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this) + +public class CMVector(public val origin: RealVector) : Point { + public override val size: Int get() = origin.dimension + + public override operator fun get(index: Int): Double = origin.getEntry(index) + + public override operator fun iterator(): Iterator = origin.toArray().iterator() +} + +public fun Point.toCM(): CMVector = if (this is CMVector) this else { + val array = DoubleArray(size) { this[it] } + CMVector(ArrayRealVector(array)) +} + +public fun RealVector.toPoint(): CMVector = CMVector(this) + +public object CMMatrixContext : MatrixContext { + public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix { + val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } } + return CMMatrix(Array2DRowRealMatrix(array)) + } + + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toCM(): CMMatrix = when (val matrix = origin) { + is CMMatrix -> matrix + else -> { + //TODO add feature analysis + val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } + CMMatrix(Array2DRowRealMatrix(array)) + } + } + + public override fun Matrix.dot(other: Matrix): CMMatrix = + CMMatrix(toCM().origin.multiply(other.toCM().origin)) + + public override fun Matrix.dot(vector: Point): CMVector = + CMVector(toCM().origin.preMultiply(vector.toCM().origin)) + + public override operator fun Matrix.unaryMinus(): CMMatrix = + produce(rowNum, colNum) { i, j -> -get(i, j) } + + public override fun add(a: Matrix, b: Matrix): CMMatrix = + CMMatrix(a.toCM().origin.multiply(b.toCM().origin)) + + public override operator fun Matrix.minus(b: Matrix): CMMatrix = + CMMatrix(toCM().origin.subtract(b.toCM().origin)) + + public override fun multiply(a: Matrix, k: Number): CMMatrix = + CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble())) + + public override operator fun Matrix.times(value: Double): CMMatrix = + produce(rowNum, colNum) { i, j -> get(i, j) * value } +} + +public operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = + CMMatrix(origin.add(other.origin)) + +public operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = + CMMatrix(origin.subtract(other.origin)) + +public infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = + CMMatrix(origin.multiply(other.origin)) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMSolver.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMSolver.kt new file mode 100644 index 000000000..210014e1a --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMSolver.kt @@ -0,0 +1,41 @@ +package kscience.kmath.commons.linear + +import kscience.kmath.linear.Point +import kscience.kmath.structures.Matrix +import org.apache.commons.math3.linear.* + +public enum class CMDecomposition { + LUP, + QR, + RRQR, + EIGEN, + CHOLESKY +} + +public fun CMMatrixContext.solver( + a: Matrix, + decomposition: CMDecomposition = CMDecomposition.LUP +): DecompositionSolver = when (decomposition) { + CMDecomposition.LUP -> LUDecomposition(a.toCM().origin).solver + CMDecomposition.RRQR -> RRQRDecomposition(a.toCM().origin).solver + CMDecomposition.QR -> QRDecomposition(a.toCM().origin).solver + CMDecomposition.EIGEN -> EigenDecomposition(a.toCM().origin).solver + CMDecomposition.CHOLESKY -> CholeskyDecomposition(a.toCM().origin).solver +} + +public fun CMMatrixContext.solve( + a: Matrix, + b: Matrix, + decomposition: CMDecomposition = CMDecomposition.LUP +): CMMatrix = solver(a, decomposition).solve(b.toCM().origin).asMatrix() + +public fun CMMatrixContext.solve( + a: Matrix, + b: Point, + decomposition: CMDecomposition = CMDecomposition.LUP +): CMVector = solver(a, decomposition).solve(b.toCM().origin).toPoint() + +public fun CMMatrixContext.inverse( + a: Matrix, + decomposition: CMDecomposition = CMDecomposition.LUP +): CMMatrix = solver(a, decomposition).inverse.asMatrix() diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt new file mode 100644 index 000000000..d6f79529a --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt @@ -0,0 +1,110 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.expressions.* +import kscience.kmath.stat.OptimizationFeature +import kscience.kmath.stat.OptimizationProblem +import kscience.kmath.stat.OptimizationProblemFactory +import kscience.kmath.stat.OptimizationResult +import org.apache.commons.math3.optim.* +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType +import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer +import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction +import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient +import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer +import kotlin.reflect.KClass + +public operator fun PointValuePair.component1(): DoubleArray = point +public operator fun PointValuePair.component2(): Double = value + +public class CMOptimizationProblem(override val symbols: List, ) : + OptimizationProblem, SymbolIndexer, OptimizationFeature { + private val optimizationData: HashMap, OptimizationData> = HashMap() + private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null + public var convergenceChecker: ConvergenceChecker = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE, + DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER) + + public fun addOptimizationData(data: OptimizationData) { + optimizationData[data::class] = data + } + + init { + addOptimizationData(MaxEval.unlimited()) + } + + public fun exportOptimizationData(): List = optimizationData.values.toList() + + public override fun initialGuess(map: Map): Unit { + addOptimizationData(InitialGuess(map.toDoubleArray())) + } + + public override fun expression(expression: Expression): Unit { + val objectiveFunction = ObjectiveFunction { + val args = it.toMap() + expression(args) + } + addOptimizationData(objectiveFunction) + } + + public override fun diffExpression(expression: DifferentiableExpression>) { + expression(expression) + val gradientFunction = ObjectiveFunctionGradient { + val args = it.toMap() + DoubleArray(symbols.size) { index -> + expression.derivative(symbols[index])(args) + } + } + addOptimizationData(gradientFunction) + if (optimizatorBuilder == null) { + optimizatorBuilder = { + NonLinearConjugateGradientOptimizer( + NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, + convergenceChecker + ) + } + } + } + + public fun simplex(simplex: AbstractSimplex) { + addOptimizationData(simplex) + //Set optimization builder to simplex if it is not present + if (optimizatorBuilder == null) { + optimizatorBuilder = { SimplexOptimizer(convergenceChecker) } + } + } + + public fun simplexSteps(steps: Map) { + simplex(NelderMeadSimplex(steps.toDoubleArray())) + } + + public fun goal(goalType: GoalType) { + addOptimizationData(goalType) + } + + public fun optimizer(block: () -> MultivariateOptimizer) { + optimizatorBuilder = block + } + + override fun update(result: OptimizationResult) { + initialGuess(result.point) + } + + override fun optimize(): OptimizationResult { + val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined") + val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray()) + return OptimizationResult(point.toMap(), value, setOf(this)) + } + + public companion object : OptimizationProblemFactory { + public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4 + public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4 + public const val DEFAULT_MAX_ITER: Int = 1000 + + override fun build(symbols: List): CMOptimizationProblem = CMOptimizationProblem(symbols) + } +} + +public fun CMOptimizationProblem.initialGuess(vararg pairs: Pair): Unit = initialGuess(pairs.toMap()) +public fun CMOptimizationProblem.simplexSteps(vararg pairs: Pair): Unit = simplexSteps(pairs.toMap()) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt new file mode 100644 index 000000000..b8e8bfd4b --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt @@ -0,0 +1,67 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.commons.expressions.DerivativeStructureField +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.Symbol +import kscience.kmath.stat.Fitting +import kscience.kmath.stat.OptimizationResult +import kscience.kmath.stat.optimizeWith +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.asBuffer +import org.apache.commons.math3.analysis.differentiation.DerivativeStructure +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType + +/** + * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation + */ +public fun Fitting.chiSquared( + x: Buffer, + y: Buffer, + yErr: Buffer, + model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, +): DifferentiableExpression> = chiSquared(DerivativeStructureField, x, y, yErr, model) + +/** + * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation + */ +public fun Fitting.chiSquared( + x: Iterable, + y: Iterable, + yErr: Iterable, + model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, +): DifferentiableExpression> = chiSquared( + DerivativeStructureField, + x.toList().asBuffer(), + y.toList().asBuffer(), + yErr.toList().asBuffer(), + model +) + +/** + * Optimize expression without derivatives + */ +public fun Expression.optimize( + vararg symbols: Symbol, + configuration: CMOptimizationProblem.() -> Unit, +): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) + +/** + * Optimize differentiable expression + */ +public fun DifferentiableExpression>.optimize( + vararg symbols: Symbol, + configuration: CMOptimizationProblem.() -> Unit, +): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) + +public fun DifferentiableExpression>.minimize( + vararg startPoint: Pair, + configuration: CMOptimizationProblem.() -> Unit = {}, +): OptimizationResult { + require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = CMOptimizationProblem(startPoint.map { it.first }).apply(configuration) + problem.diffExpression(this) + problem.initialGuess(startPoint.toMap()) + problem.goal(GoalType.MINIMIZE) + return problem.optimize() +} \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt new file mode 100644 index 000000000..1eab5f2bd --- /dev/null +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -0,0 +1,34 @@ +package kscience.kmath.commons.random + +import kscience.kmath.stat.RandomGenerator + +public class CMRandomGeneratorWrapper( + public val factory: (IntArray) -> RandomGenerator, +) : org.apache.commons.math3.random.RandomGenerator { + private var generator: RandomGenerator = factory(intArrayOf()) + + public override fun nextBoolean(): Boolean = generator.nextBoolean() + public override fun nextFloat(): Float = generator.nextDouble().toFloat() + + public override fun setSeed(seed: Int) { + generator = factory(intArrayOf(seed)) + } + + public override fun setSeed(seed: IntArray) { + generator = factory(seed) + } + + public override fun setSeed(seed: Long) { + setSeed(seed.toInt()) + } + + public override fun nextBytes(bytes: ByteArray) { + generator.fillBytes(bytes) + } + + public override fun nextInt(): Int = generator.nextInt() + public override fun nextInt(n: Int): Int = generator.nextInt(n) + public override fun nextGaussian(): Double = TODO() + public override fun nextDouble(): Double = generator.nextDouble() + public override fun nextLong(): Long = generator.nextLong() +} diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/transform/Transformations.kt similarity index 81% rename from kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt rename to kmath-commons/src/main/kotlin/kscience/kmath/commons/transform/Transformations.kt index eb1b5b69a..cd2896be6 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/transform/Transformations.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/transform/Transformations.kt @@ -1,20 +1,19 @@ -package scientifik.kmath.commons.transform +package kscience.kmath.commons.transform import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map +import kscience.kmath.operations.Complex +import kscience.kmath.streaming.chunked +import kscience.kmath.streaming.spread +import kscience.kmath.structures.* import org.apache.commons.math3.transform.* -import scientifik.kmath.operations.Complex -import scientifik.kmath.streaming.chunked -import scientifik.kmath.streaming.spread -import scientifik.kmath.structures.* /** * Streaming and buffer transformations */ -object Transformations { - +public object Transformations { private fun Buffer.toArray(): Array = Array(size) { org.apache.commons.math3.complex.Complex(get(it).re, get(it).im) } @@ -32,35 +31,35 @@ object Transformations { Complex(value.real, value.imaginary) } - fun fourier( + public fun fourier( normalization: DftNormalization = DftNormalization.STANDARD, direction: TransformType = TransformType.FORWARD ): SuspendBufferTransform = { FastFourierTransformer(normalization).transform(it.toArray(), direction).asBuffer() } - fun realFourier( + public fun realFourier( normalization: DftNormalization = DftNormalization.STANDARD, direction: TransformType = TransformType.FORWARD ): SuspendBufferTransform = { FastFourierTransformer(normalization).transform(it.asArray(), direction).asBuffer() } - fun sine( + public fun sine( normalization: DstNormalization = DstNormalization.STANDARD_DST_I, direction: TransformType = TransformType.FORWARD ): SuspendBufferTransform = { FastSineTransformer(normalization).transform(it.asArray(), direction).asBuffer() } - fun cosine( + public fun cosine( normalization: DctNormalization = DctNormalization.STANDARD_DCT_I, direction: TransformType = TransformType.FORWARD ): SuspendBufferTransform = { FastCosineTransformer(normalization).transform(it.asArray(), direction).asBuffer() } - fun hadamard( + public fun hadamard( direction: TransformType = TransformType.FORWARD ): SuspendBufferTransform = { FastHadamardTransformer().transform(it.asArray(), direction).asBuffer() @@ -71,7 +70,7 @@ object Transformations { * Process given [Flow] with commons-math fft transformation */ @FlowPreview -fun Flow>.FFT( +public fun Flow>.FFT( normalization: DftNormalization = DftNormalization.STANDARD, direction: TransformType = TransformType.FORWARD ): Flow> { @@ -81,7 +80,7 @@ fun Flow>.FFT( @FlowPreview @JvmName("realFFT") -fun Flow>.FFT( +public fun Flow>.FFT( normalization: DftNormalization = DftNormalization.STANDARD, direction: TransformType = TransformType.FORWARD ): Flow> { @@ -90,20 +89,18 @@ fun Flow>.FFT( } /** - * Process a continous flow of real numbers in FFT splitting it in chunks of [bufferSize]. + * Process a continuous flow of real numbers in FFT splitting it in chunks of [bufferSize]. */ @FlowPreview @JvmName("realFFT") -fun Flow.FFT( +public fun Flow.FFT( bufferSize: Int = Int.MAX_VALUE, normalization: DftNormalization = DftNormalization.STANDARD, direction: TransformType = TransformType.FORWARD -): Flow { - return chunked(bufferSize).FFT(normalization,direction).spread() -} +): Flow = chunked(bufferSize).FFT(normalization, direction).spread() /** * Map a complex flow into real flow by taking real part of each number */ @FlowPreview -fun Flow.real(): Flow = map{it.re} \ No newline at end of file +public fun Flow.real(): Flow = map { it.re } diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt deleted file mode 100644 index 9119991e5..000000000 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ /dev/null @@ -1,137 +0,0 @@ -package scientifik.kmath.commons.expressions - -import org.apache.commons.math3.analysis.differentiation.DerivativeStructure -import scientifik.kmath.expressions.Expression -import scientifik.kmath.expressions.ExpressionAlgebra -import scientifik.kmath.operations.ExtendedField -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.invoke -import kotlin.properties.ReadOnlyProperty -import kotlin.reflect.KProperty - -/** - * A field wrapping commons-math derivative structures - */ -class DerivativeStructureField( - val order: Int, - val parameters: Map -) : ExtendedField { - override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } - override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) } - - private val variables: Map = parameters.mapValues { (key, value) -> - DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value) - } - - val variable: ReadOnlyProperty = object : ReadOnlyProperty { - override fun getValue(thisRef: Any?, property: KProperty<*>): DerivativeStructure = - variables[property.name] ?: error("A variable with name ${property.name} does not exist") - } - - fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure = - variables[name] ?: default ?: error("A variable with name $name does not exist") - - fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble()) - - fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double { - return deriv(mapOf(parName to order)) - } - - fun DerivativeStructure.deriv(orders: Map): Double { - return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray()) - } - - fun DerivativeStructure.deriv(vararg orders: Pair): Double = deriv(mapOf(*orders)) - - override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) - - override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) { - is Double -> a.multiply(k) - is Int -> a.multiply(k) - else -> a.multiply(k.toDouble()) - } - - override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b) - - override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) - - override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() - override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() - override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan() - override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() - override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() - override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() - - override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh() - override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh() - override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh() - override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh() - override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh() - override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh() - - override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { - is Double -> arg.pow(pow) - is Int -> arg.pow(pow) - else -> arg.pow(pow.toDouble()) - } - - fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) - override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() - override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() - - override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) - override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) - override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this - override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this -} - -/** - * A constructs that creates a derivative structure with required order on-demand - */ -class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression { - override operator fun invoke(arguments: Map): Double = DerivativeStructureField( - 0, - arguments - ).run(function).value - - /** - * Get the derivative expression with given orders - * TODO make result [DiffExpression] - */ - fun derivative(orders: Map): Expression = object : Expression { - override operator fun invoke(arguments: Map): Double = - (DerivativeStructureField(orders.values.max() ?: 0, arguments)) { function().deriv(orders) } - } - - //TODO add gradient and maybe other vector operators -} - -fun DiffExpression.derivative(vararg orders: Pair): Expression = derivative(mapOf(*orders)) -fun DiffExpression.derivative(name: String): Expression = derivative(name to 1) - -/** - * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) - */ -object DiffExpressionAlgebra : ExpressionAlgebra, Field { - override fun variable(name: String, default: Double?): DiffExpression = - DiffExpression { variable(name, default?.const()) } - - override fun const(value: Double): DiffExpression = - DiffExpression { value.const() } - - override fun add(a: DiffExpression, b: DiffExpression): DiffExpression = - DiffExpression { a.function(this) + b.function(this) } - - override val zero: DiffExpression = DiffExpression { 0.0.const() } - - override fun multiply(a: DiffExpression, k: Number): DiffExpression = - DiffExpression { a.function(this) * k } - - override val one: DiffExpression = DiffExpression { 1.0.const() } - - override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression = - DiffExpression { a.function(this) * b.function(this) } - - override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression = - DiffExpression { a.function(this) / b.function(this) } -} diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt deleted file mode 100644 index f0bbdbe65..000000000 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMMatrix.kt +++ /dev/null @@ -1,93 +0,0 @@ -package scientifik.kmath.commons.linear - -import org.apache.commons.math3.linear.* -import scientifik.kmath.linear.* -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.NDStructure - -class CMMatrix(val origin: RealMatrix, features: Set? = null) : - FeaturedMatrix { - override val rowNum: Int get() = origin.rowDimension - override val colNum: Int get() = origin.columnDimension - - override val features: Set = features ?: sequence { - if (origin is DiagonalMatrix) yield(DiagonalFeature) - }.toHashSet() - - override fun suggestFeature(vararg features: MatrixFeature): CMMatrix = - CMMatrix(origin, this.features + features) - - override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) - - override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) - } - - override fun hashCode(): Int { - var result = origin.hashCode() - result = 31 * result + features.hashCode() - return result - } -} - -fun Matrix.toCM(): CMMatrix = if (this is CMMatrix) { - this -} else { - //TODO add feature analysis - val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } - CMMatrix(Array2DRowRealMatrix(array)) -} - -fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this) - -class CMVector(val origin: RealVector) : Point { - override val size: Int get() = origin.dimension - - override operator fun get(index: Int): Double = origin.getEntry(index) - - override operator fun iterator(): Iterator = origin.toArray().iterator() -} - -fun Point.toCM(): CMVector = if (this is CMVector) this else { - val array = DoubleArray(size) { this[it] } - CMVector(ArrayRealVector(array)) -} - -fun RealVector.toPoint(): CMVector = CMVector(this) - -object CMMatrixContext : MatrixContext { - override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix { - val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } } - return CMMatrix(Array2DRowRealMatrix(array)) - } - - override fun Matrix.dot(other: Matrix): CMMatrix = - CMMatrix(this.toCM().origin.multiply(other.toCM().origin)) - - override fun Matrix.dot(vector: Point): CMVector = - CMVector(this.toCM().origin.preMultiply(vector.toCM().origin)) - - override operator fun Matrix.unaryMinus(): CMMatrix = - produce(rowNum, colNum) { i, j -> -get(i, j) } - - override fun add(a: Matrix, b: Matrix): CMMatrix = - CMMatrix(a.toCM().origin.multiply(b.toCM().origin)) - - override operator fun Matrix.minus(b: Matrix): CMMatrix = - CMMatrix(this.toCM().origin.subtract(b.toCM().origin)) - - override fun multiply(a: Matrix, k: Number): CMMatrix = - CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble())) - - override operator fun Matrix.times(value: Double): Matrix = - produce(rowNum, colNum) { i, j -> get(i, j) * value } -} - -operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = - CMMatrix(this.origin.add(other.origin)) - -operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = - CMMatrix(this.origin.subtract(other.origin)) - -infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = - CMMatrix(this.origin.multiply(other.origin)) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMSolver.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMSolver.kt deleted file mode 100644 index 77b688e31..000000000 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/linear/CMSolver.kt +++ /dev/null @@ -1,40 +0,0 @@ -package scientifik.kmath.commons.linear - -import org.apache.commons.math3.linear.* -import scientifik.kmath.linear.Point -import scientifik.kmath.structures.Matrix - -enum class CMDecomposition { - LUP, - QR, - RRQR, - EIGEN, - CHOLESKY -} - - -fun CMMatrixContext.solver(a: Matrix, decomposition: CMDecomposition = CMDecomposition.LUP) = - when (decomposition) { - CMDecomposition.LUP -> LUDecomposition(a.toCM().origin).solver - CMDecomposition.RRQR -> RRQRDecomposition(a.toCM().origin).solver - CMDecomposition.QR -> QRDecomposition(a.toCM().origin).solver - CMDecomposition.EIGEN -> EigenDecomposition(a.toCM().origin).solver - CMDecomposition.CHOLESKY -> CholeskyDecomposition(a.toCM().origin).solver - } - -fun CMMatrixContext.solve( - a: Matrix, - b: Matrix, - decomposition: CMDecomposition = CMDecomposition.LUP -) = solver(a, decomposition).solve(b.toCM().origin).asMatrix() - -fun CMMatrixContext.solve( - a: Matrix, - b: Point, - decomposition: CMDecomposition = CMDecomposition.LUP -) = solver(a, decomposition).solve(b.toCM().origin).toPoint() - -fun CMMatrixContext.inverse( - a: Matrix, - decomposition: CMDecomposition = CMDecomposition.LUP -) = solver(a, decomposition).inverse.asMatrix() diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt deleted file mode 100644 index cb2b5dd9c..000000000 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/random/CMRandomGeneratorWrapper.kt +++ /dev/null @@ -1,33 +0,0 @@ -package scientifik.kmath.commons.random - -import scientifik.kmath.prob.RandomGenerator - -class CMRandomGeneratorWrapper(val factory: (IntArray) -> RandomGenerator) : - org.apache.commons.math3.random.RandomGenerator { - private var generator: RandomGenerator = factory(intArrayOf()) - - override fun nextBoolean(): Boolean = generator.nextBoolean() - override fun nextFloat(): Float = generator.nextDouble().toFloat() - - override fun setSeed(seed: Int) { - generator = factory(intArrayOf(seed)) - } - - override fun setSeed(seed: IntArray) { - generator = factory(seed) - } - - override fun setSeed(seed: Long) { - setSeed(seed.toInt()) - } - - override fun nextBytes(bytes: ByteArray) { - generator.fillBytes(bytes) - } - - override fun nextInt(): Int = generator.nextInt() - override fun nextInt(n: Int): Int = generator.nextInt(n) - override fun nextGaussian(): Double = TODO() - override fun nextDouble(): Double = generator.nextDouble() - override fun nextLong(): Long = generator.nextLong() -} diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt new file mode 100644 index 000000000..7511a38ed --- /dev/null +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt @@ -0,0 +1,50 @@ +package kscience.kmath.commons.expressions + +import kscience.kmath.expressions.* +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFails + +internal inline fun diff( + order: Int, + vararg parameters: Pair, + block: DerivativeStructureField.() -> Unit, +): Unit { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + DerivativeStructureField(order, mapOf(*parameters)).run(block) +} + +internal class AutoDiffTest { + private val x by symbol + private val y by symbol + + @Test + fun derivativeStructureFieldTest() { + diff(2, x to 1.0, y to 1.0) { + val x = bind(x)//by binding() + val y = symbol("y") + val z = x * (-sin(x * y) + y) + 2.0 + println(z.derivative(x)) + println(z.derivative(y,x)) + assertEquals(z.derivative(x, y), z.derivative(y, x)) + //check that improper order cause failure + assertFails { z.derivative(x,x,y) } + } + } + + @Test + fun autoDifTest() { + val f = DerivativeStructureExpression { + val x by binding() + val y by binding() + x.pow(2) + 2 * x * y + y.pow(2) + 1 + } + + assertEquals(10.0, f(x to 1.0, y to 2.0)) + assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0)) + assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0)) + assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0)) + } +} diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt new file mode 100644 index 000000000..3290c8f32 --- /dev/null +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -0,0 +1,68 @@ +package kscience.kmath.commons.optimization + +import kscience.kmath.commons.expressions.DerivativeStructureExpression +import kscience.kmath.expressions.symbol +import kscience.kmath.stat.Distribution +import kscience.kmath.stat.Fitting +import kscience.kmath.stat.RandomGenerator +import kscience.kmath.stat.normal +import org.junit.jupiter.api.Test +import kotlin.math.pow + +internal class OptimizeTest { + val x by symbol + val y by symbol + + val normal = DerivativeStructureExpression { + exp(-bind(x).pow(2) / 2) + exp(-bind(y).pow(2) / 2) + } + + @Test + fun testGradientOptimization() { + val result = normal.optimize(x, y) { + initialGuess(x to 1.0, y to 1.0) + //no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function + } + println(result.point) + println(result.value) + } + + @Test + fun testSimplexOptimization() { + val result = normal.optimize(x, y) { + initialGuess(x to 1.0, y to 1.0) + simplexSteps(x to 2.0, y to 0.5) + //this sets simplex optimizer + } + println(result.point) + println(result.value) + } + + @Test + fun testCmFit() { + val a by symbol + val b by symbol + val c by symbol + + val sigma = 1.0 + val generator = Distribution.normal(0.0, sigma) + val chain = generator.sample(RandomGenerator.default(112667)) + val x = (1..100).map(Int::toDouble) + + val y = x.map { + it.pow(2) + it + 1 + chain.nextDouble() + } + + val yErr = List(x.size) { sigma } + + val chi2 = Fitting.chiSquared(x, y, yErr) { x1 -> + val cWithDefault = bindOrNull(c) ?: one + bind(a) * x1.pow(2) + bind(b) * x1 + cWithDefault + } + + val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) + println(result) + println("Chi2/dof = ${result.value / (x.size - 3)}") + } + +} \ No newline at end of file diff --git a/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt b/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt deleted file mode 100644 index bbdcff2fc..000000000 --- a/kmath-commons/src/test/kotlin/scientifik/kmath/commons/expressions/AutoDiffTest.kt +++ /dev/null @@ -1,36 +0,0 @@ -package scientifik.kmath.commons.expressions - -import scientifik.kmath.expressions.invoke -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract -import kotlin.test.Test -import kotlin.test.assertEquals - -inline fun diff(order: Int, vararg parameters: Pair, block: DerivativeStructureField.() -> R): R { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return DerivativeStructureField(order, mapOf(*parameters)).run(block) -} - -class AutoDiffTest { - @Test - fun derivativeStructureFieldTest() { - val res = diff(3, "x" to 1.0, "y" to 1.0) { - val x by variable - val y = variable("y") - val z = x * (-sin(x * y) + y) - z.deriv("x") - } - } - - @Test - fun autoDifTest() { - val f = DiffExpression { - val x by variable - val y by variable - x.pow(2) + 2 * x * y + y.pow(2) + 1 - } - - assertEquals(10.0, f("x" to 1.0, "y" to 2.0)) - assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0)) - } -} \ No newline at end of file diff --git a/kmath-core/README.md b/kmath-core/README.md index aed33a257..7882e5252 100644 --- a/kmath-core/README.md +++ b/kmath-core/README.md @@ -2,39 +2,48 @@ The core features of KMath: -- Algebraic structures: contexts and elements. -- ND structures. -- Buffers. -- Functional Expressions. -- Domains. -- Automatic differentiation. + - [algebras](src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : Algebraic structures: contexts and elements + - [nd](src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Many-dimensional structures + - [buffers](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure + - [expressions](src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions + - [domains](src/commonMain/kotlin/kscience/kmath/domains) : Domains + - [autodif](src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation + > #### Artifact: -> This module is distributed in the artifact `scientifik:kmath-core:0.1.4-dev-8`. -> +> +> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-4`. +> +> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion) +> +> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion) +> > **Gradle:** > > ```gradle > repositories { -> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } +> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } +> maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' } -> maven { url https://dl.bintray.com/hotkeytlt/maven' } +> maven { url 'https://dl.bintray.com/hotkeytlt/maven' } +> > } > > dependencies { -> implementation 'scientifik:kmath-core:0.1.4-dev-8' +> implementation 'kscience.kmath:kmath-core:0.2.0-dev-4' > } > ``` > **Gradle Kotlin DSL:** > > ```kotlin > repositories { -> maven("https://dl.bintray.com/mipt-npm/scientifik") +> maven("https://dl.bintray.com/kotlin/kotlin-eap") +> maven("https://dl.bintray.com/mipt-npm/kscience") > maven("https://dl.bintray.com/mipt-npm/dev") > maven("https://dl.bintray.com/hotkeytlt/maven") > } > -> dependencies {`` -> implementation("scientifik:kmath-core:0.1.4-dev-8") +> dependencies { +> implementation("kscience.kmath:kmath-core:0.2.0-dev-4") > } > ``` diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index 18c0cc771..9ed7e690b 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -1,11 +1,52 @@ plugins { - id("scientifik.mpp") + id("ru.mipt.npm.mpp") + id("ru.mipt.npm.native") } -kotlin.sourceSets { - commonMain { - dependencies { - api(project(":kmath-memory")) - } +kotlin.sourceSets.commonMain { + dependencies { + api(project(":kmath-memory")) } } + +readme { + description = "Core classes, algebra definitions, basic linear algebra" + maturity = ru.mipt.npm.gradle.Maturity.DEVELOPMENT + propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) + + feature( + id = "algebras", + description = "Algebraic structures: contexts and elements", + ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt" + ) + + feature( + id = "nd", + description = "Many-dimensional structures", + ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt" + ) + + feature( + id = "buffers", + description = "One-dimensional structure", + ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt" + ) + + feature( + id = "expressions", + description = "Functional Expressions", + ref = "src/commonMain/kotlin/kscience/kmath/expressions" + ) + + feature( + id = "domains", + description = "Domains", + ref = "src/commonMain/kotlin/kscience/kmath/domains" + ) + + feature( + id = "autodif", + description = "Automatic differentiation", + ref = "src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt" + ) +} diff --git a/kmath-core/docs/README-TEMPLATE.md b/kmath-core/docs/README-TEMPLATE.md new file mode 100644 index 000000000..83d1ebdce --- /dev/null +++ b/kmath-core/docs/README-TEMPLATE.md @@ -0,0 +1,7 @@ +# The Core Module (`kmath-core`) + +The core features of KMath: + +${features} + +${artifact} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/Domain.kt similarity index 54% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/domains/Domain.kt index 341383bfb..5c3cff2c5 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/Domain.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/Domain.kt @@ -1,20 +1,20 @@ -package scientifik.kmath.domains +package kscience.kmath.domains -import scientifik.kmath.linear.Point +import kscience.kmath.linear.Point /** * A simple geometric domain. * * @param T the type of element of this domain. */ -interface Domain { +public interface Domain { /** * Checks if the specified point is contained in this domain. */ - operator fun contains(point: Point): Boolean + public operator fun contains(point: Point): Boolean /** * Number of hyperspace dimensions. */ - val dimension: Int + public val dimension: Int } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/HyperSquareDomain.kt similarity index 54% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/domains/HyperSquareDomain.kt index 66798c42f..b45cf6bf5 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/HyperSquareDomain.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/HyperSquareDomain.kt @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package scientifik.kmath.domains +package kscience.kmath.domains -import scientifik.kmath.linear.Point -import scientifik.kmath.structures.RealBuffer -import scientifik.kmath.structures.indices +import kscience.kmath.linear.Point +import kscience.kmath.structures.RealBuffer +import kscience.kmath.structures.indices /** * @@ -25,23 +25,22 @@ import scientifik.kmath.structures.indices * * @author Alexander Nozik */ -class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain { +public class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBuffer) : RealDomain { + public override val dimension: Int get() = lower.size - override operator fun contains(point: Point): Boolean = point.indices.all { i -> + public override operator fun contains(point: Point): Boolean = point.indices.all { i -> point[i] in lower[i]..upper[i] } - override val dimension: Int get() = lower.size + public override fun getLowerBound(num: Int, point: Point): Double? = lower[num] - override fun getLowerBound(num: Int, point: Point): Double? = lower[num] + public override fun getLowerBound(num: Int): Double? = lower[num] - override fun getLowerBound(num: Int): Double? = lower[num] + public override fun getUpperBound(num: Int, point: Point): Double? = upper[num] - override fun getUpperBound(num: Int, point: Point): Double? = upper[num] + public override fun getUpperBound(num: Int): Double? = upper[num] - override fun getUpperBound(num: Int): Double? = upper[num] - - override fun nearestInDomain(point: Point): Point { + public override fun nearestInDomain(point: Point): Point { val res = DoubleArray(point.size) { i -> when { point[i] < lower[i] -> lower[i] @@ -53,16 +52,14 @@ class HyperSquareDomain(private val lower: RealBuffer, private val upper: RealBu return RealBuffer(*res) } - override fun volume(): Double { + public override fun volume(): Double { var res = 1.0 + for (i in 0 until dimension) { - if (lower[i].isInfinite() || upper[i].isInfinite()) { - return Double.POSITIVE_INFINITY - } - if (upper[i] > lower[i]) { - res *= upper[i] - lower[i] - } + if (lower[i].isInfinite() || upper[i].isInfinite()) return Double.POSITIVE_INFINITY + if (upper[i] > lower[i]) res *= upper[i] - lower[i] } + return res } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/RealDomain.kt similarity index 71% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/domains/RealDomain.kt index 7507ccd59..369b093bb 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/RealDomain.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/RealDomain.kt @@ -13,17 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package scientifik.kmath.domains +package kscience.kmath.domains -import scientifik.kmath.linear.Point +import kscience.kmath.linear.Point /** * n-dimensional volume * * @author Alexander Nozik */ -interface RealDomain : Domain { - fun nearestInDomain(point: Point): Point +public interface RealDomain : Domain { + public fun nearestInDomain(point: Point): Point /** * The lower edge for the domain going down from point @@ -31,7 +31,7 @@ interface RealDomain : Domain { * @param point * @return */ - fun getLowerBound(num: Int, point: Point): Double? + public fun getLowerBound(num: Int, point: Point): Double? /** * The upper edge of the domain going up from point @@ -39,25 +39,25 @@ interface RealDomain : Domain { * @param point * @return */ - fun getUpperBound(num: Int, point: Point): Double? + public fun getUpperBound(num: Int, point: Point): Double? /** * Global lower edge * @param num * @return */ - fun getLowerBound(num: Int): Double? + public fun getLowerBound(num: Int): Double? /** * Global upper edge * @param num * @return */ - fun getUpperBound(num: Int): Double? + public fun getUpperBound(num: Int): Double? /** * Hyper volume * @return */ - fun volume(): Double + public fun volume(): Double } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/UnconstrainedDomain.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/UnconstrainedDomain.kt new file mode 100644 index 000000000..e2efb51ab --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/UnconstrainedDomain.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kscience.kmath.domains + +import kscience.kmath.linear.Point + +public class UnconstrainedDomain(public override val dimension: Int) : RealDomain { + public override operator fun contains(point: Point): Boolean = true + + public override fun getLowerBound(num: Int, point: Point): Double? = Double.NEGATIVE_INFINITY + + public override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY + + public override fun getUpperBound(num: Int, point: Point): Double? = Double.POSITIVE_INFINITY + + public override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY + + public override fun nearestInDomain(point: Point): Point = point + + public override fun volume(): Double = Double.POSITIVE_INFINITY +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/UnivariateDomain.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/UnivariateDomain.kt new file mode 100644 index 000000000..bf090f2e5 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/domains/UnivariateDomain.kt @@ -0,0 +1,49 @@ +package kscience.kmath.domains + +import kscience.kmath.linear.Point +import kscience.kmath.structures.asBuffer + +public inline class UnivariateDomain(public val range: ClosedFloatingPointRange) : RealDomain { + public override val dimension: Int + get() = 1 + + public operator fun contains(d: Double): Boolean = range.contains(d) + + public override operator fun contains(point: Point): Boolean { + require(point.size == 0) + return contains(point[0]) + } + + public override fun nearestInDomain(point: Point): Point { + require(point.size == 1) + val value = point[0] + + return when { + value in range -> point + value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() + else -> doubleArrayOf(range.start).asBuffer() + } + } + + public override fun getLowerBound(num: Int, point: Point): Double? { + require(num == 0) + return range.start + } + + public override fun getUpperBound(num: Int, point: Point): Double? { + require(num == 0) + return range.endInclusive + } + + public override fun getLowerBound(num: Int): Double? { + require(num == 0) + return range.start + } + + public override fun getUpperBound(num: Int): Double? { + require(num == 0) + return range.endInclusive + } + + public override fun volume(): Double = range.endInclusive - range.start +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt new file mode 100644 index 000000000..abce9c4ec --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -0,0 +1,48 @@ +package kscience.kmath.expressions + +/** + * Represents expression which structure can be differentiated. + * + * @param T the type this expression takes as argument and returns. + * @param R the type of expression this expression can be differentiated to. + */ +public interface DifferentiableExpression> : Expression { + /** + * Differentiates this expression by ordered collection of [symbols]. + * + * @param symbols the symbols. + * @return the derivative or `null`. + */ + public fun derivativeOrNull(symbols: List): R? +} + +public fun > DifferentiableExpression.derivative(symbols: List): R = + derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") + +public fun > DifferentiableExpression.derivative(vararg symbols: Symbol): R = + derivative(symbols.toList()) + +public fun > DifferentiableExpression.derivative(name: String): R = + derivative(StringSymbol(name)) + +/** + * A [DifferentiableExpression] that defines only first derivatives + */ +public abstract class FirstDerivativeExpression> : DifferentiableExpression { + /** + * Returns first derivative of this expression by given [symbol]. + */ + public abstract fun derivativeOrNull(symbol: Symbol): R? + + public final override fun derivativeOrNull(symbols: List): R? { + val dSymbol = symbols.firstOrNull() ?: return null + return derivativeOrNull(dSymbol) + } +} + +/** + * A factory that converts an expression in autodiff variables to a [DifferentiableExpression] + */ +public fun interface AutoDiffProcessor, out R : Expression> { + public fun process(function: A.() -> I): DifferentiableExpression +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt new file mode 100644 index 000000000..63bbc9312 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -0,0 +1,115 @@ +package kscience.kmath.expressions + +import kscience.kmath.operations.Algebra +import kotlin.jvm.JvmName +import kotlin.properties.ReadOnlyProperty +import kotlin.reflect.KProperty + +/** + * A marker interface for a symbol. A symbol mus have an identity + */ +public interface Symbol { + /** + * Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol. + */ + public val identity: String + + public companion object : ReadOnlyProperty { + //TODO deprecate and replace by top level function after fix of https://youtrack.jetbrains.com/issue/KT-40121 + override fun getValue(thisRef: Any?, property: KProperty<*>): Symbol { + return StringSymbol(property.name) + } + } +} + +/** + * A [Symbol] with a [String] identity + */ +public inline class StringSymbol(override val identity: String) : Symbol { + override fun toString(): String = identity +} + +/** + * An elementary function that could be invoked on a map of arguments. + * + * @param T the type this expression takes as argument and returns. + */ +public fun interface Expression { + /** + * Calls this expression from arguments. + * + * @param arguments the map of arguments. + * @return the value. + */ + public operator fun invoke(arguments: Map): T +} + +/** + * Calls this expression without providing any arguments. + * + * @return a value. + */ +public operator fun Expression.invoke(): T = invoke(emptyMap()) + +/** + * Calls this expression from arguments. + * + * @param pairs the pairs of arguments to values. + * @return a value. + */ +@JvmName("callBySymbol") +public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) + +/** + * Calls this expression from arguments. + * + * @param pairs the pairs of arguments' names to values. + * @return a value. + */ +@JvmName("callByString") +public operator fun Expression.invoke(vararg pairs: Pair): T = + invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) }) + + +/** + * A context for expression construction + * + * @param T type of the constants for the expression + * @param E type of the actual expression state + */ +public interface ExpressionAlgebra : Algebra { + /** + * Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context. + */ + public fun bindOrNull(symbol: Symbol): E? + + /** + * Bind a string to a context using [StringSymbol] + */ + override fun symbol(value: String): E = bind(StringSymbol(value)) + + /** + * A constant expression which does not depend on arguments + */ + public fun const(value: T): E +} + +/** + * Bind a given [Symbol] to this context variable and produce context-specific object. + */ +public fun ExpressionAlgebra.bind(symbol: Symbol): E = + bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this") + +/** + * A delegate to create a symbol with a string identity in this scope + */ +public val symbol: ReadOnlyProperty get() = Symbol +//TODO does not work directly on native due to https://youtrack.jetbrains.com/issue/KT-40121 + + +/** + * Bind a symbol by name inside the [ExpressionAlgebra] + */ +public fun ExpressionAlgebra.binding(): ReadOnlyProperty = ReadOnlyProperty { _, property -> + bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist") +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt new file mode 100644 index 000000000..1a3668855 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -0,0 +1,164 @@ +package kscience.kmath.expressions + +import kscience.kmath.operations.* + +/** + * A context class for [Expression] construction. + * + * @param algebra The algebra to provide for Expressions built. + */ +public abstract class FunctionalExpressionAlgebra>( + public val algebra: A, +) : ExpressionAlgebra> { + /** + * Builds an Expression of constant expression which does not depend on arguments. + */ + public override fun const(value: T): Expression = Expression { value } + + /** + * Builds an Expression to access a variable. + */ + public override fun bindOrNull(symbol: Symbol): Expression? = Expression { arguments -> + arguments[symbol] ?: error("Argument not found: $symbol") + } + + /** + * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. + */ + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + { left, right -> + Expression { arguments -> + algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments)) + } + } + + /** + * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. + */ + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = { arg -> + Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) } + } +} + +/** + * A context class for [Expression] construction for [Space] algebras. + */ +public open class FunctionalExpressionSpace>( + algebra: A, +) : FunctionalExpressionAlgebra(algebra), Space> { + public override val zero: Expression get() = const(algebra.zero) + + /** + * Builds an Expression of addition of two another expressions. + */ + public override fun add(a: Expression, b: Expression): Expression = + binaryOperationFunction(SpaceOperations.PLUS_OPERATION)(a, b) + + /** + * Builds an Expression of multiplication of expression by number. + */ + public override fun multiply(a: Expression, k: Number): Expression = Expression { arguments -> + algebra.multiply(a.invoke(arguments), k) + } + + public operator fun Expression.plus(arg: T): Expression = this + const(arg) + public operator fun Expression.minus(arg: T): Expression = this - const(arg) + public operator fun T.plus(arg: Expression): Expression = arg + this + public operator fun T.minus(arg: Expression): Expression = arg - this + + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + super.unaryOperationFunction(operation) + + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperationFunction(operation) +} + +public open class FunctionalExpressionRing>( + algebra: A, +) : FunctionalExpressionSpace(algebra), Ring> { + public override val one: Expression + get() = const(algebra.one) + + /** + * Builds an Expression of multiplication of two expressions. + */ + public override fun multiply(a: Expression, b: Expression): Expression = + binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) + + public operator fun Expression.times(arg: T): Expression = this * const(arg) + public operator fun T.times(arg: Expression): Expression = arg * this + + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + super.unaryOperationFunction(operation) + + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperationFunction(operation) +} + +public open class FunctionalExpressionField>( + algebra: A, +) : FunctionalExpressionRing(algebra), Field> { + /** + * Builds an Expression of division an expression by another one. + */ + public override fun divide(a: Expression, b: Expression): Expression = + binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) + + public operator fun Expression.div(arg: T): Expression = this / const(arg) + public operator fun T.div(arg: Expression): Expression = arg / this + + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + super.unaryOperationFunction(operation) + + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperationFunction(operation) +} + +public open class FunctionalExpressionExtendedField>( + algebra: A, +) : FunctionalExpressionField(algebra), ExtendedField> { + + override fun number(value: Number): Expression = const(algebra.number(value)) + + public override fun sin(arg: Expression): Expression = + unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) + + public override fun cos(arg: Expression): Expression = + unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg) + + public override fun asin(arg: Expression): Expression = + unaryOperationFunction(TrigonometricOperations.ASIN_OPERATION)(arg) + + public override fun acos(arg: Expression): Expression = + unaryOperationFunction(TrigonometricOperations.ACOS_OPERATION)(arg) + + public override fun atan(arg: Expression): Expression = + unaryOperationFunction(TrigonometricOperations.ATAN_OPERATION)(arg) + + public override fun power(arg: Expression, pow: Number): Expression = + binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) + + public override fun exp(arg: Expression): Expression = + unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg) + + public override fun ln(arg: Expression): Expression = + unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg) + + public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + super.unaryOperationFunction(operation) + + public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + super.binaryOperationFunction(operation) +} + +public inline fun > A.expressionInSpace(block: FunctionalExpressionSpace.() -> Expression): Expression = + FunctionalExpressionSpace(this).block() + +public inline fun > A.expressionInRing(block: FunctionalExpressionRing.() -> Expression): Expression = + FunctionalExpressionRing(this).block() + +public inline fun > A.expressionInField(block: FunctionalExpressionField.() -> Expression): Expression = + FunctionalExpressionField(this).block() + +public inline fun > A.expressionInExtendedField(block: FunctionalExpressionExtendedField.() -> Expression): Expression = + FunctionalExpressionExtendedField(this).block() diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt new file mode 100644 index 000000000..0621e82bd --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -0,0 +1,395 @@ +package kscience.kmath.expressions + +import kscience.kmath.linear.Point +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.* +import kscience.kmath.structures.asBuffer +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/* + * Implementation of backward-mode automatic differentiation. + * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d + */ + + +public open class AutoDiffValue(public val value: T) + +/** + * Represents result of [simpleAutoDiff] call. + * + * @param T the non-nullable type of value. + * @param value the value of result. + * @property simpleAutoDiff The mapping of differentiated variables to their derivatives. + * @property context The field over [T]. + */ +public class DerivationResult( + public val value: T, + private val derivativeValues: Map, + public val context: Field, +) { + /** + * Returns derivative of [variable] or returns [Ring.zero] in [context]. + */ + public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero + + /** + * Computes the divergence. + */ + public fun div(): T = context { sum(derivativeValues.values) } +} + +/** + * Computes the gradient for variables in given order. + */ +public fun DerivationResult.grad(vararg variables: Symbol): Point { + check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } + return variables.map(::derivative).asBuffer() +} + +/** + * Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code. + * + * The partial derivatives are placed in argument `d` variable + * + * Example: + * ``` + * val x by symbol // define variable(s) and their values + * val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context + * assertEquals(17.0, y.x) // the value of result (y) + * assertEquals(9.0, x.d) // dy/dx + * ``` + * + * @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to. + * @return the result of differentiation. + */ +public fun > F.simpleAutoDiff( + bindings: Map, + body: SimpleAutoDiffField.() -> AutoDiffValue, +): DerivationResult { + contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } + + return SimpleAutoDiffField(this, bindings).differentiate(body) +} + +public fun > F.simpleAutoDiff( + vararg bindings: Pair, + body: SimpleAutoDiffField.() -> AutoDiffValue, +): DerivationResult = simpleAutoDiff(bindings.toMap(), body) + +/** + * Represents field in context of which functions can be derived. + */ +@OptIn(UnstableKMathAPI::class) +public open class SimpleAutoDiffField>( + public val context: F, + bindings: Map, +) : Field>, ExpressionAlgebra>, RingWithNumbers> { + public override val zero: AutoDiffValue + get() = const(context.zero) + + public override val one: AutoDiffValue + get() = const(context.one) + + // this stack contains pairs of blocks and values to apply them to + private var stack: Array = arrayOfNulls(8) + private var sp: Int = 0 + private val derivatives: MutableMap, T> = hashMapOf() + + private val bindings: Map> = bindings.entries.associate { + it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero) + } + + /** + * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result + * with respect to this variable. + * + * @param T the non-nullable type of value. + * @property value The value of this variable. + */ + private class AutoDiffVariableWithDerivative( + override val identity: String, + value: T, + var d: T, + ) : AutoDiffValue(value), Symbol { + override fun toString(): String = identity + override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + override fun hashCode(): Int = identity.hashCode() + } + + public override fun bindOrNull(symbol: Symbol): AutoDiffValue? = bindings[symbol.identity] + + private fun getDerivative(variable: AutoDiffValue): T = + (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero + + private fun setDerivative(variable: AutoDiffValue, value: T) { + if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value + } + + @Suppress("UNCHECKED_CAST") + private fun runBackwardPass() { + while (sp > 0) { + val value = stack[--sp] + val block = stack[--sp] as F.(Any?) -> Unit + context.block(value) + } + } + + override fun const(value: T): AutoDiffValue = AutoDiffValue(value) + + /** + * A variable accessing inner state of derivatives. + * Use this value in inner builders to avoid creating additional derivative bindings. + */ + public var AutoDiffValue.d: T + get() = getDerivative(this) + set(value) = setDerivative(this, value) + + public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block()) + + /** + * Performs update of derivative after the rest of the formula in the back-pass. + * + * For example, implementation of `sin` function is: + * + * ``` + * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result + * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function + * } + * ``` + */ + @Suppress("UNCHECKED_CAST") + public fun derive(value: R, block: F.(R) -> Unit): R { + // save block to stack for backward pass + if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) + stack[sp++] = block + stack[sp++] = value + return value + } + + + internal fun differentiate(function: SimpleAutoDiffField.() -> AutoDiffValue): DerivationResult { + val result = function() + result.d = context.one // computing derivative w.r.t result + runBackwardPass() + return DerivationResult(result.value, bindings.mapValues { it.value.d }, context) + } + + // Overloads for Double constants + + public override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue = + derive(const { this@plus.toDouble() * one + b.value }) { z -> + b.d += z.d + } + + public override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this) + + public override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue = + derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } + + public override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue = + derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } + + + // Basic math (+, -, *, /) + + public override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = + derive(const { a.value + b.value }) { z -> + a.d += z.d + b.d += z.d + } + + public override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = + derive(const { a.value * b.value }) { z -> + a.d += z.d * b.value + b.d += z.d * a.value + } + + public override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = + derive(const { a.value / b.value }) { z -> + a.d += z.d / b.value + b.d -= z.d * a.value / (b.value * b.value) + } + + public override fun multiply(a: AutoDiffValue, k: Number): AutoDiffValue = + derive(const { k.toDouble() * a.value }) { z -> + a.d += z.d * k.toDouble() + } +} + +/** + * A constructs that creates a derivative structure with required order on-demand + */ +public class SimpleAutoDiffExpression>( + public val field: F, + public val function: SimpleAutoDiffField.() -> AutoDiffValue, +) : FirstDerivativeExpression>() { + public override operator fun invoke(arguments: Map): T { + //val bindings = arguments.entries.map { it.key.bind(it.value) } + return SimpleAutoDiffField(field, arguments).function().value + } + + public override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments -> + //val bindings = arguments.entries.map { it.key.bind(it.value) } + val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function) + derivationResult.derivative(symbol) + } +} + +/** + * Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression] + */ +public fun > simpleAutoDiff(field: F): AutoDiffProcessor, SimpleAutoDiffField, Expression> = + AutoDiffProcessor { function -> + SimpleAutoDiffExpression(field, function) + } + +// Extensions for differentiation of various basic mathematical functions + +// x ^ 2 +public fun > SimpleAutoDiffField.sqr(x: AutoDiffValue): AutoDiffValue = + derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } + +// x ^ 1/2 +public fun > SimpleAutoDiffField.sqrt(x: AutoDiffValue): AutoDiffValue = + derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } + +// x ^ y (const) +public fun > SimpleAutoDiffField.pow( + x: AutoDiffValue, + y: Double, +): AutoDiffValue = + derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } + +public fun > SimpleAutoDiffField.pow( + x: AutoDiffValue, + y: Int, +): AutoDiffValue = pow(x, y.toDouble()) + +// exp(x) +public fun > SimpleAutoDiffField.exp(x: AutoDiffValue): AutoDiffValue = + derive(const { exp(x.value) }) { z -> x.d += z.d * z.value } + +// ln(x) +public fun > SimpleAutoDiffField.ln(x: AutoDiffValue): AutoDiffValue = + derive(const { ln(x.value) }) { z -> x.d += z.d / x.value } + +// x ^ y (any) +public fun > SimpleAutoDiffField.pow( + x: AutoDiffValue, + y: AutoDiffValue, +): AutoDiffValue = + exp(y * ln(x)) + +// sin(x) +public fun > SimpleAutoDiffField.sin(x: AutoDiffValue): AutoDiffValue = + derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } + +// cos(x) +public fun > SimpleAutoDiffField.cos(x: AutoDiffValue): AutoDiffValue = + derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } + +public fun > SimpleAutoDiffField.tan(x: AutoDiffValue): AutoDiffValue = + derive(const { tan(x.value) }) { z -> + val c = cos(x.value) + x.d += z.d / (c * c) + } + +public fun > SimpleAutoDiffField.asin(x: AutoDiffValue): AutoDiffValue = + derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) } + +public fun > SimpleAutoDiffField.acos(x: AutoDiffValue): AutoDiffValue = + derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) } + +public fun > SimpleAutoDiffField.atan(x: AutoDiffValue): AutoDiffValue = + derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) } + +public fun > SimpleAutoDiffField.sinh(x: AutoDiffValue): AutoDiffValue = + derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) } + +public fun > SimpleAutoDiffField.cosh(x: AutoDiffValue): AutoDiffValue = + derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) } + +public fun > SimpleAutoDiffField.tanh(x: AutoDiffValue): AutoDiffValue = + derive(const { tanh(x.value) }) { z -> + val c = cosh(x.value) + x.d += z.d / (c * c) + } + +public fun > SimpleAutoDiffField.asinh(x: AutoDiffValue): AutoDiffValue = + derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) } + +public fun > SimpleAutoDiffField.acosh(x: AutoDiffValue): AutoDiffValue = + derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) } + +public fun > SimpleAutoDiffField.atanh(x: AutoDiffValue): AutoDiffValue = + derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) } + +public class SimpleAutoDiffExtendedField>( + context: F, + bindings: Map, +) : ExtendedField>, SimpleAutoDiffField(context, bindings) { + // x ^ 2 + public fun sqr(x: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sqr(x) + + // x ^ 1/2 + public override fun sqrt(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sqrt(arg) + + // x ^ y (const) + public override fun power(arg: AutoDiffValue, pow: Number): AutoDiffValue = + (this as SimpleAutoDiffField).pow(arg, pow.toDouble()) + + // exp(x) + public override fun exp(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).exp(arg) + + // ln(x) + public override fun ln(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).ln(arg) + + // x ^ y (any) + public fun pow( + x: AutoDiffValue, + y: AutoDiffValue, + ): AutoDiffValue = exp(y * ln(x)) + + // sin(x) + public override fun sin(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sin(arg) + + // cos(x) + public override fun cos(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).cos(arg) + + public override fun tan(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).tan(arg) + + public override fun asin(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).asin(arg) + + public override fun acos(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).acos(arg) + + public override fun atan(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).atan(arg) + + public override fun sinh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).sinh(arg) + + public override fun cosh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).cosh(arg) + + public override fun tanh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).tanh(arg) + + public override fun asinh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).asinh(arg) + + public override fun acosh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).acosh(arg) + + public override fun atanh(arg: AutoDiffValue): AutoDiffValue = + (this as SimpleAutoDiffField).atanh(arg) +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt new file mode 100644 index 000000000..6c61c7c7d --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SymbolIndexer.kt @@ -0,0 +1,61 @@ +package kscience.kmath.expressions + +import kscience.kmath.linear.Point +import kscience.kmath.structures.BufferFactory +import kscience.kmath.structures.Structure2D + +/** + * An environment to easy transform indexed variables to symbols and back. + * TODO requires multi-receivers to be beutiful + */ +public interface SymbolIndexer { + public val symbols: List + public fun indexOf(symbol: Symbol): Int = symbols.indexOf(symbol) + + public operator fun List.get(symbol: Symbol): T { + require(size == symbols.size) { "The input list size for indexer should be ${symbols.size} but $size found" } + return get(this@SymbolIndexer.indexOf(symbol)) + } + + public operator fun Array.get(symbol: Symbol): T { + require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" } + return get(this@SymbolIndexer.indexOf(symbol)) + } + + public operator fun DoubleArray.get(symbol: Symbol): Double { + require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" } + return get(this@SymbolIndexer.indexOf(symbol)) + } + + public operator fun Point.get(symbol: Symbol): T { + require(size == symbols.size) { "The input buffer size for indexer should be ${symbols.size} but $size found" } + return get(this@SymbolIndexer.indexOf(symbol)) + } + + public fun DoubleArray.toMap(): Map { + require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" } + return symbols.indices.associate { symbols[it] to get(it) } + } + + public operator fun Structure2D.get(rowSymbol: Symbol, columnSymbol: Symbol): T = + get(indexOf(rowSymbol), indexOf(columnSymbol)) + + + public fun Map.toList(): List = symbols.map { getValue(it) } + + public fun Map.toPoint(bufferFactory: BufferFactory): Point = + bufferFactory(symbols.size) { getValue(symbols[it]) } + + public fun Map.toDoubleArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) } +} + +public inline class SimpleSymbolIndexer(override val symbols: List) : SymbolIndexer + +/** + * Execute the block with symbol indexer based on given symbol order + */ +public inline fun withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R = + with(SimpleSymbolIndexer(symbols.toList()), block) + +public inline fun withSymbols(symbols: Collection, block: SymbolIndexer.() -> R): R = + with(SimpleSymbolIndexer(symbols.toList()), block) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt similarity index 51% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt index 8d0b82a89..1603bc21d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/expressionBuilders.kt @@ -1,17 +1,17 @@ -package scientifik.kmath.expressions +package kscience.kmath.expressions -import scientifik.kmath.operations.ExtendedField -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space -import kotlin.contracts.ExperimentalContracts +import kscience.kmath.operations.ExtendedField +import kscience.kmath.operations.Field +import kscience.kmath.operations.Ring +import kscience.kmath.operations.Space import kotlin.contracts.InvocationKind import kotlin.contracts.contract + /** * Creates a functional expression with this [Space]. */ -inline fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression { +public inline fun Space.spaceExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return FunctionalExpressionSpace(this).block() } @@ -19,7 +19,7 @@ inline fun Space.spaceExpression(block: FunctionalExpressionSpace Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression { +public inline fun Ring.ringExpression(block: FunctionalExpressionRing>.() -> Expression): Expression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return FunctionalExpressionRing(this).block() } @@ -27,7 +27,7 @@ inline fun Ring.ringExpression(block: FunctionalExpressionRing /** * Creates a functional expression with this [Field]. */ -inline fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression { +public inline fun Field.fieldExpression(block: FunctionalExpressionField>.() -> Expression): Expression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return FunctionalExpressionField(this).block() } @@ -35,7 +35,7 @@ inline fun Field.fieldExpression(block: FunctionalExpressionField ExtendedField.extendedFieldExpression(block: FunctionalExpressionExtendedField>.() -> Expression): Expression { +public inline fun ExtendedField.extendedFieldExpression(block: FunctionalExpressionExtendedField>.() -> Expression): Expression { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return FunctionalExpressionExtendedField(this).block() } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt new file mode 100644 index 000000000..a74d948fc --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt @@ -0,0 +1,68 @@ +package kscience.kmath.linear + +import kscience.kmath.operations.Ring +import kscience.kmath.structures.* + +/** + * Basic implementation of Matrix space based on [NDStructure] + */ +public class BufferMatrixContext>( + public override val elementContext: R, + private val bufferFactory: BufferFactory, +) : GenericMatrixContext> { + public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix { + val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) } + return BufferMatrix(rows, columns, buffer) + } + + public override fun point(size: Int, initializer: (Int) -> T): Point = bufferFactory(size, initializer) + + public companion object +} + +public class BufferMatrix( + public override val rowNum: Int, + public override val colNum: Int, + public val buffer: Buffer, +) : Matrix { + + init { + require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" } + } + + override val shape: IntArray get() = intArrayOf(rowNum, colNum) + + public override operator fun get(index: IntArray): T = get(index[0], index[1]) + public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j] + + public override fun elements(): Sequence> = sequence { + for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j)) + } + + public override fun equals(other: Any?): Boolean { + if (this === other) return true + + return when (other) { + is NDStructure<*> -> NDStructure.contentEquals(this, other) + else -> false + } + } + + override fun hashCode(): Int { + var result = rowNum + result = 31 * result + colNum + result = 31 * result + buffer.hashCode() + return result + } + + public override fun toString(): String { + return if (rowNum <= 5 && colNum <= 5) + "Matrix(rowsNum = $rowNum, colNum = $colNum)\n" + + rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer -> + buffer.asSequence().joinToString(separator = "\t") { it.toString() } + } + else "Matrix(rowsNum = $rowNum, colNum = $colNum)" + } + + +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LinearAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LinearAlgebra.kt new file mode 100644 index 000000000..034decc2f --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LinearAlgebra.kt @@ -0,0 +1,27 @@ +package kscience.kmath.linear + +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.VirtualBuffer + +public typealias Point = Buffer + +/** + * A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors + */ +public interface LinearSolver { + public fun solve(a: Matrix, b: Matrix): Matrix + public fun solve(a: Matrix, b: Point): Point = solve(a, b.asMatrix()).asPoint() + public fun inverse(a: Matrix): Matrix +} + +/** + * Convert matrix to vector if it is possible + */ +public fun Matrix.asPoint(): Point = + if (this.colNum == 1) + VirtualBuffer(rowNum) { get(it, 0) } + else + error("Can't convert matrix with more than one column to vector") + +public fun Point.asMatrix(): VirtualMatrix = VirtualMatrix(size, 1) { i, _ -> get(i) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt similarity index 55% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt index f3e4f648f..5cf7c8f70 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt @@ -1,38 +1,31 @@ -package scientifik.kmath.linear +package kscience.kmath.linear -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.invoke -import scientifik.kmath.structures.BufferAccessor2D -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.Structure2D -import kotlin.reflect.KClass +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.* +import kscience.kmath.structures.* /** - * Common implementation of [LUPDecompositionFeature] + * Common implementation of [LupDecompositionFeature]. */ -class LUPDecomposition( - val context: GenericMatrixContext>, - val lu: Structure2D, - val pivot: IntArray, - private val even: Boolean -) : LUPDecompositionFeature, DeterminantFeature { - - val elementContext: Field get() = context.elementContext - +public class LupDecomposition( + public val context: MatrixContext>, + public val elementContext: Field, + public val lu: Matrix, + public val pivot: IntArray, + private val even: Boolean, +) : LupDecompositionFeature, DeterminantFeature { /** * Returns the matrix L of the decomposition. * * L is a lower-triangular matrix with [Ring.one] in diagonal */ - override val l: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j -> + override val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> when { j < i -> lu[i, j] j == i -> elementContext.one else -> elementContext.zero } - } + } + LFeature /** @@ -40,10 +33,9 @@ class LUPDecomposition( * * U is an upper-triangular matrix including the diagonal */ - override val u: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j -> + override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j >= i) lu[i, j] else elementContext.zero - } - + } + UFeature /** * Returns the P rows permutation matrix. @@ -51,11 +43,10 @@ class LUPDecomposition( * P is a sparse matrix with exactly one element set to [Ring.one] in * each row and each column, all other elements being set to [Ring.zero]. */ - override val p: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + override val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j == pivot[i]) elementContext.one else elementContext.zero } - /** * Return the determinant of the matrix * @return determinant of the matrix @@ -66,27 +57,25 @@ class LUPDecomposition( } -fun , F : Field> GenericMatrixContext.abs(value: T): T = +@PublishedApi +internal fun , F : Field> GenericMatrixContext.abs(value: T): T = if (value > elementContext.zero) value else elementContext { -value } - /** - * Create a lup decomposition of generic matrix + * Create a lup decomposition of generic matrix. */ -fun , F : Field> GenericMatrixContext.lup( - type: KClass, +public fun > MatrixContext>.lup( + factory: MutableBufferFactory, + elementContext: Field, matrix: Matrix, - checkSingular: (T) -> Boolean -): LUPDecomposition { - if (matrix.rowNum != matrix.colNum) { - error("LU decomposition supports only square matrices") - } - + checkSingular: (T) -> Boolean, +): LupDecomposition { + require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" } val m = matrix.colNum val pivot = IntArray(matrix.rowNum) //TODO just waits for KEEP-176 - BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { + BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run { elementContext { val lu = create(matrix) @@ -118,14 +107,14 @@ fun , F : Field> GenericMatrixContext.lup( luRow[col] = sum // maintain best permutation choice - if (this@lup.abs(sum) > largest) { - largest = this@lup.abs(sum) + if (abs(sum) > largest) { + largest = abs(sum) max = row } } // Singularity check - check(!checkSingular(this@lup.abs(lu[max, col]))) { "The matrix is singular" } + check(!checkSingular(abs(lu[max, col]))) { "The matrix is singular" } // Pivot if necessary if (max != col) { @@ -149,23 +138,26 @@ fun , F : Field> GenericMatrixContext.lup( for (row in col + 1 until m) lu[row, col] /= luDiag } - return LUPDecomposition(this@lup, lu.collect(), pivot, even) + return LupDecomposition(this@lup, elementContext, lu.collect(), pivot, even) } } } -inline fun , F : Field> GenericMatrixContext.lup( +public inline fun , F : Field> GenericMatrixContext>.lup( matrix: Matrix, - noinline checkSingular: (T) -> Boolean -): LUPDecomposition = lup(T::class, matrix, checkSingular) + noinline checkSingular: (T) -> Boolean, +): LupDecomposition = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular) -fun GenericMatrixContext.lup(matrix: Matrix): LUPDecomposition = - lup(Double::class, matrix) { it < 1e-11 } +public fun MatrixContext>.lup(matrix: Matrix): LupDecomposition = + lup(Buffer.Companion::real, RealField, matrix) { it < 1e-11 } -fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Matrix { +public fun LupDecomposition.solveWithLUP( + factory: MutableBufferFactory, + matrix: Matrix, +): Matrix { require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" } - BufferAccessor2D(type, matrix.rowNum, matrix.colNum).run { + BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run { elementContext { // Apply permutations to b val bp = create { _, _ -> zero } @@ -207,27 +199,41 @@ fun LUPDecomposition.solve(type: KClass, matrix: Matrix): Mat } } -inline fun LUPDecomposition.solve(matrix: Matrix): Matrix = solve(T::class, matrix) +public inline fun LupDecomposition.solveWithLUP(matrix: Matrix): Matrix = + solveWithLUP(MutableBuffer.Companion::auto, matrix) /** - * Solve a linear equation **a*x = b** + * Solve a linear equation **a*x = b** using LUP decomposition */ -inline fun , F : Field> GenericMatrixContext.solve( +@OptIn(UnstableKMathAPI::class) +public inline fun , F : Field> GenericMatrixContext>.solveWithLUP( a: Matrix, b: Matrix, - noinline checkSingular: (T) -> Boolean + noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, + noinline checkSingular: (T) -> Boolean, ): Matrix { // Use existing decomposition if it is provided by matrix - val decomposition = a.getFeature() ?: lup(T::class, a, checkSingular) - return decomposition.solve(T::class, b) + val decomposition = a.getFeature() ?: lup(bufferFactory, elementContext, a, checkSingular) + return decomposition.solveWithLUP(bufferFactory, b) } -fun RealMatrixContext.solve(a: Matrix, b: Matrix): Matrix = solve(a, b) { it < 1e-11 } - -inline fun , F : Field> GenericMatrixContext.inverse( +public inline fun , F : Field> GenericMatrixContext>.inverseWithLUP( matrix: Matrix, - noinline checkSingular: (T) -> Boolean -): Matrix = solve(matrix, one(matrix.rowNum, matrix.colNum), checkSingular) + noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, + noinline checkSingular: (T) -> Boolean, +): Matrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) -fun RealMatrixContext.inverse(matrix: Matrix): Matrix = - solve(matrix, one(matrix.rowNum, matrix.colNum)) { it < 1e-11 } + +@OptIn(UnstableKMathAPI::class) +public fun RealMatrixContext.solveWithLUP(a: Matrix, b: Matrix): Matrix { + // Use existing decomposition if it is provided by matrix + val bufferFactory: MutableBufferFactory = MutableBuffer.Companion::real + val decomposition: LupDecomposition = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 } + return decomposition.solveWithLUP(bufferFactory, b) +} + +/** + * Inverses a square matrix using LUP decomposition. Non square matrix will throw a error. + */ +public fun RealMatrixContext.inverseWithLUP(matrix: Matrix): Matrix = + solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum)) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt similarity index 52% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt index 390362f8c..c0c209248 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixBuilder.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt @@ -1,12 +1,9 @@ -package scientifik.kmath.linear +package kscience.kmath.linear -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.BufferFactory -import scientifik.kmath.structures.Structure2D -import scientifik.kmath.structures.asBuffer +import kscience.kmath.structures.* -class MatrixBuilder(val rows: Int, val columns: Int) { - operator fun invoke(vararg elements: T): FeaturedMatrix { +public class MatrixBuilder(public val rows: Int, public val columns: Int) { + public operator fun invoke(vararg elements: T): Matrix { require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" } val buffer = elements.asBuffer() return BufferMatrix(rows, columns, buffer) @@ -15,32 +12,32 @@ class MatrixBuilder(val rows: Int, val columns: Int) { //TODO add specific matrix builder functions like diagonal, etc } -fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) +public fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) -fun Structure2D.Companion.row(vararg values: T): FeaturedMatrix { +public fun Structure2D.Companion.row(vararg values: T): Matrix { val buffer = values.asBuffer() return BufferMatrix(1, values.size, buffer) } -inline fun Structure2D.Companion.row( +public inline fun Structure2D.Companion.row( size: Int, factory: BufferFactory = Buffer.Companion::auto, noinline builder: (Int) -> T -): FeaturedMatrix { +): Matrix { val buffer = factory(size, builder) return BufferMatrix(1, size, buffer) } -fun Structure2D.Companion.column(vararg values: T): FeaturedMatrix { +public fun Structure2D.Companion.column(vararg values: T): Matrix { val buffer = values.asBuffer() return BufferMatrix(values.size, 1, buffer) } -inline fun Structure2D.Companion.column( +public inline fun Structure2D.Companion.column( size: Int, factory: BufferFactory = Buffer.Companion::auto, noinline builder: (Int) -> T -): FeaturedMatrix { +): Matrix { val buffer = factory(size, builder) return BufferMatrix(size, 1, buffer) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt new file mode 100644 index 000000000..59a41f840 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt @@ -0,0 +1,138 @@ +package kscience.kmath.linear + +import kscience.kmath.operations.Ring +import kscience.kmath.operations.SpaceOperations +import kscience.kmath.operations.invoke +import kscience.kmath.operations.sum +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.BufferFactory +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.asSequence + +/** + * Basic operations on matrices. Operates on [Matrix] + */ +public interface MatrixContext> : SpaceOperations> { + /** + * Produce a matrix with this context and given dimensions + */ + public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M + + /** + * Produce a point compatible with matrix space (and possibly optimized for it) + */ + public fun point(size: Int, initializer: (Int) -> T): Point = Buffer.boxing(size, initializer) + + @Suppress("UNCHECKED_CAST") + public override fun binaryOperationFunction(operation: String): (left: Matrix, right: Matrix) -> M = + when (operation) { + "dot" -> { left, right -> left dot right } + else -> super.binaryOperationFunction(operation) as (Matrix, Matrix) -> M + } + + /** + * Computes the dot product of this matrix and another one. + * + * @receiver the multiplicand. + * @param other the multiplier. + * @return the dot product. + */ + public infix fun Matrix.dot(other: Matrix): M + + /** + * Computes the dot product of this matrix and a vector. + * + * @receiver the multiplicand. + * @param vector the multiplier. + * @return the dot product. + */ + public infix fun Matrix.dot(vector: Point): Point + + /** + * Multiplies a matrix by its element. + * + * @receiver the multiplicand. + * @param value the multiplier. + * @receiver the product. + */ + public operator fun Matrix.times(value: T): M + + /** + * Multiplies an element by a matrix of it. + * + * @receiver the multiplicand. + * @param value the multiplier. + * @receiver the product. + */ + public operator fun T.times(m: Matrix): M = m * this + + public companion object { + + /** + * A structured matrix with custom buffer + */ + public fun > buffered( + ring: R, + bufferFactory: BufferFactory = Buffer.Companion::boxing, + ): GenericMatrixContext> = BufferMatrixContext(ring, bufferFactory) + + /** + * Automatic buffered matrix, unboxed if it is possible + */ + public inline fun > auto(ring: R): GenericMatrixContext> = + buffered(ring, Buffer.Companion::auto) + } +} + +public interface GenericMatrixContext, out M : Matrix> : MatrixContext { + /** + * The ring context for matrix elements + */ + public val elementContext: R + + public override infix fun Matrix.dot(other: Matrix): M { + //TODO add typed error + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } + + return produce(rowNum, other.colNum) { i, j -> + val row = rows[i] + val column = other.columns[j] + elementContext { sum(row.asSequence().zip(column.asSequence(), ::multiply)) } + } + } + + public override infix fun Matrix.dot(vector: Point): Point { + //TODO add typed error + require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" } + + return point(rowNum) { i -> + val row = rows[i] + elementContext { sum(row.asSequence().zip(vector.asSequence(), ::multiply)) } + } + } + + public override operator fun Matrix.unaryMinus(): M = + produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } } + + public override fun add(a: Matrix, b: Matrix): M { + require(a.rowNum == b.rowNum && a.colNum == b.colNum) { + "Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]" + } + + return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } } + } + + public override operator fun Matrix.minus(b: Matrix): M { + require(rowNum == b.rowNum && colNum == b.colNum) { + "Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]" + } + + return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } } + } + + public override fun multiply(a: Matrix, k: Number): M = + produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } } + + public override operator fun Matrix.times(value: T): M = + produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } } +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt new file mode 100644 index 000000000..e61feec6c --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt @@ -0,0 +1,158 @@ +package kscience.kmath.linear + +import kscience.kmath.structures.Matrix + +/** + * A marker interface representing some properties of matrices or additional transformations of them. Features are used + * to optimize matrix operations performance in some cases or retrieve the APIs. + */ +public interface MatrixFeature + +/** + * Matrices with this feature are considered to have only diagonal non-null elements. + */ +public interface DiagonalFeature : MatrixFeature{ + public companion object: DiagonalFeature +} + +/** + * Matrices with this feature have all zero elements. + */ +public object ZeroFeature : DiagonalFeature + +/** + * Matrices with this feature have unit elements on diagonal and zero elements in all other places. + */ +public object UnitFeature : DiagonalFeature + +/** + * Matrices with this feature can be inverted: [inverse] = `a`-1 where `a` is the owning matrix. + * + * @param T the type of matrices' items. + */ +public interface InverseMatrixFeature : MatrixFeature { + /** + * The inverse matrix of the matrix that owns this feature. + */ + public val inverse: Matrix +} + +/** + * Matrices with this feature can compute their determinant. + */ +public interface DeterminantFeature : MatrixFeature { + /** + * The determinant of the matrix that owns this feature. + */ + public val determinant: T +} + +/** + * Produces a [DeterminantFeature] where the [DeterminantFeature.determinant] is [determinant]. + * + * @param determinant the value of determinant. + * @return a new [DeterminantFeature]. + */ +@Suppress("FunctionName") +public fun DeterminantFeature(determinant: T): DeterminantFeature = object : DeterminantFeature { + override val determinant: T = determinant +} + +/** + * Matrices with this feature are lower triangular ones. + */ +public object LFeature : MatrixFeature + +/** + * Matrices with this feature are upper triangular ones. + */ +public object UFeature : MatrixFeature + +/** + * Matrices with this feature support LU factorization with partial pivoting: *[p] · a = [l] · [u]* where + * *a* is the owning matrix. + * + * @param T the type of matrices' items. + */ +public interface LupDecompositionFeature : MatrixFeature { + /** + * The lower triangular matrix in this decomposition. It may have [LFeature]. + */ + public val l: Matrix + + /** + * The upper triangular matrix in this decomposition. It may have [UFeature]. + */ + public val u: Matrix + + /** + * The permutation matrix in this decomposition. + */ + public val p: Matrix +} + +/** + * Matrices with this feature are orthogonal ones: *a · aT = u* where *a* is the owning matrix, *u* + * is the unit matrix ([UnitFeature]). + */ +public object OrthogonalFeature : MatrixFeature + +/** + * Matrices with this feature support QR factorization: *a = [q] · [r]* where *a* is the owning matrix. + * + * @param T the type of matrices' items. + */ +public interface QRDecompositionFeature : MatrixFeature { + /** + * The orthogonal matrix in this decomposition. It may have [OrthogonalFeature]. + */ + public val q: Matrix + + /** + * The upper triangular matrix in this decomposition. It may have [UFeature]. + */ + public val r: Matrix +} + +/** + * Matrices with this feature support Cholesky factorization: *a = [l] · [l]H* where *a* is the + * owning matrix. + * + * @param T the type of matrices' items. + */ +public interface CholeskyDecompositionFeature : MatrixFeature { + /** + * The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature]. + */ + public val l: Matrix +} + +/** + * Matrices with this feature support SVD: *a = [u] · [s] · [v]H* where *a* is the owning + * matrix. + * + * @param T the type of matrices' items. + */ +public interface SingularValueDecompositionFeature : MatrixFeature { + /** + * The matrix in this decomposition. It is unitary, and it consists from left singular vectors. + */ + public val u: Matrix + + /** + * The matrix in this decomposition. Its main diagonal elements are singular values. + */ + public val s: Matrix + + /** + * The matrix in this decomposition. It is unitary, and it consists from right singular vectors. + */ + public val v: Matrix + + /** + * The buffer of singular values of this SVD. + */ + public val singularValues: Point +} + +//TODO add sparse matrix feature diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt new file mode 100644 index 000000000..362db1fe7 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt @@ -0,0 +1,105 @@ +package kscience.kmath.linear + +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.Ring +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.Structure2D +import kscience.kmath.structures.asBuffer +import kscience.kmath.structures.getFeature +import kotlin.math.sqrt +import kotlin.reflect.KClass +import kotlin.reflect.safeCast + +/** + * A [Matrix] that holds [MatrixFeature] objects. + * + * @param T the type of items. + */ +public class MatrixWrapper internal constructor( + public val origin: Matrix, + public val features: Set, +) : Matrix by origin { + + /** + * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria + */ + @UnstableKMathAPI + override fun getFeature(type: KClass): T? = type.safeCast(features.find { type.isInstance(it) }) + ?: origin.getFeature(type) + + override fun equals(other: Any?): Boolean = origin == other + override fun hashCode(): Int = origin.hashCode() + override fun toString(): String { + return "MatrixWrapper(matrix=$origin, features=$features)" + } +} + +/** + * Return the original matrix. If this is a wrapper, return its origin. If not, this matrix. + * Origin does not necessary store all features. + */ +@UnstableKMathAPI +public val Matrix.origin: Matrix get() = (this as? MatrixWrapper)?.origin ?: this + +/** + * Add a single feature to a [Matrix] + */ +public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixWrapper = if (this is MatrixWrapper) { + MatrixWrapper(origin, features + newFeature) +} else { + MatrixWrapper(this, setOf(newFeature)) +} + +/** + * Add a collection of features to a [Matrix] + */ +public operator fun Matrix.plus(newFeatures: Collection): MatrixWrapper = + if (this is MatrixWrapper) { + MatrixWrapper(origin, features + newFeatures) + } else { + MatrixWrapper(this, newFeatures.toSet()) + } + +public inline fun Structure2D.Companion.real( + rows: Int, + columns: Int, + initializer: (Int, Int) -> Double, +): BufferMatrix = MatrixContext.real.produce(rows, columns, initializer) + +/** + * Build a square matrix from given elements. + */ +public fun Structure2D.Companion.square(vararg elements: T): Matrix { + val size: Int = sqrt(elements.size.toDouble()).toInt() + require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" } + val buffer = elements.asBuffer() + return BufferMatrix(size, size, buffer) +} + +/** + * Diagonal matrix of ones. The matrix is virtual no actual matrix is created + */ +public fun > GenericMatrixContext.one(rows: Int, columns: Int): Matrix = + VirtualMatrix(rows, columns) { i, j -> + if (i == j) elementContext.one else elementContext.zero + } + UnitFeature + + +/** + * A virtual matrix of zeroes + */ +public fun > GenericMatrixContext.zero(rows: Int, columns: Int): Matrix = + VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } + ZeroFeature + +public class TransposedFeature(public val original: Matrix) : MatrixFeature + +/** + * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` + */ +@OptIn(UnstableKMathAPI::class) +public fun Matrix.transpose(): Matrix { + return getFeature>()?.original ?: VirtualMatrix( + colNum, + rowNum, + ) { i, j -> get(j, i) } + TransposedFeature(this) +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt new file mode 100644 index 000000000..8e197672f --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt @@ -0,0 +1,68 @@ +package kscience.kmath.linear + +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.RealBuffer + +@Suppress("OVERRIDE_BY_INLINE") +public object RealMatrixContext : MatrixContext> { + + public override inline fun produce( + rows: Int, + columns: Int, + initializer: (i: Int, j: Int) -> Double, + ): BufferMatrix { + val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } + return BufferMatrix(rows, columns, buffer) + } + + private fun Matrix.wrap(): BufferMatrix = if (this is BufferMatrix) this else { + produce(rowNum, colNum) { i, j -> get(i, j) } + } + + public fun one(rows: Int, columns: Int): Matrix = VirtualMatrix(rows, columns) { i, j -> + if (i == j) 1.0 else 0.0 + } + DiagonalFeature + + public override infix fun Matrix.dot(other: Matrix): BufferMatrix { + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } + return produce(rowNum, other.colNum) { i, j -> + var res = 0.0 + for (l in 0 until colNum) { + res += get(i, l) * other.get(l, j) + } + res + } + } + + public override infix fun Matrix.dot(vector: Point): Point { + require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" } + return RealBuffer(rowNum) { i -> + var res = 0.0 + for (j in 0 until colNum) { + res += get(i, j) * vector[j] + } + res + } + } + + override fun add(a: Matrix, b: Matrix): BufferMatrix { + require(a.rowNum == b.rowNum) { "Row number mismatch in matrix addition. Left side: ${a.rowNum}, right side: ${b.rowNum}" } + require(a.colNum == b.colNum) { "Column number mismatch in matrix addition. Left side: ${a.colNum}, right side: ${b.colNum}" } + return produce(a.rowNum, a.colNum) { i, j -> + a[i, j] + b[i, j] + } + } + + override fun Matrix.times(value: Double): BufferMatrix = + produce(rowNum, colNum) { i, j -> get(i, j) * value } + + + override fun multiply(a: Matrix, k: Number): BufferMatrix = + produce(a.rowNum, a.colNum) { i, j -> a[i, j] * k.toDouble() } +} + + +/** + * Partially optimized real-valued matrix + */ +public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VectorSpace.kt similarity index 60% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VectorSpace.kt index 82e5c7ef6..2a3b8f5d1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VectorSpace.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VectorSpace.kt @@ -1,21 +1,21 @@ -package scientifik.kmath.linear +package kscience.kmath.linear -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.BufferFactory +import kscience.kmath.operations.RealField +import kscience.kmath.operations.Space +import kscience.kmath.operations.invoke +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.BufferFactory /** * A linear space for vectors. * Could be used on any point-like structure */ -interface VectorSpace> : Space> { - val size: Int - val space: S +public interface VectorSpace> : Space> { + public val size: Int + public val space: S override val zero: Point get() = produce { space.zero } - fun produce(initializer: (Int) -> T): Point + public fun produce(initializer: S.(Int) -> T): Point /** * Produce a space-element of this vector space for expressions @@ -28,13 +28,13 @@ interface VectorSpace> : Space> { //TODO add basis - companion object { + public companion object { private val realSpaceCache: MutableMap> = hashMapOf() /** * Non-boxing double vector space */ - fun real(size: Int): BufferVectorSpace = realSpaceCache.getOrPut(size) { + public fun real(size: Int): BufferVectorSpace = realSpaceCache.getOrPut(size) { BufferVectorSpace( size, RealField, @@ -45,26 +45,26 @@ interface VectorSpace> : Space> { /** * A structured vector space with custom buffer */ - fun > buffered( + public fun > buffered( size: Int, space: S, - bufferFactory: BufferFactory = Buffer.Companion::boxing + bufferFactory: BufferFactory = Buffer.Companion::boxing, ): BufferVectorSpace = BufferVectorSpace(size, space, bufferFactory) /** * Automatic buffered vector, unboxed if it is possible */ - inline fun > auto(size: Int, space: S): VectorSpace = + public inline fun > auto(size: Int, space: S): VectorSpace = buffered(size, space, Buffer.Companion::auto) } } -class BufferVectorSpace>( +public class BufferVectorSpace>( override val size: Int, override val space: S, - val bufferFactory: BufferFactory + public val bufferFactory: BufferFactory, ) : VectorSpace { - override fun produce(initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) + override fun produce(initializer: S.(Int) -> T): Buffer = bufferFactory(size) { space.initializer(it) } //override fun produceElement(initializer: (Int) -> T): Vector = BufferVector(this, produce(initializer)) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt new file mode 100644 index 000000000..0269a64d1 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt @@ -0,0 +1,33 @@ +package kscience.kmath.linear + +import kscience.kmath.structures.Matrix + +public class VirtualMatrix( + override val rowNum: Int, + override val colNum: Int, + public val generator: (i: Int, j: Int) -> T +) : Matrix { + + override val shape: IntArray get() = intArrayOf(rowNum, colNum) + + override operator fun get(i: Int, j: Int): T = generator(i, j) + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is Matrix<*>) return false + + if (rowNum != other.rowNum) return false + if (colNum != other.colNum) return false + + return elements().all { (index, value) -> value == other[index] } + } + + override fun hashCode(): Int { + var result = rowNum + result = 31 * result + colNum + result = 31 * result + generator.hashCode() + return result + } + + +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/annotations.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/annotations.kt new file mode 100644 index 000000000..d70ac7b39 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/annotations.kt @@ -0,0 +1,4 @@ +package kscience.kmath.misc + +@RequiresOptIn("This API is unstable and could change in future", RequiresOptIn.Level.WARNING) +public annotation class UnstableKMathAPI \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/cumulative.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/cumulative.kt new file mode 100644 index 000000000..72d2f2388 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/misc/cumulative.kt @@ -0,0 +1,74 @@ +package kscience.kmath.misc + +import kscience.kmath.operations.Space +import kscience.kmath.operations.invoke +import kotlin.jvm.JvmName + +/** + * Generic cumulative operation on iterator. + * + * @param T the type of initial iterable. + * @param R the type of resulting iterable. + * @param initial lazy evaluated. + */ +public inline fun Iterator.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator = + object : Iterator { + var state: R = initial + + override fun hasNext(): Boolean = this@cumulative.hasNext() + + override fun next(): R { + state = operation(state, this@cumulative.next()) + return state + } + } + +public inline fun Iterable.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable = + Iterable { this@cumulative.iterator().cumulative(initial, operation) } + +public inline fun Sequence.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence = + Sequence { this@cumulative.iterator().cumulative(initial, operation) } + +public fun List.cumulative(initial: R, operation: (R, T) -> R): List = + iterator().cumulative(initial, operation).asSequence().toList() + +//Cumulative sum + +/** + * Cumulative sum with custom space + */ +public fun Iterable.cumulativeSum(space: Space): Iterable = + space { cumulative(zero) { element: T, sum: T -> sum + element } } + +@JvmName("cumulativeSumOfDouble") +public fun Iterable.cumulativeSum(): Iterable = cumulative(0.0) { element, sum -> sum + element } + +@JvmName("cumulativeSumOfInt") +public fun Iterable.cumulativeSum(): Iterable = cumulative(0) { element, sum -> sum + element } + +@JvmName("cumulativeSumOfLong") +public fun Iterable.cumulativeSum(): Iterable = cumulative(0L) { element, sum -> sum + element } + +public fun Sequence.cumulativeSum(space: Space): Sequence = + space { cumulative(zero) { element: T, sum: T -> sum + element } } + +@JvmName("cumulativeSumOfDouble") +public fun Sequence.cumulativeSum(): Sequence = cumulative(0.0) { element, sum -> sum + element } + +@JvmName("cumulativeSumOfInt") +public fun Sequence.cumulativeSum(): Sequence = cumulative(0) { element, sum -> sum + element } + +@JvmName("cumulativeSumOfLong") +public fun Sequence.cumulativeSum(): Sequence = cumulative(0L) { element, sum -> sum + element } + +public fun List.cumulativeSum(space: Space): List = + space { cumulative(zero) { element: T, sum: T -> sum + element } } + +@JvmName("cumulativeSumOfDouble") +public fun List.cumulativeSum(): List = cumulative(0.0) { element, sum -> sum + element } + +@JvmName("cumulativeSumOfInt") +public fun List.cumulativeSum(): List = cumulative(0) { element, sum -> sum + element } + +@JvmName("cumulativeSumOfLong") +public fun List.cumulativeSum(): List = cumulative(0L) { element, sum -> sum + element } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt new file mode 100644 index 000000000..e7eb2770d --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt @@ -0,0 +1,325 @@ +package kscience.kmath.operations + +/** + * Stub for DSL the [Algebra] is. + */ +@DslMarker +public annotation class KMathContext + +/** + * Represents an algebraic structure. + * + * @param T the type of element of this structure. + */ +public interface Algebra { + /** + * Wraps a raw string to [T] object. This method is designed for three purposes: + * + * 1. Mathematical constants (`e`, `pi`). + * 2. Variables for expression-like contexts (`a`, `b`, `c`...). + * 3. Literals (`{1, 2}`, (`(3; 4)`)). + * + * In case if algebra can't parse the string, this method must throw [kotlin.IllegalStateException]. + * + * @param value the raw string. + * @return an object. + */ + public fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") + + /** + * Dynamically dispatches an unary operation with the certain name. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second `unaryOperation` overload: + * i.e. `unaryOperationFunction(a)(b) == unaryOperation(a, b)`. + * + * @param operation the name of operation. + * @return an operation. + */ + public fun unaryOperationFunction(operation: String): (arg: T) -> T = + error("Unary operation $operation not defined in $this") + + /** + * Dynamically invokes an unary operation with the certain name. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second [unaryOperationFunction] overload: + * i.e. `unaryOperationFunction(a)(b) == unaryOperation(a, b)`. + * + * @param operation the name of operation. + * @param arg the argument of operation. + * @return a result of operation. + */ + public fun unaryOperation(operation: String, arg: T): T = unaryOperationFunction(operation)(arg) + + /** + * Dynamically dispatches a binary operation with the certain name. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second [binaryOperationFunction] overload: + * i.e. `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`. + * + * @param operation the name of operation. + * @return an operation. + */ + public fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = + error("Binary operation $operation not defined in $this") + + /** + * Dynamically invokes a binary operation with the certain name. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second [binaryOperationFunction] overload: + * i.e. `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`. + * + * @param operation the name of operation. + * @param left the first argument of operation. + * @param right the second argument of operation. + * @return a result of operation. + */ + public fun binaryOperation(operation: String, left: T, right: T): T = binaryOperationFunction(operation)(left, right) +} + +/** + * Call a block with an [Algebra] as receiver. + */ +// TODO add contract when KT-32313 is fixed +public inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) + +/** + * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as + * multiplication by scalars. + * + * @param T the type of element of this semispace. + */ +public interface SpaceOperations : Algebra { + /** + * Addition of two elements. + * + * @param a the addend. + * @param b the augend. + * @return the sum. + */ + public fun add(a: T, b: T): T + + /** + * Multiplication of element by scalar. + * + * @param a the multiplier. + * @param k the multiplicand. + * @return the produce. + */ + public fun multiply(a: T, k: Number): T + + // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176 + + /** + * The negation of this element. + * + * @receiver this value. + * @return the additive inverse of this value. + */ + public operator fun T.unaryMinus(): T = multiply(this, -1.0) + + /** + * Returns this value. + * + * @receiver this value. + * @return this value. + */ + public operator fun T.unaryPlus(): T = this + + /** + * Addition of two elements. + * + * @receiver the addend. + * @param b the augend. + * @return the sum. + */ + public operator fun T.plus(b: T): T = add(this, b) + + /** + * Subtraction of two elements. + * + * @receiver the minuend. + * @param b the subtrahend. + * @return the difference. + */ + public operator fun T.minus(b: T): T = add(this, -b) + + /** + * Multiplication of this element by a scalar. + * + * @receiver the multiplier. + * @param k the multiplicand. + * @return the product. + */ + public operator fun T.times(k: Number): T = multiply(this, k.toDouble()) + + /** + * Division of this element by scalar. + * + * @receiver the dividend. + * @param k the divisor. + * @return the quotient. + */ + public operator fun T.div(k: Number): T = multiply(this, 1.0 / k.toDouble()) + + /** + * Multiplication of this number by element. + * + * @receiver the multiplier. + * @param b the multiplicand. + * @return the product. + */ + public operator fun Number.times(b: T): T = b * this + + public override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { + PLUS_OPERATION -> { arg -> arg } + MINUS_OPERATION -> { arg -> -arg } + else -> super.unaryOperationFunction(operation) + } + + public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { + PLUS_OPERATION -> ::add + MINUS_OPERATION -> { left, right -> left - right } + else -> super.binaryOperationFunction(operation) + } + + public companion object { + /** + * The identifier of addition and unary positive operator. + */ + public const val PLUS_OPERATION: String = "+" + + /** + * The identifier of subtraction and unary negative operator. + */ + public const val MINUS_OPERATION: String = "-" + } +} + +/** + * Represents linear space, i.e. algebraic structure with associative binary operation called "addition" and its neutral + * element as well as multiplication by scalars. + * + * @param T the type of element of this group. + */ +public interface Space : SpaceOperations { + /** + * The neutral element of addition. + */ + public val zero: T +} + +/** + * Represents semiring, i.e. algebraic structure with two associative binary operations called "addition" and + * "multiplication". + * + * @param T the type of element of this semiring. + */ +public interface RingOperations : SpaceOperations { + /** + * Multiplies two elements. + * + * @param a the multiplier. + * @param b the multiplicand. + */ + public fun multiply(a: T, b: T): T + + /** + * Multiplies this element by scalar. + * + * @receiver the multiplier. + * @param b the multiplicand. + */ + public operator fun T.times(b: T): T = multiply(this, b) + + public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { + TIMES_OPERATION -> ::multiply + else -> super.binaryOperationFunction(operation) + } + + public companion object { + /** + * The identifier of multiplication. + */ + public const val TIMES_OPERATION: String = "*" + } +} + +/** + * Represents ring, i.e. algebraic structure with two associative binary operations called "addition" and + * "multiplication" and their neutral elements. + * + * @param T the type of element of this ring. + */ +public interface Ring : Space, RingOperations { + /** + * neutral operation for multiplication + */ + public val one: T +} + +/** + * Represents semifield, i.e. algebraic structure with three operations: associative "addition" and "multiplication", + * and "division". + * + * @param T the type of element of this semifield. + */ +public interface FieldOperations : RingOperations { + /** + * Division of two elements. + * + * @param a the dividend. + * @param b the divisor. + * @return the quotient. + */ + public fun divide(a: T, b: T): T + + /** + * Division of two elements. + * + * @receiver the dividend. + * @param b the divisor. + * @return the quotient. + */ + public operator fun T.div(b: T): T = divide(this, b) + + public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { + DIV_OPERATION -> ::divide + else -> super.binaryOperationFunction(operation) + } + + public companion object { + /** + * The identifier of division. + */ + public const val DIV_OPERATION: String = "/" + } +} + +/** + * Represents field, i.e. algebraic structure with three operations: associative "addition" and "multiplication", + * and "division" and their neutral elements. + * + * @param T the type of element of this semifield. + */ +public interface Field : Ring, FieldOperations { + /** + * Division of element by scalar. + * + * @receiver the dividend. + * @param b the divisor. + * @return the quotient. + */ + public operator fun Number.div(b: T): T = this * divide(one, b) +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/AlgebraElements.kt similarity index 59% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/operations/AlgebraElements.kt index 197897c14..aa572d894 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraElements.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/AlgebraElements.kt @@ -1,15 +1,15 @@ -package scientifik.kmath.operations +package kscience.kmath.operations /** * The generic mathematics elements which is able to store its context * * @param C the type of mathematical context for this element. */ -interface MathElement { +public interface MathElement { /** * The context this element belongs to. */ - val context: C + public val context: C } /** @@ -18,16 +18,16 @@ interface MathElement { * @param T the type wrapped by this wrapper. * @param I the type of this wrapper. */ -interface MathWrapper { +public interface MathWrapper { /** * Unwraps [I] to [T]. */ - fun unwrap(): T + public fun unwrap(): T /** * Wraps [T] to [I]. */ - fun T.wrap(): I + public fun T.wrap(): I } /** @@ -37,14 +37,14 @@ interface MathWrapper { * @param I self type of the element. Needed for static type checking. * @param S the type of space. */ -interface SpaceElement, S : Space> : MathElement, MathWrapper { +public interface SpaceElement, S : Space> : MathElement, MathWrapper { /** * Adds element to this one. * * @param b the augend. * @return the sum. */ - operator fun plus(b: T): I = context.add(unwrap(), b).wrap() + public operator fun plus(b: T): I = context.add(unwrap(), b).wrap() /** * Subtracts element from this one. @@ -52,7 +52,7 @@ interface SpaceElement, S : Space> : MathElement * @param b the subtrahend. * @return the difference. */ - operator fun minus(b: T): I = context.add(unwrap(), context.multiply(b, -1.0)).wrap() + public operator fun minus(b: T): I = context.add(unwrap(), context.multiply(b, -1.0)).wrap() /** * Multiplies this element by number. @@ -60,7 +60,7 @@ interface SpaceElement, S : Space> : MathElement * @param k the multiplicand. * @return the product. */ - operator fun times(k: Number): I = context.multiply(unwrap(), k.toDouble()).wrap() + public operator fun times(k: Number): I = context.multiply(unwrap(), k.toDouble()).wrap() /** * Divides this element by number. @@ -68,34 +68,34 @@ interface SpaceElement, S : Space> : MathElement * @param k the divisor. * @return the quotient. */ - operator fun div(k: Number): I = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap() + public operator fun div(k: Number): I = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap() } /** * The element of [Ring]. * - * @param T the type of space operation results. + * @param T the type of ring operation results. * @param I self type of the element. Needed for static type checking. - * @param R the type of space. + * @param R the type of ring. */ -interface RingElement, R : Ring> : SpaceElement { +public interface RingElement, R : Ring> : SpaceElement { /** * Multiplies this element by another one. * * @param b the multiplicand. * @return the product. */ - operator fun times(b: T): I = context.multiply(unwrap(), b).wrap() + public operator fun times(b: T): I = context.multiply(unwrap(), b).wrap() } /** * The element of [Field]. * - * @param T the type of space operation results. + * @param T the type of field operation results. * @param I self type of the element. Needed for static type checking. * @param F the type of field. */ -interface FieldElement, F : Field> : RingElement { +public interface FieldElement, F : Field> : RingElement { override val context: F /** @@ -104,5 +104,5 @@ interface FieldElement, F : Field> : RingElement * @param b the divisor. * @return the quotient. */ - operator fun div(b: T): I = context.divide(unwrap(), b).wrap() + public operator fun div(b: T): I = context.divide(unwrap(), b).wrap() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraExtensions.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/AlgebraExtensions.kt similarity index 69% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraExtensions.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/operations/AlgebraExtensions.kt index 00b16dc98..4527a2a42 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AlgebraExtensions.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/AlgebraExtensions.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.operations +package kscience.kmath.operations /** * Returns the sum of all elements in the iterable in this [Space]. @@ -7,7 +7,7 @@ package scientifik.kmath.operations * @param data the iterable to sum up. * @return the sum. */ -fun Space.sum(data: Iterable): T = data.fold(zero) { left, right -> add(left, right) } +public fun Space.sum(data: Iterable): T = data.fold(zero) { left, right -> add(left, right) } /** * Returns the sum of all elements in the sequence in this [Space]. @@ -16,7 +16,7 @@ fun Space.sum(data: Iterable): T = data.fold(zero) { left, right -> ad * @param data the sequence to sum up. * @return the sum. */ -fun Space.sum(data: Sequence): T = data.fold(zero) { left, right -> add(left, right) } +public fun Space.sum(data: Sequence): T = data.fold(zero) { left, right -> add(left, right) } /** * Returns an average value of elements in the iterable in this [Space]. @@ -24,8 +24,9 @@ fun Space.sum(data: Sequence): T = data.fold(zero) { left, right -> ad * @receiver the algebra that provides addition and division. * @param data the iterable to find average. * @return the average value. + * @author Iaroslav Postovalov */ -fun Space.average(data: Iterable): T = sum(data) / data.count() +public fun Space.average(data: Iterable): T = sum(data) / data.count() /** * Returns an average value of elements in the sequence in this [Space]. @@ -33,8 +34,14 @@ fun Space.average(data: Iterable): T = sum(data) / data.count() * @receiver the algebra that provides addition and division. * @param data the sequence to find average. * @return the average value. + * @author Iaroslav Postovalov */ -fun Space.average(data: Sequence): T = sum(data) / data.count() +public fun Space.average(data: Sequence): T = sum(data) / data.count() + +/** + * Absolute of the comparable [value] + */ +public fun > Space.abs(value: T): T = if (value > zero) value else -value /** * Returns the sum of all elements in the iterable in provided space. @@ -43,7 +50,7 @@ fun Space.average(data: Sequence): T = sum(data) / data.count() * @param space the algebra that provides addition. * @return the sum. */ -fun Iterable.sumWith(space: Space): T = space.sum(this) +public fun Iterable.sumWith(space: Space): T = space.sum(this) /** * Returns the sum of all elements in the sequence in provided space. @@ -52,7 +59,7 @@ fun Iterable.sumWith(space: Space): T = space.sum(this) * @param space the algebra that provides addition. * @return the sum. */ -fun Sequence.sumWith(space: Space): T = space.sum(this) +public fun Sequence.sumWith(space: Space): T = space.sum(this) /** * Returns an average value of elements in the iterable in this [Space]. @@ -60,8 +67,9 @@ fun Sequence.sumWith(space: Space): T = space.sum(this) * @receiver the iterable to find average. * @param space the algebra that provides addition and division. * @return the average value. + * @author Iaroslav Postovalov */ -fun Iterable.averageWith(space: Space): T = space.average(this) +public fun Iterable.averageWith(space: Space): T = space.average(this) /** * Returns an average value of elements in the sequence in this [Space]. @@ -69,8 +77,9 @@ fun Iterable.averageWith(space: Space): T = space.average(this) * @receiver the sequence to find average. * @param space the algebra that provides addition and division. * @return the average value. + * @author Iaroslav Postovalov */ -fun Sequence.averageWith(space: Space): T = space.average(this) +public fun Sequence.averageWith(space: Space): T = space.average(this) //TODO optimized power operation @@ -82,7 +91,7 @@ fun Sequence.averageWith(space: Space): T = space.average(this) * @param power the exponent. * @return the base raised to the power. */ -fun Ring.power(arg: T, power: Int): T { +public fun Ring.power(arg: T, power: Int): T { require(power >= 0) { "The power can't be negative." } require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." } if (power == 0) return one @@ -98,8 +107,9 @@ fun Ring.power(arg: T, power: Int): T { * @param arg the base. * @param power the exponent. * @return the base raised to the power. + * @author Iaroslav Postovalov */ -fun Field.power(arg: T, power: Int): T { +public fun Field.power(arg: T, power: Int): T { require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." } if (power == 0) return one if (power < 0) return one / (this as Ring).power(arg, -power) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt similarity index 57% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt index 0eed7132e..0be72e80c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt @@ -1,24 +1,24 @@ -package scientifik.kmath.operations +package kscience.kmath.operations -import scientifik.kmath.operations.BigInt.Companion.BASE -import scientifik.kmath.operations.BigInt.Companion.BASE_SIZE -import scientifik.kmath.structures.* -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.BigInt.Companion.BASE +import kscience.kmath.operations.BigInt.Companion.BASE_SIZE +import kscience.kmath.structures.* import kotlin.math.log2 import kotlin.math.max import kotlin.math.min import kotlin.math.sign -typealias Magnitude = UIntArray -typealias TBase = ULong +public typealias Magnitude = UIntArray +public typealias TBase = ULong /** * Kotlin Multiplatform implementation of Big Integer numbers (KBigInteger). * * @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai) */ -object BigIntField : Field { +@OptIn(UnstableKMathAPI::class) +public object BigIntField : Field, RingWithNumbers { override val zero: BigInt = BigInt.ZERO override val one: BigInt = BigInt.ONE @@ -29,113 +29,92 @@ object BigIntField : Field { override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b) - operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer") + public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer") - operator fun String.unaryMinus(): BigInt = + public operator fun String.unaryMinus(): BigInt = -(this.parseBigInteger() ?: error("Can't parse $this as big integer")) override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b) } -class BigInt internal constructor( +public class BigInt internal constructor( private val sign: Byte, private val magnitude: Magnitude ) : Comparable { + public override fun compareTo(other: BigInt): Int = when { + (sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0 + sign < other.sign -> -1 + sign > other.sign -> 1 + else -> sign * compareMagnitudes(magnitude, other.magnitude) + } - override fun compareTo(other: BigInt): Int { - return when { - (this.sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0 - this.sign < other.sign -> -1 - this.sign > other.sign -> 1 - else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude) + public override fun equals(other: Any?): Boolean = + if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type") + + public override fun hashCode(): Int = magnitude.hashCode() + sign + + public fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude) + + public operator fun unaryMinus(): BigInt = + if (this.sign == 0.toByte()) this else BigInt((-this.sign).toByte(), this.magnitude) + + public operator fun plus(b: BigInt): BigInt = when { + b.sign == 0.toByte() -> this + sign == 0.toByte() -> b + this == -b -> ZERO + sign == b.sign -> BigInt(sign, addMagnitudes(magnitude, b.magnitude)) + + else -> { + val comp = compareMagnitudes(magnitude, b.magnitude) + + if (comp == 1) + BigInt(sign, subtractMagnitudes(magnitude, b.magnitude)) + else + BigInt((-sign).toByte(), subtractMagnitudes(b.magnitude, magnitude)) } } - override fun equals(other: Any?): Boolean { - if (other is BigInt) { - return this.compareTo(other) == 0 - } else error("Can't compare KBigInteger to a different type") - } + public operator fun minus(b: BigInt): BigInt = this + (-b) - override fun hashCode(): Int { - return magnitude.hashCode() + this.sign - } - - fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude) - - operator fun unaryMinus(): BigInt { - return if (this.sign == 0.toByte()) this else BigInt((-this.sign).toByte(), this.magnitude) - } - - operator fun plus(b: BigInt): BigInt { - return when { - b.sign == 0.toByte() -> this - this.sign == 0.toByte() -> b - this == -b -> ZERO - this.sign == b.sign -> BigInt(this.sign, addMagnitudes(this.magnitude, b.magnitude)) - else -> { - val comp: Int = compareMagnitudes(this.magnitude, b.magnitude) - - if (comp == 1) { - BigInt(this.sign, subtractMagnitudes(this.magnitude, b.magnitude)) - } else { - BigInt((-this.sign).toByte(), subtractMagnitudes(b.magnitude, this.magnitude)) - } - } - } - } - - operator fun minus(b: BigInt): BigInt { - return this + (-b) - } - - operator fun times(b: BigInt): BigInt { - return when { - this.sign == 0.toByte() -> ZERO - b.sign == 0.toByte() -> ZERO + public operator fun times(b: BigInt): BigInt = when { + this.sign == 0.toByte() -> ZERO + b.sign == 0.toByte() -> ZERO // TODO: Karatsuba - else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude)) - } + else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude)) } - operator fun times(other: UInt): BigInt { - return when { - this.sign == 0.toByte() -> ZERO - other == 0U -> ZERO - else -> BigInt(this.sign, multiplyMagnitudeByUInt(this.magnitude, other)) - } + public operator fun times(other: UInt): BigInt = when { + sign == 0.toByte() -> ZERO + other == 0U -> ZERO + else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other)) } - operator fun times(other: Int): BigInt { - return if (other > 0) - this * kotlin.math.abs(other).toUInt() - else - -this * kotlin.math.abs(other).toUInt() - } + public operator fun times(other: Int): BigInt = if (other > 0) + this * kotlin.math.abs(other).toUInt() + else + -this * kotlin.math.abs(other).toUInt() - operator fun div(other: UInt): BigInt { - return BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other)) - } + public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other)) - operator fun div(other: Int): BigInt { - return BigInt( - (this.sign * other.sign).toByte(), - divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt()) - ) - } + public operator fun div(other: Int): BigInt = BigInt( + (this.sign * other.sign).toByte(), + divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt()) + ) private fun division(other: BigInt): Pair { // Long division algorithm: // https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder // TODO: Implement more effective algorithm - var q: BigInt = ZERO - var r: BigInt = ZERO + var q = ZERO + var r = ZERO val bitSize = (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.lastOrNull()?.toFloat() ?: 0f + 1)).toInt() + for (i in bitSize downTo 0) { r = r shl 1 r = r or ((abs(this) shr i) and ONE) + if (r >= abs(other)) { r -= abs(other) q += (ONE shl i) @@ -145,99 +124,84 @@ class BigInt internal constructor( return Pair(BigInt((this.sign * other.sign).toByte(), q.magnitude), r) } - operator fun div(other: BigInt): BigInt { - return this.division(other).first - } + public operator fun div(other: BigInt): BigInt = division(other).first - infix fun shl(i: Int): BigInt { + public infix fun shl(i: Int): BigInt { if (this == ZERO) return ZERO if (i == 0) return this - val fullShifts = i / BASE_SIZE + 1 val relShift = i % BASE_SIZE val shiftLeft = { x: UInt -> if (relShift >= 32) 0U else x shl relShift } val shiftRight = { x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift) } + val newMagnitude = Magnitude(magnitude.size + fullShifts) - val newMagnitude: Magnitude = Magnitude(this.magnitude.size + fullShifts) - - for (j in this.magnitude.indices) { + for (j in magnitude.indices) { newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j]) - if (j != 0) { + + if (j != 0) newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1]) - } } - newMagnitude[this.magnitude.size + fullShifts - 1] = shiftRight(this.magnitude.last()) - + newMagnitude[magnitude.size + fullShifts - 1] = shiftRight(magnitude.last()) return BigInt(this.sign, stripLeadingZeros(newMagnitude)) } - infix fun shr(i: Int): BigInt { + public infix fun shr(i: Int): BigInt { if (this == ZERO) return ZERO if (i == 0) return this - val fullShifts = i / BASE_SIZE val relShift = i % BASE_SIZE val shiftRight = { x: UInt -> if (relShift >= 32) 0U else x shr relShift } val shiftLeft = { x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift) } - if (this.magnitude.size - fullShifts <= 0) { - return ZERO - } - val newMagnitude: Magnitude = Magnitude(this.magnitude.size - fullShifts) + if (this.magnitude.size - fullShifts <= 0) return ZERO + val newMagnitude: Magnitude = Magnitude(magnitude.size - fullShifts) - for (j in fullShifts until this.magnitude.size) { - newMagnitude[j - fullShifts] = shiftRight(this.magnitude[j]) - if (j != this.magnitude.size - 1) { - newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(this.magnitude[j + 1]) - } + for (j in fullShifts until magnitude.size) { + newMagnitude[j - fullShifts] = shiftRight(magnitude[j]) + + if (j != magnitude.size - 1) + newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(magnitude[j + 1]) } return BigInt(this.sign, stripLeadingZeros(newMagnitude)) } - infix fun or(other: BigInt): BigInt { + public infix fun or(other: BigInt): BigInt { if (this == ZERO) return other if (other == ZERO) return this - val resSize = max(this.magnitude.size, other.magnitude.size) + val resSize = max(magnitude.size, other.magnitude.size) val newMagnitude: Magnitude = Magnitude(resSize) + for (i in 0 until resSize) { - if (i < this.magnitude.size) { - newMagnitude[i] = newMagnitude[i] or this.magnitude[i] - } - if (i < other.magnitude.size) { - newMagnitude[i] = newMagnitude[i] or other.magnitude[i] - } + if (i < magnitude.size) newMagnitude[i] = newMagnitude[i] or magnitude[i] + if (i < other.magnitude.size) newMagnitude[i] = newMagnitude[i] or other.magnitude[i] } + return BigInt(1, stripLeadingZeros(newMagnitude)) } - infix fun and(other: BigInt): BigInt { + public infix fun and(other: BigInt): BigInt { if ((this == ZERO) or (other == ZERO)) return ZERO val resSize = min(this.magnitude.size, other.magnitude.size) val newMagnitude: Magnitude = Magnitude(resSize) - for (i in 0 until resSize) { - newMagnitude[i] = this.magnitude[i] and other.magnitude[i] - } + for (i in 0 until resSize) newMagnitude[i] = this.magnitude[i] and other.magnitude[i] return BigInt(1, stripLeadingZeros(newMagnitude)) } - operator fun rem(other: Int): Int { + public operator fun rem(other: Int): Int { val res = this - (this / other) * other return if (res == ZERO) 0 else res.sign * res.magnitude[0].toInt() } - operator fun rem(other: BigInt): BigInt { - return this - (this / other) * other - } + public operator fun rem(other: BigInt): BigInt = this - (this / other) * other - fun modPow(exponent: BigInt, m: BigInt): BigInt { - return when { - exponent == ZERO -> ONE - exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m - else -> { - val sqRoot = modPow(exponent / 2, m) - (sqRoot * sqRoot) % m - } + public fun modPow(exponent: BigInt, m: BigInt): BigInt = when { + exponent == ZERO -> ONE + exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m + + else -> { + val sqRoot = modPow(exponent / 2, m) + (sqRoot * sqRoot) % m } } @@ -261,11 +225,11 @@ class BigInt internal constructor( return res } - companion object { - const val BASE: ULong = 0xffffffffUL - const val BASE_SIZE: Int = 32 - val ZERO: BigInt = BigInt(0, uintArrayOf()) - val ONE: BigInt = BigInt(1, uintArrayOf(1u)) + public companion object { + public const val BASE: ULong = 0xffffffffUL + public const val BASE_SIZE: Int = 32 + public val ZERO: BigInt = BigInt(0, uintArrayOf()) + public val ONE: BigInt = BigInt(1, uintArrayOf(1u)) private val hexMapping: HashMap = hashMapOf( 0U to "0", 1U to "1", 2U to "2", 3U to "3", @@ -292,9 +256,9 @@ class BigInt internal constructor( } private fun addMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { - val resultLength: Int = max(mag1.size, mag2.size) + 1 + val resultLength = max(mag1.size, mag2.size) + 1 val result = Magnitude(resultLength) - var carry: TBase = 0UL + var carry = 0uL for (i in 0 until resultLength - 1) { val res = when { @@ -302,20 +266,22 @@ class BigInt internal constructor( i >= mag2.size -> mag1[i].toULong() + carry else -> mag1[i].toULong() + mag2[i].toULong() + carry } + result[i] = (res and BASE).toUInt() carry = (res shr BASE_SIZE) } + result[resultLength - 1] = carry.toUInt() return stripLeadingZeros(result) } private fun subtractMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { - val resultLength: Int = mag1.size + val resultLength = mag1.size val result = Magnitude(resultLength) var carry = 0L for (i in 0 until resultLength) { - var res: Long = + var res = if (i < mag2.size) mag1[i].toLong() - mag2[i].toLong() - carry else mag1[i].toLong() - carry @@ -329,13 +295,13 @@ class BigInt internal constructor( } private fun multiplyMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { - val resultLength: Int = mag.size + 1 + val resultLength = mag.size + 1 val result = Magnitude(resultLength) - var carry: ULong = 0UL + var carry = 0uL for (i in mag.indices) { val cur: ULong = carry + mag[i].toULong() * x.toULong() - result[i] = (cur and BASE.toULong()).toUInt() + result[i] = (cur and BASE).toUInt() carry = cur shr BASE_SIZE } result[resultLength - 1] = (carry and BASE).toUInt() @@ -344,16 +310,18 @@ class BigInt internal constructor( } private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { - val resultLength: Int = mag1.size + mag2.size + val resultLength = mag1.size + mag2.size val result = Magnitude(resultLength) for (i in mag1.indices) { - var carry: ULong = 0UL + var carry = 0uL + for (j in mag2.indices) { val cur: ULong = result[i + j].toULong() + mag1[i].toULong() * mag2[j].toULong() + carry - result[i + j] = (cur and BASE.toULong()).toUInt() + result[i + j] = (cur and BASE).toUInt() carry = cur shr BASE_SIZE } + result[i + mag2.size] = (carry and BASE).toUInt() } @@ -361,48 +329,46 @@ class BigInt internal constructor( } private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { - val resultLength: Int = mag.size + val resultLength = mag.size val result = Magnitude(resultLength) - var carry: ULong = 0UL + var carry = 0uL for (i in mag.size - 1 downTo 0) { val cur: ULong = mag[i].toULong() + (carry shl BASE_SIZE) result[i] = (cur / x).toUInt() carry = cur % x } + return stripLeadingZeros(result) } - } - } - private fun stripLeadingZeros(mag: Magnitude): Magnitude { - if (mag.isEmpty() || mag.last() != 0U) { - return mag - } - var resSize: Int = mag.size - 1 + if (mag.isEmpty() || mag.last() != 0U) return mag + var resSize = mag.size - 1 + while (mag[resSize] == 0U) { - if (resSize == 0) - break + if (resSize == 0) break resSize -= 1 } + return mag.sliceArray(IntRange(0, resSize)) } -fun abs(x: BigInt): BigInt = x.abs() +public fun abs(x: BigInt): BigInt = x.abs() /** * Convert this [Int] to [BigInt] */ -fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt())) +public fun Int.toBigInt(): BigInt = BigInt(sign.toByte(), uintArrayOf(kotlin.math.abs(this).toUInt())) /** * Convert this [Long] to [BigInt] */ -fun Long.toBigInt(): BigInt = BigInt( - sign.toByte(), stripLeadingZeros( +public fun Long.toBigInt(): BigInt = BigInt( + sign.toByte(), + stripLeadingZeros( uintArrayOf( (kotlin.math.abs(this).toULong() and BASE).toUInt(), ((kotlin.math.abs(this).toULong() shr BASE_SIZE) and BASE).toUInt() @@ -413,12 +379,12 @@ fun Long.toBigInt(): BigInt = BigInt( /** * Convert UInt to [BigInt] */ -fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this)) +public fun UInt.toBigInt(): BigInt = BigInt(1, uintArrayOf(this)) /** * Convert ULong to [BigInt] */ -fun ULong.toBigInt(): BigInt = BigInt( +public fun ULong.toBigInt(): BigInt = BigInt( 1, stripLeadingZeros( uintArrayOf( @@ -431,12 +397,12 @@ fun ULong.toBigInt(): BigInt = BigInt( /** * Create a [BigInt] with this array of magnitudes with protective copy */ -fun UIntArray.toBigInt(sign: Byte): BigInt { +public fun UIntArray.toBigInt(sign: Byte): BigInt { require(sign != 0.toByte() || !isNotEmpty()) return BigInt(sign, copyOf()) } -val hexChToInt: MutableMap = hashMapOf( +private val hexChToInt: MutableMap = hashMapOf( '0' to 0, '1' to 1, '2' to 2, '3' to 3, '4' to 4, '5' to 5, '6' to 6, '7' to 7, '8' to 8, '9' to 9, 'A' to 10, 'B' to 11, @@ -446,9 +412,10 @@ val hexChToInt: MutableMap = hashMapOf( /** * Returns null if a valid number can not be read from a string */ -fun String.parseBigInteger(): BigInt? { +public fun String.parseBigInteger(): BigInt? { val sign: Int val sPositive: String + when { this[0] == '+' -> { sign = +1 @@ -463,43 +430,42 @@ fun String.parseBigInteger(): BigInt? { sign = +1 } } + var res = BigInt.ZERO var digitValue = BigInt.ONE val sPositiveUpper = sPositive.toUpperCase() + if (sPositiveUpper.startsWith("0X")) { // hex representation val sHex = sPositiveUpper.substring(2) + for (ch in sHex.reversed()) { if (ch == '_') continue res += digitValue * (hexChToInt[ch] ?: return null) digitValue *= 16.toBigInt() } - } else { // decimal representation - for (ch in sPositiveUpper.reversed()) { - if (ch == '_') continue - if (ch !in '0'..'9') { - return null - } - res += digitValue * (ch.toInt() - '0'.toInt()) - digitValue *= 10.toBigInt() + } else for (ch in sPositiveUpper.reversed()) { + // decimal representation + if (ch == '_') continue + if (ch !in '0'..'9') { + return null } + res += digitValue * (ch.toInt() - '0'.toInt()) + digitValue *= 10.toBigInt() } + return res * sign } -inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer { - contract { callsInPlace(initializer) } - return boxing(size, initializer) -} +public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer = + boxing(size, initializer) -inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer { - contract { callsInPlace(initializer) } - return boxing(size, initializer) -} +public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer = + boxing(size, initializer) -fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing = +public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing = BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt) -fun NDElement.Companion.bigInt( +public fun NDElement.Companion.bigInt( vararg shape: Int, initializer: BigIntField.(IntArray) -> BigInt ): BufferedNDRingElement = NDAlgebra.bigInt(*shape).produce(initializer) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt similarity index 64% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt index dcfd97d1a..c6409c015 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt @@ -1,25 +1,25 @@ -package scientifik.kmath.operations +package kscience.kmath.operations -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.MemoryBuffer -import scientifik.kmath.structures.MutableBuffer -import scientifik.memory.MemoryReader -import scientifik.memory.MemorySpec -import scientifik.memory.MemoryWriter -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract +import kscience.kmath.memory.MemoryReader +import kscience.kmath.memory.MemorySpec +import kscience.kmath.memory.MemoryWriter +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.MemoryBuffer +import kscience.kmath.structures.MutableBuffer +import kscience.kmath.structures.MutableMemoryBuffer import kotlin.math.* /** * This complex's conjugate. */ -val Complex.conjugate: Complex +public val Complex.conjugate: Complex get() = Complex(re, -im) /** * This complex's reciprocal. */ -val Complex.reciprocal: Complex +public val Complex.reciprocal: Complex get() { val scale = re * re + im * im return Complex(re / scale, -im / scale) @@ -28,13 +28,13 @@ val Complex.reciprocal: Complex /** * Absolute value of complex number. */ -val Complex.r: Double +public val Complex.r: Double get() = sqrt(re * re + im * im) /** * An angle between vector represented by complex number and X axis. */ -val Complex.theta: Double +public val Complex.theta: Double get() = atan(im / re) private val PI_DIV_2 = Complex(PI / 2, 0) @@ -42,14 +42,15 @@ private val PI_DIV_2 = Complex(PI / 2, 0) /** * A field of [Complex]. */ -object ComplexField : ExtendedField, Norm { +@OptIn(UnstableKMathAPI::class) +public object ComplexField : ExtendedField, Norm, RingWithNumbers { override val zero: Complex = 0.0.toComplex() override val one: Complex = 1.0.toComplex() /** * The imaginary unit. */ - val i: Complex = Complex(0.0, 1.0) + public val i: Complex = Complex(0.0, 1.0) override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) @@ -117,7 +118,7 @@ object ComplexField : ExtendedField, Norm { * @param c the augend. * @return the sum. */ - operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c) + public operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c) /** * Subtracts complex number from real one. @@ -126,7 +127,7 @@ object ComplexField : ExtendedField, Norm { * @param c the subtrahend. * @return the difference. */ - operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c) + public operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c) /** * Adds real number to complex one. @@ -135,7 +136,7 @@ object ComplexField : ExtendedField, Norm { * @param d the augend. * @return the sum. */ - operator fun Complex.plus(d: Double): Complex = d + this + public operator fun Complex.plus(d: Double): Complex = d + this /** * Subtracts real number from complex one. @@ -144,7 +145,7 @@ object ComplexField : ExtendedField, Norm { * @param d the subtrahend. * @return the difference. */ - operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex()) + public operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex()) /** * Multiplies real number by complex one. @@ -153,21 +154,22 @@ object ComplexField : ExtendedField, Norm { * @param c the multiplicand. * @receiver the product. */ - operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) + public operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) - override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) + override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) } /** - * Represents complex number. + * Represents `double`-based complex number. * * @property re The real part. * @property im The imaginary part. */ -data class Complex(val re: Double, val im: Double) : FieldElement, Comparable { - constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) +public data class Complex(val re: Double, val im: Double) : FieldElement, + Comparable { + public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) override val context: ComplexField get() = ComplexField @@ -177,11 +179,16 @@ data class Complex(val re: Double, val im: Double) : FieldElement { - override val objectSize: Int = 16 + override fun toString(): String { + return "($re + i*$im)" + } - override fun MemoryReader.read(offset: Int): Complex = - Complex(readDouble(offset), readDouble(offset + 8)) + + public companion object : MemorySpec { + override val objectSize: Int + get() = 16 + + override fun MemoryReader.read(offset: Int): Complex = Complex(readDouble(offset), readDouble(offset + 8)) override fun MemoryWriter.write(offset: Int, value: Complex) { writeDouble(offset, value.re) @@ -190,20 +197,25 @@ data class Complex(val re: Double, val im: Double) : FieldElement Complex): Buffer { - contract { callsInPlace(init) } - return MemoryBuffer.create(Complex, size, init) -} +/** + * Creates a new buffer of complex numbers with the specified [size], where each element is calculated by calling the + * specified [init] function. + */ +public inline fun Buffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer = + MemoryBuffer.create(Complex, size, init) -inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer { - contract { callsInPlace(init) } - return MemoryBuffer.create(Complex, size, init) -} +/** + * Creates a new buffer of complex numbers with the specified [size], where each element is calculated by calling the + * specified [init] function. + */ +public inline fun MutableBuffer.Companion.complex(size: Int, init: (Int) -> Complex): MutableBuffer = + MutableMemoryBuffer.create(Complex, size, init) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt new file mode 100644 index 000000000..26f93fae8 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt @@ -0,0 +1,125 @@ +package kscience.kmath.operations + +import kscience.kmath.misc.UnstableKMathAPI + +/** + * An algebraic structure where elements can have numeric representation. + * + * @param T the type of element of this structure. + */ +public interface NumericAlgebra : Algebra { + /** + * Wraps a number to [T] object. + * + * @param value the number to wrap. + * @return an object. + */ + public fun number(value: Number): T + + /** + * Dynamically dispatches a binary operation with the certain name with numeric first argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with the other [leftSideNumberOperation] overload: + * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b)`. + * + * @param operation the name of operation. + * @return an operation. + */ + public fun leftSideNumberOperationFunction(operation: String): (left: Number, right: T) -> T = + { l, r -> binaryOperationFunction(operation)(number(l), r) } + + /** + * Dynamically invokes a binary operation with the certain name with numeric first argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second [leftSideNumberOperation] overload: + * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. + * + * @param operation the name of operation. + * @param left the first argument of operation. + * @param right the second argument of operation. + * @return a result of operation. + */ + public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = + leftSideNumberOperationFunction(operation)(left, right) + + /** + * Dynamically dispatches a binary operation with the certain name with numeric first argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: + * i.e. `rightSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. + * + * @param operation the name of operation. + * @return an operation. + */ + public fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = + { l, r -> binaryOperationFunction(operation)(l, number(r)) } + + /** + * Dynamically invokes a binary operation with the certain name with numeric second argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: + * i.e. `rightSideNumberOperationFunction(a)(b, c) == rightSideNumberOperation(a, b, c)`. + * + * @param operation the name of operation. + * @param left the first argument of operation. + * @param right the second argument of operation. + * @return a result of operation. + */ + public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = + rightSideNumberOperationFunction(operation)(left, right) +} + +/** + * A combination of [NumericAlgebra] and [Ring] that adds intrinsic simple operations on numbers like `T+1` + * TODO to be removed and replaced by extensions after multiple receivers are there + */ +@UnstableKMathAPI +public interface RingWithNumbers: Ring, NumericAlgebra{ + public override fun number(value: Number): T = one * value + + /** + * Addition of element and scalar. + * + * @receiver the addend. + * @param b the augend. + */ + public operator fun T.plus(b: Number): T = this + number(b) + + /** + * Addition of scalar and element. + * + * @receiver the addend. + * @param b the augend. + */ + public operator fun Number.plus(b: T): T = b + this + + /** + * Subtraction of element from number. + * + * @receiver the minuend. + * @param b the subtrahend. + * @receiver the difference. + */ + public operator fun T.minus(b: Number): T = this - number(b) + + /** + * Subtraction of number from element. + * + * @receiver the minuend. + * @param b the subtrahend. + * @receiver the difference. + */ + public operator fun Number.minus(b: T): T = -b + this +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/OptionalOperations.kt similarity index 52% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/operations/OptionalOperations.kt index 1dac649aa..f31d61ae1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/OptionalOperations.kt @@ -1,234 +1,234 @@ -package scientifik.kmath.operations +package kscience.kmath.operations /** * A container for trigonometric operations for specific type. * * @param T the type of element of this structure. */ -interface TrigonometricOperations : Algebra { +public interface TrigonometricOperations : Algebra { /** * Computes the sine of [arg]. */ - fun sin(arg: T): T + public fun sin(arg: T): T /** * Computes the cosine of [arg]. */ - fun cos(arg: T): T + public fun cos(arg: T): T /** * Computes the tangent of [arg]. */ - fun tan(arg: T): T + public fun tan(arg: T): T /** * Computes the inverse sine of [arg]. */ - fun asin(arg: T): T + public fun asin(arg: T): T /** * Computes the inverse cosine of [arg]. */ - fun acos(arg: T): T + public fun acos(arg: T): T /** * Computes the inverse tangent of [arg]. */ - fun atan(arg: T): T + public fun atan(arg: T): T - companion object { + public companion object { /** * The identifier of sine. */ - const val SIN_OPERATION: String = "sin" + public const val SIN_OPERATION: String = "sin" /** * The identifier of cosine. */ - const val COS_OPERATION: String = "cos" + public const val COS_OPERATION: String = "cos" /** * The identifier of tangent. */ - const val TAN_OPERATION: String = "tan" + public const val TAN_OPERATION: String = "tan" /** * The identifier of inverse sine. */ - const val ASIN_OPERATION: String = "asin" + public const val ASIN_OPERATION: String = "asin" /** * The identifier of inverse cosine. */ - const val ACOS_OPERATION: String = "acos" + public const val ACOS_OPERATION: String = "acos" /** * The identifier of inverse tangent. */ - const val ATAN_OPERATION: String = "atan" + public const val ATAN_OPERATION: String = "atan" } } /** * Computes the sine of [arg]. */ -fun >> sin(arg: T): T = arg.context.sin(arg) +public fun >> sin(arg: T): T = arg.context.sin(arg) /** * Computes the cosine of [arg]. */ -fun >> cos(arg: T): T = arg.context.cos(arg) +public fun >> cos(arg: T): T = arg.context.cos(arg) /** * Computes the tangent of [arg]. */ -fun >> tan(arg: T): T = arg.context.tan(arg) +public fun >> tan(arg: T): T = arg.context.tan(arg) /** * Computes the inverse sine of [arg]. */ -fun >> asin(arg: T): T = arg.context.asin(arg) +public fun >> asin(arg: T): T = arg.context.asin(arg) /** * Computes the inverse cosine of [arg]. */ -fun >> acos(arg: T): T = arg.context.acos(arg) +public fun >> acos(arg: T): T = arg.context.acos(arg) /** * Computes the inverse tangent of [arg]. */ -fun >> atan(arg: T): T = arg.context.atan(arg) +public fun >> atan(arg: T): T = arg.context.atan(arg) /** * A container for hyperbolic trigonometric operations for specific type. * * @param T the type of element of this structure. */ -interface HyperbolicOperations : Algebra { +public interface HyperbolicOperations : Algebra { /** * Computes the hyperbolic sine of [arg]. */ - fun sinh(arg: T): T + public fun sinh(arg: T): T /** * Computes the hyperbolic cosine of [arg]. */ - fun cosh(arg: T): T + public fun cosh(arg: T): T /** * Computes the hyperbolic tangent of [arg]. */ - fun tanh(arg: T): T + public fun tanh(arg: T): T /** * Computes the inverse hyperbolic sine of [arg]. */ - fun asinh(arg: T): T + public fun asinh(arg: T): T /** * Computes the inverse hyperbolic cosine of [arg]. */ - fun acosh(arg: T): T + public fun acosh(arg: T): T /** * Computes the inverse hyperbolic tangent of [arg]. */ - fun atanh(arg: T): T + public fun atanh(arg: T): T - companion object { + public companion object { /** * The identifier of hyperbolic sine. */ - const val SINH_OPERATION: String = "sinh" + public const val SINH_OPERATION: String = "sinh" /** * The identifier of hyperbolic cosine. */ - const val COSH_OPERATION: String = "cosh" + public const val COSH_OPERATION: String = "cosh" /** * The identifier of hyperbolic tangent. */ - const val TANH_OPERATION: String = "tanh" + public const val TANH_OPERATION: String = "tanh" /** * The identifier of inverse hyperbolic sine. */ - const val ASINH_OPERATION: String = "asinh" + public const val ASINH_OPERATION: String = "asinh" /** * The identifier of inverse hyperbolic cosine. */ - const val ACOSH_OPERATION: String = "acosh" + public const val ACOSH_OPERATION: String = "acosh" /** * The identifier of inverse hyperbolic tangent. */ - const val ATANH_OPERATION: String = "atanh" + public const val ATANH_OPERATION: String = "atanh" } } /** * Computes the hyperbolic sine of [arg]. */ -fun >> sinh(arg: T): T = arg.context.sinh(arg) +public fun >> sinh(arg: T): T = arg.context.sinh(arg) /** * Computes the hyperbolic cosine of [arg]. */ -fun >> cosh(arg: T): T = arg.context.cosh(arg) +public fun >> cosh(arg: T): T = arg.context.cosh(arg) /** * Computes the hyperbolic tangent of [arg]. */ -fun >> tanh(arg: T): T = arg.context.tanh(arg) +public fun >> tanh(arg: T): T = arg.context.tanh(arg) /** * Computes the inverse hyperbolic sine of [arg]. */ -fun >> asinh(arg: T): T = arg.context.asinh(arg) +public fun >> asinh(arg: T): T = arg.context.asinh(arg) /** * Computes the inverse hyperbolic cosine of [arg]. */ -fun >> acosh(arg: T): T = arg.context.acosh(arg) +public fun >> acosh(arg: T): T = arg.context.acosh(arg) /** * Computes the inverse hyperbolic tangent of [arg]. */ -fun >> atanh(arg: T): T = arg.context.atanh(arg) +public fun >> atanh(arg: T): T = arg.context.atanh(arg) /** * A context extension to include power operations based on exponentiation. * * @param T the type of element of this structure. */ -interface PowerOperations : Algebra { +public interface PowerOperations : Algebra { /** * Raises [arg] to the power [pow]. */ - fun power(arg: T, pow: Number): T + public fun power(arg: T, pow: Number): T /** * Computes the square root of the value [arg]. */ - fun sqrt(arg: T): T = power(arg, 0.5) + public fun sqrt(arg: T): T = power(arg, 0.5) /** * Raises this value to the power [pow]. */ - infix fun T.pow(pow: Number): T = power(this, pow) + public infix fun T.pow(pow: Number): T = power(this, pow) - companion object { + public companion object { /** * The identifier of exponentiation. */ - const val POW_OPERATION: String = "pow" + public const val POW_OPERATION: String = "pow" /** * The identifier of square root. */ - const val SQRT_OPERATION: String = "sqrt" + public const val SQRT_OPERATION: String = "sqrt" } } @@ -239,56 +239,56 @@ interface PowerOperations : Algebra { * @param power the exponent. * @return the base raised to the power. */ -infix fun >> T.pow(power: Double): T = context.power(this, power) +public infix fun >> T.pow(power: Double): T = context.power(this, power) /** * Computes the square root of the value [arg]. */ -fun >> sqrt(arg: T): T = arg pow 0.5 +public fun >> sqrt(arg: T): T = arg pow 0.5 /** * Computes the square of the value [arg]. */ -fun >> sqr(arg: T): T = arg pow 2.0 +public fun >> sqr(arg: T): T = arg pow 2.0 /** * A container for operations related to `exp` and `ln` functions. * * @param T the type of element of this structure. */ -interface ExponentialOperations : Algebra { +public interface ExponentialOperations : Algebra { /** * Computes Euler's number `e` raised to the power of the value [arg]. */ - fun exp(arg: T): T + public fun exp(arg: T): T /** * Computes the natural logarithm (base `e`) of the value [arg]. */ - fun ln(arg: T): T + public fun ln(arg: T): T - companion object { + public companion object { /** * The identifier of exponential function. */ - const val EXP_OPERATION: String = "exp" + public const val EXP_OPERATION: String = "exp" /** * The identifier of natural logarithm. */ - const val LN_OPERATION: String = "ln" + public const val LN_OPERATION: String = "ln" } } /** * The identifier of exponential function. */ -fun >> exp(arg: T): T = arg.context.exp(arg) +public fun >> exp(arg: T): T = arg.context.exp(arg) /** * The identifier of natural logarithm. */ -fun >> ln(arg: T): T = arg.context.ln(arg) +public fun >> ln(arg: T): T = arg.context.ln(arg) /** * A container for norm functional on element. @@ -296,14 +296,14 @@ fun >> ln(arg: T): T = arg.context. * @param T the type of element having norm defined. * @param R the type of norm. */ -interface Norm { +public interface Norm { /** * Computes the norm of [arg] (i.e. absolute value or vector length). */ - fun norm(arg: T): R + public fun norm(arg: T): R } /** * Computes the norm of [arg] (i.e. absolute value or vector length). */ -fun >, R> norm(arg: T): R = arg.context.norm(arg) +public fun >, R> norm(arg: T): R = arg.context.norm(arg) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt new file mode 100644 index 000000000..0440d74e8 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt @@ -0,0 +1,280 @@ +package kscience.kmath.operations + +import kotlin.math.pow as kpow + +/** + * Advanced Number-like semifield that implements basic operations. + */ +public interface ExtendedFieldOperations : + FieldOperations, + TrigonometricOperations, + HyperbolicOperations, + PowerOperations, + ExponentialOperations { + public override fun tan(arg: T): T = sin(arg) / cos(arg) + public override fun tanh(arg: T): T = sinh(arg) / cosh(arg) + + public override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { + TrigonometricOperations.COS_OPERATION -> ::cos + TrigonometricOperations.SIN_OPERATION -> ::sin + TrigonometricOperations.TAN_OPERATION -> ::tan + TrigonometricOperations.ACOS_OPERATION -> ::acos + TrigonometricOperations.ASIN_OPERATION -> ::asin + TrigonometricOperations.ATAN_OPERATION -> ::atan + HyperbolicOperations.COSH_OPERATION -> ::cosh + HyperbolicOperations.SINH_OPERATION -> ::sinh + HyperbolicOperations.TANH_OPERATION -> ::tanh + HyperbolicOperations.ACOSH_OPERATION -> ::acosh + HyperbolicOperations.ASINH_OPERATION -> ::asinh + HyperbolicOperations.ATANH_OPERATION -> ::atanh + PowerOperations.SQRT_OPERATION -> ::sqrt + ExponentialOperations.EXP_OPERATION -> ::exp + ExponentialOperations.LN_OPERATION -> ::ln + else -> super.unaryOperationFunction(operation) + } +} + +/** + * Advanced Number-like field that implements basic operations. + */ +public interface ExtendedField : ExtendedFieldOperations, Field, NumericAlgebra { + public override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2 + public override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2 + public override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) + public override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg) + public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) + public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2 + + public override fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.rightSideNumberOperationFunction(operation) + } +} + +/** + * Real field element wrapping double. + * + * @property value the [Double] value wrapped by this [Real]. + * + * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 + */ +public inline class Real(public val value: Double) : FieldElement { + public override val context: RealField + get() = RealField + + public override fun unwrap(): Double = value + public override fun Double.wrap(): Real = Real(value) + + public companion object +} + +/** + * A field for [Double] without boxing. Does not produce appropriate field element. + */ +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public object RealField : ExtendedField, Norm { + public override val zero: Double + get() = 0.0 + + public override val one: Double + get() = 1.0 + + override fun number(value: Number): Double = value.toDouble() + + public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.binaryOperationFunction(operation) + } + + public override inline fun add(a: Double, b: Double): Double = a + b + public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() + + public override inline fun multiply(a: Double, b: Double): Double = a * b + + public override inline fun divide(a: Double, b: Double): Double = a / b + + public override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) + public override inline fun cos(arg: Double): Double = kotlin.math.cos(arg) + public override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) + public override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) + public override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) + public override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) + + public override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg) + public override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg) + public override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg) + public override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg) + public override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg) + public override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg) + + public override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) + public override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) + public override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) + + public override inline fun norm(arg: Double): Double = abs(arg) + + public override inline fun Double.unaryMinus(): Double = -this + public override inline fun Double.plus(b: Double): Double = this + b + public override inline fun Double.minus(b: Double): Double = this - b + public override inline fun Double.times(b: Double): Double = this * b + public override inline fun Double.div(b: Double): Double = this / b +} + +/** + * A field for [Float] without boxing. Does not produce appropriate field element. + */ +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public object FloatField : ExtendedField, Norm { + public override val zero: Float + get() = 0.0f + + public override val one: Float + get() = 1.0f + + override fun number(value: Number): Float = value.toFloat() + + public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.binaryOperationFunction(operation) + } + + public override inline fun add(a: Float, b: Float): Float = a + b + public override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() + + public override inline fun multiply(a: Float, b: Float): Float = a * b + + public override inline fun divide(a: Float, b: Float): Float = a / b + + public override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) + public override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) + public override inline fun tan(arg: Float): Float = kotlin.math.tan(arg) + public override inline fun acos(arg: Float): Float = kotlin.math.acos(arg) + public override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) + public override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) + + public override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg) + public override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg) + public override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg) + public override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg) + public override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg) + public override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg) + + public override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat()) + public override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) + public override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) + + public override inline fun norm(arg: Float): Float = abs(arg) + + public override inline fun Float.unaryMinus(): Float = -this + public override inline fun Float.plus(b: Float): Float = this + b + public override inline fun Float.minus(b: Float): Float = this - b + public override inline fun Float.times(b: Float): Float = this * b + public override inline fun Float.div(b: Float): Float = this / b +} + +/** + * A field for [Int] without boxing. Does not produce corresponding ring element. + */ +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public object IntRing : Ring, Norm, NumericAlgebra { + public override val zero: Int + get() = 0 + + public override val one: Int + get() = 1 + + override fun number(value: Number): Int = value.toInt() + + public override inline fun add(a: Int, b: Int): Int = a + b + public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a + + public override inline fun multiply(a: Int, b: Int): Int = a * b + + public override inline fun norm(arg: Int): Int = abs(arg) + + public override inline fun Int.unaryMinus(): Int = -this + public override inline fun Int.plus(b: Int): Int = this + b + public override inline fun Int.minus(b: Int): Int = this - b + public override inline fun Int.times(b: Int): Int = this * b +} + +/** + * A field for [Short] without boxing. Does not produce appropriate ring element. + */ +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public object ShortRing : Ring, Norm, NumericAlgebra { + public override val zero: Short + get() = 0 + + public override val one: Short + get() = 1 + + override fun number(value: Number): Short = value.toShort() + + public override inline fun add(a: Short, b: Short): Short = (a + b).toShort() + public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() + + public override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() + + public override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() + + public override inline fun Short.unaryMinus(): Short = (-this).toShort() + public override inline fun Short.plus(b: Short): Short = (this + b).toShort() + public override inline fun Short.minus(b: Short): Short = (this - b).toShort() + public override inline fun Short.times(b: Short): Short = (this * b).toShort() +} + +/** + * A field for [Byte] without boxing. Does not produce appropriate ring element. + */ +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public object ByteRing : Ring, Norm, NumericAlgebra { + public override val zero: Byte + get() = 0 + + public override val one: Byte + get() = 1 + + override fun number(value: Number): Byte = value.toByte() + + public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() + public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() + + public override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() + + public override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() + + public override inline fun Byte.unaryMinus(): Byte = (-this).toByte() + public override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() + public override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() + public override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() +} + +/** + * A field for [Double] without boxing. Does not produce appropriate ring element. + */ +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public object LongRing : Ring, Norm, NumericAlgebra { + public override val zero: Long + get() = 0L + + public override val one: Long + get() = 1L + + override fun number(value: Number): Long = value.toLong() + + public override inline fun add(a: Long, b: Long): Long = a + b + public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() + + public override inline fun multiply(a: Long, b: Long): Long = a * b + + public override fun norm(arg: Long): Long = abs(arg) + + public override inline fun Long.unaryMinus(): Long = (-this) + public override inline fun Long.plus(b: Long): Long = (this + b) + public override inline fun Long.minus(b: Long): Long = (this - b) + public override inline fun Long.times(b: Long): Long = (this * b) +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt similarity index 56% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt index be71645d1..dc65b12c4 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDField.kt @@ -1,30 +1,31 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.FieldElement +import kscience.kmath.operations.Field +import kscience.kmath.operations.FieldElement -class BoxingNDField>( - override val shape: IntArray, - override val elementContext: F, - val bufferFactory: BufferFactory +public class BoxingNDField>( + public override val shape: IntArray, + public override val elementContext: F, + public val bufferFactory: BufferFactory ) : BufferedNDField { - override val zero: BufferedNDFieldElement by lazy { produce { zero } } - override val one: BufferedNDFieldElement by lazy { produce { one } } - override val strides: Strides = DefaultStrides(shape) + public override val zero: BufferedNDFieldElement by lazy { produce { zero } } + public override val one: BufferedNDFieldElement by lazy { produce { one } } + public override val strides: Strides = DefaultStrides(shape) - fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = + public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) - override fun check(vararg elements: NDBuffer) { - check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + public override fun check(vararg elements: NDBuffer): Array> { + require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + return elements } - override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement = + public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement = BufferedNDFieldElement( this, buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }) - override fun map(arg: NDBuffer, transform: F.(T) -> T): BufferedNDFieldElement { + public override fun map(arg: NDBuffer, transform: F.(T) -> T): BufferedNDFieldElement { check(arg) return BufferedNDFieldElement( @@ -36,7 +37,7 @@ class BoxingNDField>( } - override fun mapIndexed( + public override fun mapIndexed( arg: NDBuffer, transform: F.(index: IntArray, T) -> T ): BufferedNDFieldElement { @@ -55,7 +56,7 @@ class BoxingNDField>( // return BufferedNDFieldElement(this, buffer) } - override fun combine( + public override fun combine( a: NDBuffer, b: NDBuffer, transform: F.(T, T) -> T @@ -66,15 +67,15 @@ class BoxingNDField>( buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) } - override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = + public override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = BufferedNDFieldElement(this@BoxingNDField, buffer) } -inline fun , R> F.nd( +public inline fun , R> F.nd( noinline bufferFactory: BufferFactory, vararg shape: Int, action: NDField.() -> R ): R { - val ndfield: BoxingNDField = NDField.boxing(this, *shape, bufferFactory = bufferFactory) + val ndfield = NDField.boxing(this, *shape, bufferFactory = bufferFactory) return ndfield.action() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDRing.kt similarity index 79% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDRing.kt index 91b945e79..b6794984c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BoxingNDRing.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BoxingNDRing.kt @@ -1,21 +1,22 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.RingElement +import kscience.kmath.operations.Ring +import kscience.kmath.operations.RingElement -class BoxingNDRing>( +public class BoxingNDRing>( override val shape: IntArray, override val elementContext: R, - val bufferFactory: BufferFactory + public val bufferFactory: BufferFactory ) : BufferedNDRing { override val strides: Strides = DefaultStrides(shape) override val zero: BufferedNDRingElement by lazy { produce { zero } } override val one: BufferedNDRingElement by lazy { produce { one } } - fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) + public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer = bufferFactory(size, initializer) - override fun check(vararg elements: NDBuffer) { - require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } + override fun check(vararg elements: NDBuffer): Array> { + if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides") + return elements } override fun produce(initializer: R.(IntArray) -> T): BufferedNDRingElement = @@ -59,6 +60,7 @@ class BoxingNDRing>( transform: R.(T, T) -> T ): BufferedNDRingElement { check(a, b) + return BufferedNDRingElement( this, buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferAccessor2D.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferAccessor2D.kt new file mode 100644 index 000000000..5d7ba611f --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferAccessor2D.kt @@ -0,0 +1,48 @@ +package kscience.kmath.structures + +/** + * A context that allows to operate on a [MutableBuffer] as on 2d array + */ +internal class BufferAccessor2D( + public val rowNum: Int, + public val colNum: Int, + val factory: MutableBufferFactory, +) { + public operator fun Buffer.get(i: Int, j: Int): T = get(i + colNum * j) + + public operator fun MutableBuffer.set(i: Int, j: Int, value: T) { + set(i + colNum * j, value) + } + + public inline fun create(crossinline init: (i: Int, j: Int) -> T): MutableBuffer = + factory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) } + + public fun create(mat: Structure2D): MutableBuffer = create { i, j -> mat[i, j] } + + //TODO optimize wrapper + public fun MutableBuffer.collect(): Structure2D = NDStructure.build( + DefaultStrides(intArrayOf(rowNum, colNum)), + factory + ) { (i, j) -> + get(i, j) + }.as2D() + + public inner class Row(public val buffer: MutableBuffer, public val rowIndex: Int) : MutableBuffer { + override val size: Int get() = colNum + + override operator fun get(index: Int): T = buffer[rowIndex, index] + + override operator fun set(index: Int, value: T) { + buffer[rowIndex, index] = value + } + + override fun copy(): MutableBuffer = factory(colNum) { get(it) } + override operator fun iterator(): Iterator = (0 until colNum).map(::get).iterator() + + } + + /** + * Get row + */ + public fun MutableBuffer.row(i: Int): Row = Row(this, i) +} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt new file mode 100644 index 000000000..3dcd0322c --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDAlgebra.kt @@ -0,0 +1,43 @@ +package kscience.kmath.structures + +import kscience.kmath.operations.* + +public interface BufferedNDAlgebra : NDAlgebra> { + public val strides: Strides + + public override fun check(vararg elements: NDBuffer): Array> { + require(elements.all { it.strides == strides }) { "Strides mismatch" } + return elements + } + + /** + * Convert any [NDStructure] to buffered structure using strides from this context. + * If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over + * indices. + * + * If the argument is [NDBuffer] with different strides structure, the new element will be produced. + */ + public fun NDStructure.toBuffer(): NDBuffer = + if (this is NDBuffer && this.strides == this@BufferedNDAlgebra.strides) + this + else + produce { index -> this@toBuffer[index] } + + /** + * Convert a buffer to element of this algebra + */ + public fun NDBuffer.toElement(): MathElement> +} + + +public interface BufferedNDSpace> : NDSpace>, BufferedNDAlgebra { + public override fun NDBuffer.toElement(): SpaceElement, *, out BufferedNDSpace> +} + +public interface BufferedNDRing> : NDRing>, BufferedNDSpace { + override fun NDBuffer.toElement(): RingElement, *, out BufferedNDRing> +} + +public interface BufferedNDField> : NDField>, BufferedNDRing { + override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDElement.kt similarity index 70% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDElement.kt index 20e34fadd..d53702566 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDElement.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/BufferedNDElement.kt @@ -1,11 +1,11 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.* +import kscience.kmath.operations.* /** * Base class for an element with context, containing strides */ -abstract class BufferedNDElement : NDBuffer(), NDElement> { +public abstract class BufferedNDElement : NDBuffer(), NDElement> { abstract override val context: BufferedNDAlgebra override val strides: Strides get() = context.strides @@ -13,7 +13,7 @@ abstract class BufferedNDElement : NDBuffer(), NDElement>( +public class BufferedNDSpaceElement>( override val context: BufferedNDSpace, override val buffer: Buffer ) : BufferedNDElement(), SpaceElement, BufferedNDSpaceElement, BufferedNDSpace> { @@ -26,7 +26,7 @@ class BufferedNDSpaceElement>( } } -class BufferedNDRingElement>( +public class BufferedNDRingElement>( override val context: BufferedNDRing, override val buffer: Buffer ) : BufferedNDElement(), RingElement, BufferedNDRingElement, BufferedNDRing> { @@ -38,7 +38,7 @@ class BufferedNDRingElement>( } } -class BufferedNDFieldElement>( +public class BufferedNDFieldElement>( override val context: BufferedNDField, override val buffer: Buffer ) : BufferedNDElement(), FieldElement, BufferedNDFieldElement, BufferedNDField> { @@ -54,7 +54,7 @@ class BufferedNDFieldElement>( /** * Element by element application of any operation on elements to the whole array. Just like in numpy. */ -operator fun > Function1.invoke(ndElement: BufferedNDElement): MathElement> = +public operator fun > Function1.invoke(ndElement: BufferedNDElement): MathElement> = ndElement.context.run { map(ndElement) { invoke(it) }.toElement() } /* plus and minus */ @@ -62,13 +62,13 @@ operator fun > Function1.invoke(ndElement: BufferedN /** * Summation operation for [BufferedNDElement] and single element */ -operator fun > BufferedNDElement.plus(arg: T): NDElement> = +public operator fun > BufferedNDElement.plus(arg: T): NDElement> = context.map(this) { it + arg }.wrap() /** * Subtraction operation between [BufferedNDElement] and single element */ -operator fun > BufferedNDElement.minus(arg: T): NDElement> = +public operator fun > BufferedNDElement.minus(arg: T): NDElement> = context.map(this) { it - arg }.wrap() /* prod and div */ @@ -76,11 +76,11 @@ operator fun > BufferedNDElement.minus(arg: T): NDEl /** * Product operation for [BufferedNDElement] and single element */ -operator fun > BufferedNDElement.times(arg: T): NDElement> = +public operator fun > BufferedNDElement.times(arg: T): NDElement> = context.map(this) { it * arg }.wrap() /** * Division operation between [BufferedNDElement] and single element */ -operator fun > BufferedNDElement.div(arg: T): NDElement> = +public operator fun > BufferedNDElement.div(arg: T): NDElement> = context.map(this) { it / arg }.wrap() diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt new file mode 100644 index 000000000..bfec6f871 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt @@ -0,0 +1,299 @@ +package kscience.kmath.structures + +import kscience.kmath.operations.Complex +import kscience.kmath.operations.complex +import kotlin.reflect.KClass + +/** + * Function that produces [Buffer] from its size and function that supplies values. + * + * @param T the type of buffer. + */ +public typealias BufferFactory = (Int, (Int) -> T) -> Buffer + +/** + * Function that produces [MutableBuffer] from its size and function that supplies values. + * + * @param T the type of buffer. + */ +public typealias MutableBufferFactory = (Int, (Int) -> T) -> MutableBuffer + +/** + * A generic immutable random-access structure for both primitives and objects. + * + * @param T the type of elements contained in the buffer. + */ +public interface Buffer { + /** + * The size of this buffer. + */ + public val size: Int + + /** + * Gets element at given index. + */ + public operator fun get(index: Int): T + + /** + * Iterates over all elements. + */ + public operator fun iterator(): Iterator + + /** + * Checks content equality with another buffer. + */ + public fun contentEquals(other: Buffer<*>): Boolean = + asSequence().mapIndexed { index, value -> value == other[index] }.all { it } + + public companion object { + /** + * Creates a [RealBuffer] with the specified [size], where each element is calculated by calling the specified + * [initializer] function. + */ + public inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer = + RealBuffer(size) { initializer(it) } + + /** + * Creates a [ListBuffer] of given type [T] with given [size]. Each element is calculated by calling the + * specified [initializer] function. + */ + public inline fun boxing(size: Int, initializer: (Int) -> T): Buffer = + ListBuffer(List(size, initializer)) + + // TODO add resolution based on Annotation or companion resolution + + /** + * Creates a [Buffer] of given [type]. If the type is primitive, specialized buffers are used ([IntBuffer], + * [RealBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): Buffer = + when (type) { + Double::class -> RealBuffer(size) { initializer(it) as Double } as Buffer + Short::class -> ShortBuffer(size) { initializer(it) as Short } as Buffer + Int::class -> IntBuffer(size) { initializer(it) as Int } as Buffer + Long::class -> LongBuffer(size) { initializer(it) as Long } as Buffer + Float::class -> FloatBuffer(size) { initializer(it) as Float } as Buffer + Complex::class -> complex(size) { initializer(it) as Complex } as Buffer + else -> boxing(size, initializer) + } + + /** + * Creates a [Buffer] of given type [T]. If the type is primitive, specialized buffers are used ([IntBuffer], + * [RealBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(size: Int, initializer: (Int) -> T): Buffer = + auto(T::class, size, initializer) + } +} + +/** + * Creates a sequence that returns all elements from this [Buffer]. + */ +public fun Buffer.asSequence(): Sequence = Sequence(::iterator) + +/** + * Creates an iterable that returns all elements from this [Buffer]. + */ +public fun Buffer.asIterable(): Iterable = Iterable(::iterator) + +/** + * Converts this [Buffer] to a new [List] + */ +public fun Buffer.toList(): List = asSequence().toList() + +/** + * Returns an [IntRange] of the valid indices for this [Buffer]. + */ +public val Buffer<*>.indices: IntRange get() = 0 until size + +/** + * A generic mutable random-access structure for both primitives and objects. + * + * @param T the type of elements contained in the buffer. + */ +public interface MutableBuffer : Buffer { + /** + * Sets the array element at the specified [index] to the specified [value]. + */ + public operator fun set(index: Int, value: T) + + /** + * Returns a shallow copy of the buffer. + */ + public fun copy(): MutableBuffer + + public companion object { + /** + * Create a boxing mutable buffer of given type + */ + public inline fun boxing(size: Int, initializer: (Int) -> T): MutableBuffer = + MutableListBuffer(MutableList(size, initializer)) + + /** + * Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used + * ([IntBuffer], [RealBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer = + when (type) { + Double::class -> RealBuffer(size) { initializer(it) as Double } as MutableBuffer + Short::class -> ShortBuffer(size) { initializer(it) as Short } as MutableBuffer + Int::class -> IntBuffer(size) { initializer(it) as Int } as MutableBuffer + Float::class -> FloatBuffer(size) { initializer(it) as Float } as MutableBuffer + Long::class -> LongBuffer(size) { initializer(it) as Long } as MutableBuffer + Complex::class -> complex(size) { initializer(it) as Complex } as MutableBuffer + else -> boxing(size, initializer) + } + + /** + * Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used + * ([IntBuffer], [RealBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(size: Int, initializer: (Int) -> T): MutableBuffer = + auto(T::class, size, initializer) + + /** + * Creates a [RealBuffer] with the specified [size], where each element is calculated by calling the specified + * [initializer] function. + */ + public inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer = + RealBuffer(size) { initializer(it) } + } +} + +/** + * [Buffer] implementation over [List]. + * + * @param T the type of elements contained in the buffer. + * @property list The underlying list. + */ +public inline class ListBuffer(public val list: List) : Buffer { + override val size: Int + get() = list.size + + override operator fun get(index: Int): T = list[index] + override operator fun iterator(): Iterator = list.iterator() +} + +/** + * Returns an [ListBuffer] that wraps the original list. + */ +public fun List.asBuffer(): ListBuffer = ListBuffer(this) + +/** + * Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified + * [init] function. + * + * The function [init] is called for each array element sequentially starting from the first one. + * It should return the value for an array element given its index. + */ +public inline fun ListBuffer(size: Int, init: (Int) -> T): ListBuffer = List(size, init).asBuffer() + +/** + * [MutableBuffer] implementation over [MutableList]. + * + * @param T the type of elements contained in the buffer. + * @property list The underlying list. + */ +public inline class MutableListBuffer(public val list: MutableList) : MutableBuffer { + override val size: Int + get() = list.size + + override operator fun get(index: Int): T = list[index] + + override operator fun set(index: Int, value: T) { + list[index] = value + } + + override operator fun iterator(): Iterator = list.iterator() + override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) +} + +/** + * [MutableBuffer] implementation over [Array]. + * + * @param T the type of elements contained in the buffer. + * @property array The underlying array. + */ +public class ArrayBuffer(private val array: Array) : MutableBuffer { + // Can't inline because array is invariant + override val size: Int + get() = array.size + + override operator fun get(index: Int): T = array[index] + + override operator fun set(index: Int, value: T) { + array[index] = value + } + + override operator fun iterator(): Iterator = array.iterator() + override fun copy(): MutableBuffer = ArrayBuffer(array.copyOf()) +} + +/** + * Returns an [ArrayBuffer] that wraps the original array. + */ +public fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) + +/** + * Immutable wrapper for [MutableBuffer]. + * + * @param T the type of elements contained in the buffer. + * @property buffer The underlying buffer. + */ +public inline class ReadOnlyBuffer(public val buffer: MutableBuffer) : Buffer { + override val size: Int get() = buffer.size + + override operator fun get(index: Int): T = buffer[index] + + override operator fun iterator(): Iterator = buffer.iterator() +} + +/** + * A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call. + * Useful when one needs single element from the buffer. + * + * @param T the type of elements provided by the buffer. + */ +public class VirtualBuffer(override val size: Int, private val generator: (Int) -> T) : Buffer { + override operator fun get(index: Int): T { + if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index") + return generator(index) + } + + override operator fun iterator(): Iterator = (0 until size).asSequence().map(generator).iterator() + + override fun contentEquals(other: Buffer<*>): Boolean { + return if (other is VirtualBuffer) { + this.size == other.size && this.generator == other.generator + } else { + super.contentEquals(other) + } + } +} + +/** + * Convert this buffer to read-only buffer. + */ +public fun Buffer.asReadOnly(): Buffer = if (this is MutableBuffer) ReadOnlyBuffer(this) else this + +/** + * Typealias for buffer transformations. + */ +public typealias BufferTransform = (Buffer) -> Buffer + +/** + * Typealias for buffer transformations with suspend function. + */ +public typealias SuspendBufferTransform = suspend (Buffer) -> Buffer diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt similarity index 69% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt index 2c6e3a5c7..6de69cabe 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt @@ -1,28 +1,32 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.Complex -import scientifik.kmath.operations.ComplexField -import scientifik.kmath.operations.FieldElement -import scientifik.kmath.operations.complex -import kotlin.contracts.ExperimentalContracts +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.* import kotlin.contracts.InvocationKind import kotlin.contracts.contract -typealias ComplexNDElement = BufferedNDFieldElement +public typealias ComplexNDElement = BufferedNDFieldElement /** * An optimized nd-field for complex numbers */ -class ComplexNDField(override val shape: IntArray) : +@OptIn(UnstableKMathAPI::class) +public class ComplexNDField(override val shape: IntArray) : BufferedNDField, - ExtendedNDField> { + ExtendedNDField>, + RingWithNumbers>{ override val strides: Strides = DefaultStrides(shape) override val elementContext: ComplexField get() = ComplexField override val zero: ComplexNDElement by lazy { produce { zero } } override val one: ComplexNDElement by lazy { produce { one } } - inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer = + override fun number(value: Number): NDBuffer { + val c = value.toComplex() + return produce { c } + } + + public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer = Buffer.complex(size) { initializer(it) } /** @@ -30,7 +34,7 @@ class ComplexNDField(override val shape: IntArray) : */ override fun map( arg: NDBuffer, - transform: ComplexField.(Complex) -> Complex + transform: ComplexField.(Complex) -> Complex, ): ComplexNDElement { check(arg) val array = buildBuffer(arg.strides.linearSize) { offset -> ComplexField.transform(arg.buffer[offset]) } @@ -44,7 +48,7 @@ class ComplexNDField(override val shape: IntArray) : override fun mapIndexed( arg: NDBuffer, - transform: ComplexField.(index: IntArray, Complex) -> Complex + transform: ComplexField.(index: IntArray, Complex) -> Complex, ): ComplexNDElement { check(arg) @@ -61,7 +65,7 @@ class ComplexNDField(override val shape: IntArray) : override fun combine( a: NDBuffer, b: NDBuffer, - transform: ComplexField.(Complex, Complex) -> Complex + transform: ComplexField.(Complex, Complex) -> Complex, ): ComplexNDElement { check(a, b) @@ -98,7 +102,7 @@ class ComplexNDField(override val shape: IntArray) : /** * Fast element production using function inlining */ -inline fun BufferedNDField.produceInline(crossinline initializer: ComplexField.(Int) -> Complex): ComplexNDElement { +public inline fun BufferedNDField.produceInline(initializer: ComplexField.(Int) -> Complex): ComplexNDElement { val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) } return BufferedNDFieldElement(this, buffer) } @@ -106,14 +110,13 @@ inline fun BufferedNDField.produceInline(crossinline init /** * Map one [ComplexNDElement] using function with indices. */ -inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement = +public inline fun ComplexNDElement.mapIndexed(transform: ComplexField.(index: IntArray, Complex) -> Complex): ComplexNDElement = context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } /** * Map one [ComplexNDElement] using function without indices. */ -inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement { - contract { callsInPlace(transform) } +public inline fun ComplexNDElement.map(transform: ComplexField.(Complex) -> Complex): ComplexNDElement { val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) } return BufferedNDFieldElement(context, buffer) } @@ -121,38 +124,35 @@ inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> /** * Element by element application of any operation on elements to the whole array. Just like in numpy */ -operator fun Function1.invoke(ndElement: ComplexNDElement): ComplexNDElement = +public operator fun Function1.invoke(ndElement: ComplexNDElement): ComplexNDElement = ndElement.map { this@invoke(it) } - /* plus and minus */ /** * Summation operation for [BufferedNDElement] and single element */ -operator fun ComplexNDElement.plus(arg: Complex): ComplexNDElement = map { it + arg } +public operator fun ComplexNDElement.plus(arg: Complex): ComplexNDElement = map { it + arg } /** * Subtraction operation between [BufferedNDElement] and single element */ -operator fun ComplexNDElement.minus(arg: Complex): ComplexNDElement = - map { it - arg } +public operator fun ComplexNDElement.minus(arg: Complex): ComplexNDElement = map { it - arg } -operator fun ComplexNDElement.plus(arg: Double): ComplexNDElement = - map { it + arg } +public operator fun ComplexNDElement.plus(arg: Double): ComplexNDElement = map { it + arg } +public operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement = map { it - arg } -operator fun ComplexNDElement.minus(arg: Double): ComplexNDElement = - map { it - arg } +public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape) -fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape) - -fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(IntArray) -> Complex): ComplexNDElement = - NDField.complex(*shape).produce(initializer) +public fun NDElement.Companion.complex( + vararg shape: Int, + initializer: ComplexField.(IntArray) -> Complex, +): ComplexNDElement = NDField.complex(*shape).produce(initializer) /** * Produce a context for n-dimensional operations inside this real field */ -inline fun ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R { +public inline fun ComplexField.nd(vararg shape: Int, action: ComplexNDField.() -> R): R { contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } return NDField.complex(*shape).action() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ExtendedNDField.kt similarity index 86% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ExtendedNDField.kt index 24aa48c6b..a9fa2763b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ExtendedNDField.kt @@ -1,6 +1,6 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.ExtendedField +import kscience.kmath.operations.ExtendedField /** * [ExtendedField] over [NDStructure]. @@ -9,7 +9,7 @@ import scientifik.kmath.operations.ExtendedField * @param N the type of ND structure. * @param F the extended field of structure elements. */ -interface ExtendedNDField, N : NDStructure> : NDField, ExtendedField +public interface ExtendedNDField, N : NDStructure> : NDField, ExtendedField ///** // * NDField that supports [ExtendedField] operations on its elements diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt similarity index 64% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt index 9c32aa31b..4965e37cf 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FlaggedBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FlaggedBuffer.kt @@ -1,7 +1,5 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract import kotlin.experimental.and /** @@ -9,7 +7,7 @@ import kotlin.experimental.and * * @property mask bit mask value of this flag. */ -enum class ValueFlag(val mask: Byte) { +public enum class ValueFlag(public val mask: Byte) { /** * Reports the value is NaN. */ @@ -34,23 +32,24 @@ enum class ValueFlag(val mask: Byte) { /** * A buffer with flagged values. */ -interface FlaggedBuffer : Buffer { - fun getFlag(index: Int): Byte +public interface FlaggedBuffer : Buffer { + public fun getFlag(index: Int): Byte } /** * The value is valid if all flags are down */ -fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte() +public fun FlaggedBuffer<*>.isValid(index: Int): Boolean = getFlag(index) != 0.toByte() -fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte() +public fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (getFlag(index) and flag.mask) != 0.toByte() -fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING) +public fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING) /** * A real buffer which supports flags for each value like NaN or Missing */ -class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : FlaggedBuffer, Buffer { +public class FlaggedRealBuffer(public val values: DoubleArray, public val flags: ByteArray) : FlaggedBuffer, + Buffer { init { require(values.size == flags.size) { "Values and flags must have the same dimensions" } } @@ -66,9 +65,7 @@ class FlaggedRealBuffer(val values: DoubleArray, val flags: ByteArray) : Flagged }.iterator() } -inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { - contract { callsInPlace(block) } - +public inline fun FlaggedRealBuffer.forEachValid(block: (Double) -> Unit) { indices .asSequence() .filter(::isValid) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FloatBuffer.kt similarity index 68% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FloatBuffer.kt index 9e974c644..e96c45572 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/FloatBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/FloatBuffer.kt @@ -1,14 +1,12 @@ -package scientifik.kmath.structures - -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract +package kscience.kmath.structures /** * Specialized [MutableBuffer] implementation over [FloatArray]. * * @property array the underlying array. + * @author Iaroslav Postovalov */ -inline class FloatBuffer(val array: FloatArray) : MutableBuffer { +public inline class FloatBuffer(public val array: FloatArray) : MutableBuffer { override val size: Int get() = array.size override operator fun get(index: Int): Float = array[index] @@ -30,20 +28,17 @@ inline class FloatBuffer(val array: FloatArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer { - contract { callsInPlace(init) } - return FloatBuffer(FloatArray(size) { init(it) }) -} +public inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) }) /** * Returns a new [FloatBuffer] of given elements. */ -fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats) +public fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats) /** * Returns a [FloatArray] containing all of the elements of this [MutableBuffer]. */ -val MutableBuffer.array: FloatArray +public val MutableBuffer.array: FloatArray get() = (if (this is FloatBuffer) array else FloatArray(size) { get(it) }) /** @@ -52,4 +47,4 @@ val MutableBuffer.array: FloatArray * @receiver the array. * @return the new buffer. */ -fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this) +public fun FloatArray.asBuffer(): FloatBuffer = FloatBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/IntBuffer.kt similarity index 67% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/IntBuffer.kt index 95651c547..0fe68803b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/IntBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/IntBuffer.kt @@ -1,15 +1,11 @@ -package scientifik.kmath.structures - -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract +package kscience.kmath.structures /** * Specialized [MutableBuffer] implementation over [IntArray]. * * @property array the underlying array. */ -inline class IntBuffer(val array: IntArray) : MutableBuffer { +public inline class IntBuffer(public val array: IntArray) : MutableBuffer { override val size: Int get() = array.size override operator fun get(index: Int): Int = array[index] @@ -31,20 +27,17 @@ inline class IntBuffer(val array: IntArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer { - contract { callsInPlace(init) } - return IntBuffer(IntArray(size) { init(it) }) -} +public inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) }) /** * Returns a new [IntBuffer] of given elements. */ -fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints) +public fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints) /** * Returns a [IntArray] containing all of the elements of this [MutableBuffer]. */ -val MutableBuffer.array: IntArray +public val MutableBuffer.array: IntArray get() = (if (this is IntBuffer) array else IntArray(size) { get(it) }) /** @@ -53,4 +46,4 @@ val MutableBuffer.array: IntArray * @receiver the array. * @return the new buffer. */ -fun IntArray.asBuffer(): IntBuffer = IntBuffer(this) +public fun IntArray.asBuffer(): IntBuffer = IntBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/LongBuffer.kt similarity index 68% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/LongBuffer.kt index a44109f8a..87853c251 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/LongBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/LongBuffer.kt @@ -1,14 +1,11 @@ -package scientifik.kmath.structures - -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract +package kscience.kmath.structures /** * Specialized [MutableBuffer] implementation over [LongArray]. * * @property array the underlying array. */ -inline class LongBuffer(val array: LongArray) : MutableBuffer { +public inline class LongBuffer(public val array: LongArray) : MutableBuffer { override val size: Int get() = array.size override operator fun get(index: Int): Long = array[index] @@ -21,7 +18,6 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer { override fun copy(): MutableBuffer = LongBuffer(array.copyOf()) - } /** @@ -31,20 +27,17 @@ inline class LongBuffer(val array: LongArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer { - contract { callsInPlace(init) } - return LongBuffer(LongArray(size) { init(it) }) -} +public inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) }) /** * Returns a new [LongBuffer] of given elements. */ -fun LongBuffer(vararg longs: Long): LongBuffer = LongBuffer(longs) +public fun LongBuffer(vararg longs: Long): LongBuffer = LongBuffer(longs) /** * Returns a [IntArray] containing all of the elements of this [MutableBuffer]. */ -val MutableBuffer.array: LongArray +public val MutableBuffer.array: LongArray get() = (if (this is LongBuffer) array else LongArray(size) { get(it) }) /** @@ -53,4 +46,4 @@ val MutableBuffer.array: LongArray * @receiver the array. * @return the new buffer. */ -fun LongArray.asBuffer(): LongBuffer = LongBuffer(this) +public fun LongArray.asBuffer(): LongBuffer = LongBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt similarity index 52% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt index 83c50b14b..66c9212cf 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/MemoryBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/MemoryBuffer.kt @@ -1,6 +1,6 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.memory.* +import kscience.kmath.memory.* /** * A non-boxing buffer over [Memory] object. @@ -9,7 +9,7 @@ import scientifik.memory.* * @property memory the underlying memory segment. * @property spec the spec of [T] type. */ -open class MemoryBuffer(protected val memory: Memory, protected val spec: MemorySpec) : Buffer { +public open class MemoryBuffer(protected val memory: Memory, protected val spec: MemorySpec) : Buffer { override val size: Int get() = memory.size / spec.objectSize private val reader: MemoryReader = memory.reader() @@ -17,20 +17,17 @@ open class MemoryBuffer(protected val memory: Memory, protected val spe override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index) override operator fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() - companion object { - fun create(spec: MemorySpec, size: Int): MemoryBuffer = + public companion object { + public fun create(spec: MemorySpec, size: Int): MemoryBuffer = MemoryBuffer(Memory.allocate(size * spec.objectSize), spec) - inline fun create( + public inline fun create( spec: MemorySpec, size: Int, - crossinline initializer: (Int) -> T - ): MemoryBuffer = - MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> - (0 until size).forEach { - buffer[it] = initializer(it) - } - } + initializer: (Int) -> T + ): MemoryBuffer = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> + (0 until size).forEach { buffer[it] = initializer(it) } + } } } @@ -41,7 +38,7 @@ open class MemoryBuffer(protected val memory: Memory, protected val spe * @property memory the underlying memory segment. * @property spec the spec of [T] type. */ -class MutableMemoryBuffer(memory: Memory, spec: MemorySpec) : MemoryBuffer(memory, spec), +public class MutableMemoryBuffer(memory: Memory, spec: MemorySpec) : MemoryBuffer(memory, spec), MutableBuffer { private val writer: MemoryWriter = memory.writer() @@ -49,19 +46,16 @@ class MutableMemoryBuffer(memory: Memory, spec: MemorySpec) : Memory override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) override fun copy(): MutableBuffer = MutableMemoryBuffer(memory.copy(), spec) - companion object { - fun create(spec: MemorySpec, size: Int): MutableMemoryBuffer = + public companion object { + public fun create(spec: MemorySpec, size: Int): MutableMemoryBuffer = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec) - inline fun create( + public inline fun create( spec: MemorySpec, size: Int, - crossinline initializer: (Int) -> T - ): MutableMemoryBuffer = - MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> - (0 until size).forEach { - buffer[it] = initializer(it) - } - } + initializer: (Int) -> T + ): MutableMemoryBuffer = MutableMemoryBuffer(Memory.allocate(size * spec.objectSize), spec).also { buffer -> + (0 until size).forEach { buffer[it] = initializer(it) } + } } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt new file mode 100644 index 000000000..d7b019c65 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDAlgebra.kt @@ -0,0 +1,259 @@ +package kscience.kmath.structures + +import kscience.kmath.operations.Complex +import kscience.kmath.operations.Field +import kscience.kmath.operations.Ring +import kscience.kmath.operations.Space +import kotlin.native.concurrent.ThreadLocal + +/** + * An exception is thrown when the expected ans actual shape of NDArray differs. + * + * @property expected the expected shape. + * @property actual the actual shape. + */ +public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) : + RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.") + +/** + * The base interface for all ND-algebra implementations. + * + * @param T the type of ND-structure element. + * @param C the type of the element context. + * @param N the type of the structure. + */ +public interface NDAlgebra> { + /** + * The shape of ND-structures this algebra operates on. + */ + public val shape: IntArray + + /** + * The algebra over elements of ND structure. + */ + public val elementContext: C + + /** + * Produces a new [N] structure using given initializer function. + */ + public fun produce(initializer: C.(IntArray) -> T): N + + /** + * Maps elements from one structure to another one by applying [transform] to them. + */ + public fun map(arg: N, transform: C.(T) -> T): N + + /** + * Maps elements from one structure to another one by applying [transform] to them alongside with their indices. + */ + public fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N + + /** + * Combines two structures into one. + */ + public fun combine(a: N, b: N, transform: C.(T, T) -> T): N + + /** + * Checks if given element is consistent with this context. + * + * @param element the structure to check. + * @return the valid structure. + */ + public fun check(element: N): N { + if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape) + return element + } + + /** + * Checks if given elements are consistent with this context. + * + * @param elements the structures to check. + * @return the array of valid structures. + */ + public fun check(vararg elements: N): Array = elements + .map(NDStructure::shape) + .singleOrNull { !shape.contentEquals(it) } + ?.let> { throw ShapeMismatchException(shape, it) } + ?: elements + + /** + * Element-wise invocation of function working on [T] on a [NDStructure]. + */ + public operator fun Function1.invoke(structure: N): N = map(structure) { value -> this@invoke(value) } + + public companion object +} + +/** + * Space of [NDStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param S the type of space of structure elements. + */ +public interface NDSpace, N : NDStructure> : Space, NDAlgebra { + /** + * Element-wise addition. + * + * @param a the addend. + * @param b the augend. + * @return the sum. + */ + public override fun add(a: N, b: N): N = combine(a, b) { aValue, bValue -> add(aValue, bValue) } + + /** + * Element-wise multiplication by scalar. + * + * @param a the multiplicand. + * @param k the multiplier. + * @return the product. + */ + public override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) } + + // TODO move to extensions after KEEP-176 + + /** + * Adds an ND structure to an element of it. + * + * @receiver the addend. + * @param arg the augend. + * @return the sum. + */ + public operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) } + + /** + * Subtracts an element from ND structure of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ + public operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) } + + /** + * Adds an element to ND structure of it. + * + * @receiver the addend. + * @param arg the augend. + * @return the sum. + */ + public operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) } + + /** + * Subtracts an ND structure from an element of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ + public operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) } + + public companion object +} + +/** + * Ring of [NDStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param R the type of ring of structure elements. + */ +public interface NDRing, N : NDStructure> : Ring, NDSpace { + /** + * Element-wise multiplication. + * + * @param a the multiplicand. + * @param b the multiplier. + * @return the product. + */ + public override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } + + //TODO move to extensions after KEEP-176 + + /** + * Multiplies an ND structure by an element of it. + * + * @receiver the multiplicand. + * @param arg the multiplier. + * @return the product. + */ + public operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) } + + /** + * Multiplies an element by a ND structure of it. + * + * @receiver the multiplicand. + * @param arg the multiplier. + * @return the product. + */ + public operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) } + + public companion object +} + +/** + * Field of [NDStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param F the type field of structure elements. + */ +public interface NDField, N : NDStructure> : Field, NDRing { + /** + * Element-wise division. + * + * @param a the dividend. + * @param b the divisor. + * @return the quotient. + */ + public override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) } + + //TODO move to extensions after KEEP-176 + /** + * Divides an ND structure by an element of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ + public operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) } + + /** + * Divides an element by an ND structure of it. + * + * @receiver the dividend. + * @param arg the divisor. + * @return the quotient. + */ + public operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) } + + @ThreadLocal + public companion object { + private val realNDFieldCache: MutableMap = hashMapOf() + + /** + * Create a nd-field for [Double] values or pull it from cache if it was created previously. + */ + public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) } + + /** + * Create an ND field with boxing generic buffer. + */ + public fun > boxing( + field: F, + vararg shape: Int, + bufferFactory: BufferFactory = Buffer.Companion::boxing + ): BoxingNDField = BoxingNDField(shape, field, bufferFactory) + + /** + * Create a most suitable implementation for nd-field using reified class. + */ + @Suppress("UNCHECKED_CAST") + public inline fun > auto(field: F, vararg shape: Int): BufferedNDField = + when { + T::class == Double::class -> real(*shape) as BufferedNDField + T::class == Complex::class -> complex(*shape) as BufferedNDField + else -> BoxingNDField(shape, field, Buffer.Companion::auto) + } + } +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDElement.kt similarity index 65% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDElement.kt index 6cc0a72c0..f2f565064 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDElement.kt @@ -1,9 +1,9 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space +import kscience.kmath.operations.Field +import kscience.kmath.operations.RealField +import kscience.kmath.operations.Ring +import kscience.kmath.operations.Space /** * The root for all [NDStructure] based algebra elements. Does not implement algebra element root because of problems with recursive self-types @@ -11,31 +11,30 @@ import scientifik.kmath.operations.Space * @param C the type of the context for the element * @param N the type of the underlying [NDStructure] */ -interface NDElement> : NDStructure { +public interface NDElement> : NDStructure { + public val context: NDAlgebra - val context: NDAlgebra + public fun unwrap(): N - fun unwrap(): N + public fun N.wrap(): NDElement - fun N.wrap(): NDElement - - companion object { + public companion object { /** * Create a optimized NDArray of doubles */ - fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement = + public fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }): RealNDElement = NDField.real(*shape).produce(initializer) - inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement = + public inline fun real1D(dim: Int, crossinline initializer: (Int) -> Double = { _ -> 0.0 }): RealNDElement = real(intArrayOf(dim)) { initializer(it[0]) } - inline fun real2D( + public inline fun real2D( dim1: Int, dim2: Int, crossinline initializer: (Int, Int) -> Double = { _, _ -> 0.0 } ): RealNDElement = real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } - inline fun real3D( + public inline fun real3D( dim1: Int, dim2: Int, dim3: Int, @@ -46,7 +45,7 @@ interface NDElement> : NDStructure { /** * Simple boxing NDArray */ - fun > boxing( + public fun > boxing( shape: IntArray, field: F, initializer: F.(IntArray) -> T @@ -55,7 +54,7 @@ interface NDElement> : NDStructure { return ndField.produce(initializer) } - inline fun > auto( + public inline fun > auto( shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T @@ -66,17 +65,16 @@ interface NDElement> : NDStructure { } } - -fun > NDElement.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement = +public fun > NDElement.mapIndexed(transform: C.(index: IntArray, T) -> T): NDElement = context.mapIndexed(unwrap(), transform).wrap() -fun > NDElement.map(transform: C.(T) -> T): NDElement = +public fun > NDElement.map(transform: C.(T) -> T): NDElement = context.map(unwrap(), transform).wrap() /** * Element by element application of any operation on elements to the whole [NDElement] */ -operator fun > Function1.invoke(ndElement: NDElement): NDElement = +public operator fun > Function1.invoke(ndElement: NDElement): NDElement = ndElement.map { value -> this@invoke(value) } /* plus and minus */ @@ -84,13 +82,13 @@ operator fun > Function1.invoke(ndElement: NDElem /** * Summation operation for [NDElement] and single element */ -operator fun , N : NDStructure> NDElement.plus(arg: T): NDElement = +public operator fun , N : NDStructure> NDElement.plus(arg: T): NDElement = map { value -> arg + value } /** * Subtraction operation between [NDElement] and single element */ -operator fun , N : NDStructure> NDElement.minus(arg: T): NDElement = +public operator fun , N : NDStructure> NDElement.minus(arg: T): NDElement = map { value -> arg - value } /* prod and div */ @@ -98,13 +96,13 @@ operator fun , N : NDStructure> NDElement.minus(arg: /** * Product operation for [NDElement] and single element */ -operator fun , N : NDStructure> NDElement.times(arg: T): NDElement = +public operator fun , N : NDStructure> NDElement.times(arg: T): NDElement = map { value -> arg * value } /** * Division operation between [NDElement] and single element */ -operator fun , N : NDStructure> NDElement.div(arg: T): NDElement = +public operator fun , N : NDStructure> NDElement.div(arg: T): NDElement = map { value -> arg / value } // /** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt similarity index 64% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt index f4eb93b9e..e7d89ca7e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt @@ -1,8 +1,8 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract +import kscience.kmath.misc.UnstableKMathAPI import kotlin.jvm.JvmName +import kotlin.native.concurrent.ThreadLocal import kotlin.reflect.KClass /** @@ -12,17 +12,17 @@ import kotlin.reflect.KClass * * @param T the type of items. */ -interface NDStructure { +public interface NDStructure { /** * The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of * this structure. */ - val shape: IntArray + public val shape: IntArray /** * The count of dimensions in this structure. It should be equal to size of [shape]. */ - val dimension: Int get() = shape.size + public val dimension: Int get() = shape.size /** * Returns the value at the specified indices. @@ -30,34 +30,36 @@ interface NDStructure { * @param index the indices. * @return the value. */ - operator fun get(index: IntArray): T + public operator fun get(index: IntArray): T /** * Returns the sequence of all the elements associated by their indices. * * @return the lazy sequence of pairs of indices to values. */ - fun elements(): Sequence> + public fun elements(): Sequence> - override fun equals(other: Any?): Boolean + //force override equality and hash code + public override fun equals(other: Any?): Boolean + public override fun hashCode(): Int - override fun hashCode(): Int + /** + * Feature is additional property or hint that does not directly affect the structure, but could in some cases help + * optimize operations and performance. If the feature is not present, null is defined. + */ + @UnstableKMathAPI + public fun getFeature(type: KClass): T? = null - companion object { + public companion object { /** * Indicates whether some [NDStructure] is equal to another one. */ - fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { + public fun contentEquals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { if (st1 === st2) return true // fast comparison of buffers if possible - if ( - st1 is NDBuffer && - st2 is NDBuffer && - st1.strides == st2.strides - ) { + if (st1 is NDBuffer && st2 is NDBuffer && st1.strides == st2.strides) return st1.buffer.contentEquals(st2.buffer) - } //element by element comparison if it could not be avoided return st1.elements().all { (index, value) -> value == st2[index] } @@ -68,52 +70,52 @@ interface NDStructure { * * Strides should be reused if possible. */ - fun build( + public fun build( strides: Strides, bufferFactory: BufferFactory = Buffer.Companion::boxing, - initializer: (IntArray) -> T + initializer: (IntArray) -> T, ): BufferNDStructure = BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) /** * Inline create NDStructure with non-boxing buffer implementation if it is possible */ - inline fun auto( + public inline fun auto( strides: Strides, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) - inline fun auto( + public inline fun auto( type: KClass, strides: Strides, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = BufferNDStructure(strides, Buffer.auto(type, strides.linearSize) { i -> initializer(strides.index(i)) }) - fun build( + public fun build( shape: IntArray, bufferFactory: BufferFactory = Buffer.Companion::boxing, - initializer: (IntArray) -> T + initializer: (IntArray) -> T, ): BufferNDStructure = build(DefaultStrides(shape), bufferFactory, initializer) - inline fun auto( + public inline fun auto( shape: IntArray, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = auto(DefaultStrides(shape), initializer) @JvmName("autoVarArg") - inline fun auto( + public inline fun auto( vararg shape: Int, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = auto(DefaultStrides(shape), initializer) - inline fun auto( + public inline fun auto( type: KClass, vararg shape: Int, - crossinline initializer: (IntArray) -> T + crossinline initializer: (IntArray) -> T, ): BufferNDStructure = auto(type, DefaultStrides(shape), initializer) } @@ -125,68 +127,74 @@ interface NDStructure { * @param index the indices. * @return the value. */ -operator fun NDStructure.get(vararg index: Int): T = get(index) +public operator fun NDStructure.get(vararg index: Int): T = get(index) + +@UnstableKMathAPI +public inline fun NDStructure<*>.getFeature(): T? = getFeature(T::class) /** * Represents mutable [NDStructure]. */ -interface MutableNDStructure : NDStructure { +public interface MutableNDStructure : NDStructure { /** * Inserts an item at the specified indices. * * @param index the indices. * @param value the value. */ - operator fun set(index: IntArray, value: T) + public operator fun set(index: IntArray, value: T) } -inline fun MutableNDStructure.mapInPlace(action: (IntArray, T) -> T) { - contract { callsInPlace(action) } +/** + * Transform a structure element-by element in place. + */ +public inline fun MutableNDStructure.mapInPlace(action: (IntArray, T) -> T): Unit = elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } -} /** * A way to convert ND index to linear one and back. */ -interface Strides { +public interface Strides { /** * Shape of NDstructure */ - val shape: IntArray + public val shape: IntArray /** * Array strides */ - val strides: List + public val strides: List /** * Get linear index from multidimensional index */ - fun offset(index: IntArray): Int + public fun offset(index: IntArray): Int /** * Get multidimensional from linear */ - fun index(offset: Int): IntArray + public fun index(offset: Int): IntArray /** * The size of linear buffer to accommodate all elements of ND-structure corresponding to strides */ - val linearSize: Int + public val linearSize: Int + + // TODO introduce a fast way to calculate index of the next element? /** * Iterate over ND indices in a natural order */ - fun indices(): Sequence { - //TODO introduce a fast way to calculate index of the next element? - return (0 until linearSize).asSequence().map { index(it) } - } + public fun indices(): Sequence = (0 until linearSize).asSequence().map { index(it) } } /** * Simple implementation of [Strides]. */ -class DefaultStrides private constructor(override val shape: IntArray) : Strides { +public class DefaultStrides private constructor(override val shape: IntArray) : Strides { + override val linearSize: Int + get() = strides[shape.size] + /** * Strides for memory access */ @@ -194,6 +202,7 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides sequence { var current = 1 yield(1) + shape.forEach { current *= it yield(current) @@ -212,17 +221,16 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides val res = IntArray(shape.size) var current = offset var strideIndex = strides.size - 2 + while (strideIndex >= 0) { res[strideIndex] = (current / strides[strideIndex]) current %= strides[strideIndex] strideIndex-- } + return res } - override val linearSize: Int - get() = strides[shape.size] - override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is DefaultStrides) return false @@ -232,13 +240,15 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides override fun hashCode(): Int = shape.contentHashCode() - companion object { + @ThreadLocal + public companion object { private val defaultStridesCache = HashMap() /** * Cached builder for default strides */ - operator fun invoke(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) } + public operator fun invoke(shape: IntArray): Strides = + defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) } } } @@ -247,16 +257,16 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides * * @param T the type of items. */ -abstract class NDBuffer : NDStructure { +public abstract class NDBuffer : NDStructure { /** * The underlying buffer. */ - abstract val buffer: Buffer + public abstract val buffer: Buffer /** * The strides to access elements of [Buffer] by linear indices. */ - abstract val strides: Strides + public abstract val strides: Strides override operator fun get(index: IntArray): T = buffer[strides.offset(index)] @@ -265,7 +275,7 @@ abstract class NDBuffer : NDStructure { override fun elements(): Sequence> = strides.indices().map { it to this[it] } override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false) } override fun hashCode(): Int { @@ -273,14 +283,30 @@ abstract class NDBuffer : NDStructure { result = 31 * result + buffer.hashCode() return result } + + override fun toString(): String { + val bufferRepr: String = when (shape.size) { + 1 -> buffer.asSequence().joinToString(prefix = "[", postfix = "]", separator = ", ") + 2 -> (0 until shape[0]).joinToString(prefix = "[", postfix = "]", separator = ", ") { i -> + (0 until shape[1]).joinToString(prefix = "[", postfix = "]", separator = ", ") { j -> + val offset = strides.offset(intArrayOf(i, j)) + buffer[offset].toString() + } + } + else -> "..." + } + return "NDBuffer(shape=${shape.contentToString()}, buffer=$bufferRepr)" + } + + } /** * Boxing generic [NDStructure] */ -class BufferNDStructure( +public class BufferNDStructure( override val strides: Strides, - override val buffer: Buffer + override val buffer: Buffer, ) : NDBuffer() { init { if (strides.linearSize != buffer.size) { @@ -292,13 +318,13 @@ class BufferNDStructure( /** * Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure] */ -inline fun NDStructure.mapToBuffer( +public inline fun NDStructure.mapToBuffer( factory: BufferFactory = Buffer.Companion::auto, - crossinline transform: (T) -> R + crossinline transform: (T) -> R, ): BufferNDStructure { - return if (this is BufferNDStructure) { + return if (this is BufferNDStructure) BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) - } else { + else { val strides = DefaultStrides(shape) BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) } @@ -307,9 +333,9 @@ inline fun NDStructure.mapToBuffer( /** * Mutable ND buffer based on linear [MutableBuffer]. */ -class MutableBufferNDStructure( +public class MutableBufferNDStructure( override val strides: Strides, - override val buffer: MutableBuffer + override val buffer: MutableBuffer, ) : NDBuffer(), MutableNDStructure { init { @@ -321,9 +347,9 @@ class MutableBufferNDStructure( override operator fun set(index: IntArray, value: T): Unit = buffer.set(strides.offset(index), value) } -inline fun NDStructure.combine( +public inline fun NDStructure.combine( struct: NDStructure, - crossinline block: (T, T) -> T + crossinline block: (T, T) -> T, ): NDStructure { require(shape.contentEquals(struct.shape)) { "Shape mismatch in structure combination" } return NDStructure.auto(shape) { block(this[it], struct[it]) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBuffer.kt similarity index 52% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBuffer.kt index cba8e9689..769c445d6 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBuffer.kt @@ -1,26 +1,23 @@ -package scientifik.kmath.structures - -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract +package kscience.kmath.structures /** * Specialized [MutableBuffer] implementation over [DoubleArray]. * * @property array the underlying array. */ -inline class RealBuffer(val array: DoubleArray) : MutableBuffer { +@Suppress("OVERRIDE_BY_INLINE") +public inline class RealBuffer(public val array: DoubleArray) : MutableBuffer { override val size: Int get() = array.size - override operator fun get(index: Int): Double = array[index] + override inline operator fun get(index: Int): Double = array[index] - override operator fun set(index: Int, value: Double) { + override inline operator fun set(index: Int, value: Double) { array[index] = value } override operator fun iterator(): DoubleIterator = array.iterator() - override fun copy(): MutableBuffer = - RealBuffer(array.copyOf()) + override fun copy(): RealBuffer = RealBuffer(array.copyOf()) } /** @@ -30,20 +27,22 @@ inline class RealBuffer(val array: DoubleArray) : MutableBuffer { * The function [init] is called for each array element sequentially starting from the first one. * It should return the value for an buffer element given its index. */ -inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer { - contract { callsInPlace(init) } - return RealBuffer(DoubleArray(size) { init(it) }) -} +public inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = RealBuffer(DoubleArray(size) { init(it) }) /** * Returns a new [RealBuffer] of given elements. */ -fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles) +public fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles) + +/** + * Simplified [RealBuffer] to array comparison + */ +public fun RealBuffer.contentEquals(vararg doubles: Double): Boolean = array.contentEquals(doubles) /** * Returns a [DoubleArray] containing all of the elements of this [MutableBuffer]. */ -val MutableBuffer.array: DoubleArray +public val MutableBuffer.array: DoubleArray get() = (if (this is RealBuffer) array else DoubleArray(size) { get(it) }) /** @@ -52,4 +51,4 @@ val MutableBuffer.array: DoubleArray * @receiver the array. * @return the new buffer. */ -fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this) +public fun DoubleArray.asBuffer(): RealBuffer = RealBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt similarity index 61% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt index a11826e7e..3f4d15c4d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt @@ -1,15 +1,14 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.ExtendedField -import scientifik.kmath.operations.ExtendedFieldOperations +import kscience.kmath.operations.ExtendedField +import kscience.kmath.operations.ExtendedFieldOperations import kotlin.math.* - /** * [ExtendedFieldOperations] over [RealBuffer]. */ -object RealBufferFieldOperations : ExtendedFieldOperations> { - override fun add(a: Buffer, b: Buffer): RealBuffer { +public object RealBufferFieldOperations : ExtendedFieldOperations> { + public override fun add(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } @@ -21,7 +20,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { } else RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } - override fun multiply(a: Buffer, k: Number): RealBuffer { + public override fun multiply(a: Buffer, k: Number): RealBuffer { val kValue = k.toDouble() return if (a is RealBuffer) { @@ -30,7 +29,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { } else RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } - override fun multiply(a: Buffer, b: Buffer): RealBuffer { + public override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } @@ -43,7 +42,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) } - override fun divide(a: Buffer, b: Buffer): RealBuffer { + public override fun divide(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } @@ -55,84 +54,91 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { } else RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } - override fun sin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun sin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) } else RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) - override fun cos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun cos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) } else RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) - override fun tan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun tan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) } else RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) - override fun asin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun asin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { asin(array[it]) }) - } else { + } else RealBuffer(DoubleArray(arg.size) { asin(arg[it]) }) - } - override fun acos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun acos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { acos(array[it]) }) } else RealBuffer(DoubleArray(arg.size) { acos(arg[it]) }) - override fun atan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun atan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { atan(array[it]) }) } else RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) - override fun sinh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun sinh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { sinh(array[it]) }) - } else RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) }) - override fun cosh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun cosh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { cosh(array[it]) }) - } else RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) }) - override fun tanh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun tanh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { tanh(array[it]) }) - } else RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) }) - override fun asinh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun asinh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { asinh(array[it]) }) - } else RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) }) - override fun acosh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun acosh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { acosh(array[it]) }) - } else RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) }) - override fun atanh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun atanh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { atanh(array[it]) }) - } else RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) }) - override fun power(arg: Buffer, pow: Number): RealBuffer = if (arg is RealBuffer) { + public override fun power(arg: Buffer, pow: Number): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + } else + RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - override fun exp(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun exp(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) } else RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - override fun ln(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + public override fun ln(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } /** @@ -140,101 +146,103 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { * * @property size the size of buffers to operate on. */ -class RealBufferField(val size: Int) : ExtendedField> { - override val zero: Buffer by lazy { RealBuffer(size) { 0.0 } } - override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } +public class RealBufferField(public val size: Int) : ExtendedField> { + public override val zero: Buffer by lazy { RealBuffer(size) { 0.0 } } + public override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } - override fun add(a: Buffer, b: Buffer): RealBuffer { + override fun number(value: Number): Buffer = RealBuffer(size) { value.toDouble() } + + public override fun add(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.add(a, b) } - override fun multiply(a: Buffer, k: Number): RealBuffer { + public override fun multiply(a: Buffer, k: Number): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, k) } - override fun multiply(a: Buffer, b: Buffer): RealBuffer { + public override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, b) } - override fun divide(a: Buffer, b: Buffer): RealBuffer { + public override fun divide(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.divide(a, b) } - override fun sin(arg: Buffer): RealBuffer { + public override fun sin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.sin(arg) } - override fun cos(arg: Buffer): RealBuffer { + public override fun cos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.cos(arg) } - override fun tan(arg: Buffer): RealBuffer { + public override fun tan(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.tan(arg) } - override fun asin(arg: Buffer): RealBuffer { + public override fun asin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.asin(arg) } - override fun acos(arg: Buffer): RealBuffer { + public override fun acos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.acos(arg) } - override fun atan(arg: Buffer): RealBuffer { + public override fun atan(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.atan(arg) } - override fun sinh(arg: Buffer): RealBuffer { + public override fun sinh(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.sinh(arg) } - override fun cosh(arg: Buffer): RealBuffer { + public override fun cosh(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.cosh(arg) } - override fun tanh(arg: Buffer): RealBuffer { + public override fun tanh(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.tanh(arg) } - override fun asinh(arg: Buffer): RealBuffer { + public override fun asinh(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.asinh(arg) } - override fun acosh(arg: Buffer): RealBuffer { + public override fun acosh(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.acosh(arg) } - override fun atanh(arg: Buffer): RealBuffer { + public override fun atanh(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.atanh(arg) } - override fun power(arg: Buffer, pow: Number): RealBuffer { + public override fun power(arg: Buffer, pow: Number): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) } - override fun exp(arg: Buffer): RealBuffer { + public override fun exp(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.exp(arg) } - override fun ln(arg: Buffer): RealBuffer { + public override fun ln(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.ln(arg) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt similarity index 59% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt index 6533f64be..60e6de440 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt @@ -1,13 +1,17 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.FieldElement -import scientifik.kmath.operations.RealField +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.FieldElement +import kscience.kmath.operations.RealField +import kscience.kmath.operations.RingWithNumbers -typealias RealNDElement = BufferedNDFieldElement +public typealias RealNDElement = BufferedNDFieldElement -class RealNDField(override val shape: IntArray) : +@OptIn(UnstableKMathAPI::class) +public class RealNDField(override val shape: IntArray) : BufferedNDField, - ExtendedNDField> { + ExtendedNDField>, + RingWithNumbers> { override val strides: Strides = DefaultStrides(shape) @@ -15,35 +19,36 @@ class RealNDField(override val shape: IntArray) : override val zero: RealNDElement by lazy { produce { zero } } override val one: RealNDElement by lazy { produce { one } } - inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = - RealBuffer(DoubleArray(size) { initializer(it) }) + override fun number(value: Number): NDBuffer { + val d = value.toDouble() + return produce { d } + } - /** - * Inline transform an NDStructure to - */ - override fun map( + @Suppress("OVERRIDE_BY_INLINE") + override inline fun map( arg: NDBuffer, - transform: RealField.(Double) -> Double + transform: RealField.(Double) -> Double, ): RealNDElement { check(arg) - val array = buildBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) } + val array = RealBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) } return BufferedNDFieldElement(this, array) } - override fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement { - val array = buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) } + @Suppress("OVERRIDE_BY_INLINE") + override inline fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement { + val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) } return BufferedNDFieldElement(this, array) } - override fun mapIndexed( + @Suppress("OVERRIDE_BY_INLINE") + override inline fun mapIndexed( arg: NDBuffer, - transform: RealField.(index: IntArray, Double) -> Double + transform: RealField.(index: IntArray, Double) -> Double, ): RealNDElement { check(arg) - return BufferedNDFieldElement( this, - buildBuffer(arg.strides.linearSize) { offset -> + RealBuffer(arg.strides.linearSize) { offset -> elementContext.transform( arg.strides.index(offset), arg.buffer[offset] @@ -51,15 +56,17 @@ class RealNDField(override val shape: IntArray) : }) } - override fun combine( + @Suppress("OVERRIDE_BY_INLINE") + override inline fun combine( a: NDBuffer, b: NDBuffer, - transform: RealField.(Double, Double) -> Double + transform: RealField.(Double, Double) -> Double, ): RealNDElement { check(a, b) - return BufferedNDFieldElement( - this, - buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) + val buffer = RealBuffer(strides.linearSize) { offset -> + elementContext.transform(a.buffer[offset], b.buffer[offset]) + } + return BufferedNDFieldElement(this, buffer) } override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = @@ -90,7 +97,7 @@ class RealNDField(override val shape: IntArray) : /** * Fast element production using function inlining */ -inline fun BufferedNDField.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { +public inline fun BufferedNDField.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) } return BufferedNDFieldElement(this, RealBuffer(array)) } @@ -98,13 +105,13 @@ inline fun BufferedNDField.produceInline(crossinline initiali /** * Map one [RealNDElement] using function with indices. */ -inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double): RealNDElement = +public inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double): RealNDElement = context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } /** * Map one [RealNDElement] using function without indices. */ -inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { +public inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } return BufferedNDFieldElement(context, RealBuffer(array)) } @@ -112,26 +119,22 @@ inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double /** * Element by element application of any operation on elements to the whole array. Just like in numpy. */ -operator fun Function1.invoke(ndElement: RealNDElement): RealNDElement = +public operator fun Function1.invoke(ndElement: RealNDElement): RealNDElement = ndElement.map { this@invoke(it) } - /* plus and minus */ /** * Summation operation for [BufferedNDElement] and single element */ -operator fun RealNDElement.plus(arg: Double): RealNDElement = - map { it + arg } +public operator fun RealNDElement.plus(arg: Double): RealNDElement = map { it + arg } /** * Subtraction operation between [BufferedNDElement] and single element */ -operator fun RealNDElement.minus(arg: Double): RealNDElement = - map { it - arg } +public operator fun RealNDElement.minus(arg: Double): RealNDElement = map { it - arg } /** * Produce a context for n-dimensional operations inside this real field */ - -inline fun RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action) +public inline fun RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ShortBuffer.kt new file mode 100644 index 000000000..0d9222320 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ShortBuffer.kt @@ -0,0 +1,47 @@ +package kscience.kmath.structures + +/** + * Specialized [MutableBuffer] implementation over [ShortArray]. + * + * @property array the underlying array. + */ +public inline class ShortBuffer(public val array: ShortArray) : MutableBuffer { + public override val size: Int get() = array.size + + public override operator fun get(index: Int): Short = array[index] + + public override operator fun set(index: Int, value: Short) { + array[index] = value + } + + public override operator fun iterator(): ShortIterator = array.iterator() + public override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) +} + +/** + * Creates a new [ShortBuffer] with the specified [size], where each element is calculated by calling the specified + * [init] function. + * + * The function [init] is called for each array element sequentially starting from the first one. + * It should return the value for an buffer element given its index. + */ +public inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) }) + +/** + * Returns a new [ShortBuffer] of given elements. + */ +public fun ShortBuffer(vararg shorts: Short): ShortBuffer = ShortBuffer(shorts) + +/** + * Returns a [ShortArray] containing all of the elements of this [MutableBuffer]. + */ +public val MutableBuffer.array: ShortArray + get() = (if (this is ShortBuffer) array else ShortArray(size) { get(it) }) + +/** + * Returns [ShortBuffer] over this array. + * + * @receiver the array. + * @return the new buffer. + */ +public fun ShortArray.asBuffer(): ShortBuffer = ShortBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ShortNDRing.kt similarity index 74% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ShortNDRing.kt index f404a2a27..3b506a26a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ShortNDRing.kt @@ -1,21 +1,19 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.RingElement -import scientifik.kmath.operations.ShortRing +import kscience.kmath.operations.RingElement +import kscience.kmath.operations.ShortRing +public typealias ShortNDElement = BufferedNDRingElement -typealias ShortNDElement = BufferedNDRingElement - -class ShortNDRing(override val shape: IntArray) : +public class ShortNDRing(override val shape: IntArray) : BufferedNDRing { override val strides: Strides = DefaultStrides(shape) - override val elementContext: ShortRing get() = ShortRing override val zero: ShortNDElement by lazy { produce { zero } } override val one: ShortNDElement by lazy { produce { one } } - inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer = + public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer = ShortBuffer(ShortArray(size) { initializer(it) }) /** @@ -70,15 +68,13 @@ class ShortNDRing(override val shape: IntArray) : /** * Fast element production using function inlining. */ -inline fun BufferedNDRing.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement { - val array = ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) } - return BufferedNDRingElement(this, ShortBuffer(array)) -} +public inline fun BufferedNDRing.produceInline(crossinline initializer: ShortRing.(Int) -> Short): ShortNDElement = + BufferedNDRingElement(this, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) })) /** * Element by element application of any operation on elements to the whole array. */ -operator fun Function1.invoke(ndElement: ShortNDElement): ShortNDElement = +public operator fun Function1.invoke(ndElement: ShortNDElement): ShortNDElement = ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) } @@ -87,11 +83,11 @@ operator fun Function1.invoke(ndElement: ShortNDElement): ShortNDE /** * Summation operation for [ShortNDElement] and single element. */ -operator fun ShortNDElement.plus(arg: Short): ShortNDElement = +public operator fun ShortNDElement.plus(arg: Short): ShortNDElement = context.produceInline { i -> (buffer[i] + arg).toShort() } /** * Subtraction operation between [ShortNDElement] and single element. */ -operator fun ShortNDElement.minus(arg: Short): ShortNDElement = +public operator fun ShortNDElement.minus(arg: Short): ShortNDElement = context.produceInline { i -> (buffer[i] - arg).toShort() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure1D.kt similarity index 69% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure1D.kt index a796c2037..95422ac60 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure1D.kt @@ -1,29 +1,27 @@ -package scientifik.kmath.structures +package kscience.kmath.structures /** * A structure that is guaranteed to be one-dimensional */ -interface Structure1D : NDStructure, Buffer { - override val dimension: Int get() = 1 +public interface Structure1D : NDStructure, Buffer { + public override val dimension: Int get() = 1 - override operator fun get(index: IntArray): T { + public override operator fun get(index: IntArray): T { require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" } return get(index[0]) } - override operator fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() + public override operator fun iterator(): Iterator = (0 until size).asSequence().map(::get).iterator() } /** * A 1D wrapper for nd-structure */ private inline class Structure1DWrapper(val structure: NDStructure) : Structure1D { - override val shape: IntArray get() = structure.shape override val size: Int get() = structure.shape[0] override operator fun get(index: Int): T = structure[index] - override fun elements(): Sequence> = structure.elements() } @@ -33,7 +31,6 @@ private inline class Structure1DWrapper(val structure: NDStructure) : Stru */ private inline class Buffer1DWrapper(val buffer: Buffer) : Structure1D { override val shape: IntArray get() = intArrayOf(buffer.size) - override val size: Int get() = buffer.size override fun elements(): Sequence> = @@ -45,18 +42,12 @@ private inline class Buffer1DWrapper(val buffer: Buffer) : Structure1D /** * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch */ -fun NDStructure.as1D(): Structure1D = if (shape.size == 1) { - if (this is NDBuffer) { - Buffer1DWrapper(this.buffer) - } else { - Structure1DWrapper(this) - } -} else { +public fun NDStructure.as1D(): Structure1D = if (shape.size == 1) { + if (this is NDBuffer) Buffer1DWrapper(this.buffer) else Structure1DWrapper(this) +} else error("Can't create 1d-structure from ${shape.size}d-structure") -} - /** * Represent this buffer as 1D structure */ -fun Buffer.asND(): Structure1D = Buffer1DWrapper(this) +public fun Buffer.asND(): Structure1D = Buffer1DWrapper(this) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt new file mode 100644 index 000000000..d20e9e53b --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt @@ -0,0 +1,82 @@ +package kscience.kmath.structures + +/** + * A structure that is guaranteed to be two-dimensional. + * + * @param T the type of items. + */ +public interface Structure2D : NDStructure { + /** + * The number of rows in this structure. + */ + public val rowNum: Int + + /** + * The number of columns in this structure. + */ + public val colNum: Int + + public override val shape: IntArray get() = intArrayOf(rowNum, colNum) + + /** + * The buffer of rows of this structure. It gets elements from the structure dynamically. + */ + public val rows: Buffer> + get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } } + + /** + * The buffer of columns of this structure. It gets elements from the structure dynamically. + */ + public val columns: Buffer> + get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } } + + /** + * Retrieves an element from the structure by two indices. + * + * @param i the first index. + * @param j the second index. + * @return an element. + */ + public operator fun get(i: Int, j: Int): T + + override operator fun get(index: IntArray): T { + require(index.size == 2) { "Index dimension mismatch. Expected 2 but found ${index.size}" } + return get(index[0], index[1]) + } + + override fun elements(): Sequence> = sequence { + for (i in 0 until rowNum) + for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j)) + } + + public companion object +} + +/** + * A 2D wrapper for nd-structure + */ +private inline class Structure2DWrapper(val structure: NDStructure) : Structure2D { + override val shape: IntArray get() = structure.shape + + override val rowNum: Int get() = shape[0] + override val colNum: Int get() = shape[1] + + override operator fun get(i: Int, j: Int): T = structure[i, j] + + override fun elements(): Sequence> = structure.elements() +} + +/** + * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch + */ +public fun NDStructure.as2D(): Structure2D = if (shape.size == 2) + Structure2DWrapper(this) +else + error("Can't create 2d-structure from ${shape.size}d-structure") + +/** + * Alias for [Structure2D] with more familiar name. + * + * @param T the type of items. + */ +public typealias Matrix = Structure2D diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt deleted file mode 100644 index 595a3dbe7..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnconstrainedDomain.kt +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2015 Alexander Nozik. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package scientifik.kmath.domains - -import scientifik.kmath.linear.Point - -class UnconstrainedDomain(override val dimension: Int) : RealDomain { - override operator fun contains(point: Point): Boolean = true - - override fun getLowerBound(num: Int, point: Point): Double? = Double.NEGATIVE_INFINITY - - override fun getLowerBound(num: Int): Double? = Double.NEGATIVE_INFINITY - - override fun getUpperBound(num: Int, point: Point): Double? = Double.POSITIVE_INFINITY - - override fun getUpperBound(num: Int): Double? = Double.POSITIVE_INFINITY - - override fun nearestInDomain(point: Point): Point = point - - override fun volume(): Double = Double.POSITIVE_INFINITY -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt deleted file mode 100644 index 280dc7d66..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/domains/UnivariateDomain.kt +++ /dev/null @@ -1,47 +0,0 @@ -package scientifik.kmath.domains - -import scientifik.kmath.linear.Point -import scientifik.kmath.structures.asBuffer - -inline class UnivariateDomain(val range: ClosedFloatingPointRange) : RealDomain { - operator fun contains(d: Double): Boolean = range.contains(d) - - override operator fun contains(point: Point): Boolean { - require(point.size == 0) - return contains(point[0]) - } - - override fun nearestInDomain(point: Point): Point { - require(point.size == 1) - val value = point[0] - return when { - value in range -> point - value >= range.endInclusive -> doubleArrayOf(range.endInclusive).asBuffer() - else -> doubleArrayOf(range.start).asBuffer() - } - } - - override fun getLowerBound(num: Int, point: Point): Double? { - require(num == 0) - return range.start - } - - override fun getUpperBound(num: Int, point: Point): Double? { - require(num == 0) - return range.endInclusive - } - - override fun getLowerBound(num: Int): Double? { - require(num == 0) - return range.start - } - - override fun getUpperBound(num: Int): Double? { - require(num == 0) - return range.endInclusive - } - - override fun volume(): Double = range.endInclusive - range.start - - override val dimension: Int get() = 1 -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt deleted file mode 100644 index fd11c246d..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ /dev/null @@ -1,49 +0,0 @@ -package scientifik.kmath.expressions - -import scientifik.kmath.operations.Algebra - -/** - * An elementary function that could be invoked on a map of arguments - */ -interface Expression { - /** - * Calls this expression from arguments. - * - * @param arguments the map of arguments. - * @return the value. - */ - operator fun invoke(arguments: Map): T - - companion object -} - -/** - * Create simple lazily evaluated expression inside given algebra - */ -fun Algebra.expression(block: Algebra.(arguments: Map) -> T): Expression = - object : Expression { - override operator fun invoke(arguments: Map): T = block(arguments) - } - -/** - * Calls this expression from arguments. - * - * @param pairs the pair of arguments' names to values. - * @return the value. - */ -operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) - -/** - * A context for expression construction - */ -interface ExpressionAlgebra : Algebra { - /** - * Introduce a variable into expression context - */ - fun variable(name: String, default: T? = null): E - - /** - * A constant expression which does not depend on arguments - */ - fun const(value: T): E -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt deleted file mode 100644 index d36c31a0d..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt +++ /dev/null @@ -1,169 +0,0 @@ -package scientifik.kmath.expressions - -import scientifik.kmath.operations.* - -internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) : - Expression { - override operator fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) -} - -internal class FunctionalBinaryOperation( - val context: Algebra, - val name: String, - val first: Expression, - val second: Expression -) : Expression { - override operator fun invoke(arguments: Map): T = - context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) -} - -internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { - override operator fun invoke(arguments: Map): T = - arguments[name] ?: default ?: error("Parameter not found: $name") -} - -internal class FunctionalConstantExpression(val value: T) : Expression { - override operator fun invoke(arguments: Map): T = value -} - -internal class FunctionalConstProductExpression( - val context: Space, - private val expr: Expression, - val const: Number -) : Expression { - override operator fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) -} - -/** - * A context class for [Expression] construction. - * - * @param algebra The algebra to provide for Expressions built. - */ -abstract class FunctionalExpressionAlgebra>(val algebra: A) : ExpressionAlgebra> { - /** - * Builds an Expression of constant expression which does not depend on arguments. - */ - override fun const(value: T): Expression = FunctionalConstantExpression(value) - - /** - * Builds an Expression to access a variable. - */ - override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) - - /** - * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. - */ - override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - FunctionalBinaryOperation(algebra, operation, left, right) - - /** - * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. - */ - override fun unaryOperation(operation: String, arg: Expression): Expression = - FunctionalUnaryOperation(algebra, operation, arg) -} - -/** - * A context class for [Expression] construction for [Space] algebras. - */ -open class FunctionalExpressionSpace>(algebra: A) : - FunctionalExpressionAlgebra(algebra), Space> { - override val zero: Expression get() = const(algebra.zero) - - /** - * Builds an Expression of addition of two another expressions. - */ - override fun add(a: Expression, b: Expression): Expression = - binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) - - /** - * Builds an Expression of multiplication of expression by number. - */ - override fun multiply(a: Expression, k: Number): Expression = - FunctionalConstProductExpression(algebra, a, k) - - operator fun Expression.plus(arg: T): Expression = this + const(arg) - operator fun Expression.minus(arg: T): Expression = this - const(arg) - operator fun T.plus(arg: Expression): Expression = arg + this - operator fun T.minus(arg: Expression): Expression = arg - this - - override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) -} - -open class FunctionalExpressionRing(algebra: A) : FunctionalExpressionSpace(algebra), - Ring> where A : Ring, A : NumericAlgebra { - override val one: Expression - get() = const(algebra.one) - - /** - * Builds an Expression of multiplication of two expressions. - */ - override fun multiply(a: Expression, b: Expression): Expression = - binaryOperation(RingOperations.TIMES_OPERATION, a, b) - - operator fun Expression.times(arg: T): Expression = this * const(arg) - operator fun T.times(arg: Expression): Expression = arg * this - - override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) -} - -open class FunctionalExpressionField(algebra: A) : - FunctionalExpressionRing(algebra), - Field> where A : Field, A : NumericAlgebra { - /** - * Builds an Expression of division an expression by another one. - */ - override fun divide(a: Expression, b: Expression): Expression = - binaryOperation(FieldOperations.DIV_OPERATION, a, b) - - operator fun Expression.div(arg: T): Expression = this / const(arg) - operator fun T.div(arg: Expression): Expression = arg / this - - override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) -} - -open class FunctionalExpressionExtendedField(algebra: A) : - FunctionalExpressionField(algebra), - ExtendedField> where A : ExtendedField, A : NumericAlgebra { - override fun sin(arg: Expression): Expression = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) - override fun cos(arg: Expression): Expression = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) - override fun asin(arg: Expression): Expression = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) - override fun acos(arg: Expression): Expression = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) - override fun atan(arg: Expression): Expression = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) - - override fun power(arg: Expression, pow: Number): Expression = - binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) - - override fun exp(arg: Expression): Expression = unaryOperation(ExponentialOperations.EXP_OPERATION, arg) - override fun ln(arg: Expression): Expression = unaryOperation(ExponentialOperations.LN_OPERATION, arg) - - override fun unaryOperation(operation: String, arg: Expression): Expression = - super.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = - super.binaryOperation(operation, left, right) -} - -inline fun > A.expressionInSpace(block: FunctionalExpressionSpace.() -> Expression): Expression = - FunctionalExpressionSpace(this).block() - -inline fun > A.expressionInRing(block: FunctionalExpressionRing.() -> Expression): Expression = - FunctionalExpressionRing(this).block() - -inline fun > A.expressionInField(block: FunctionalExpressionField.() -> Expression): Expression = - FunctionalExpressionField(this).block() - -inline fun > A.expressionInExtendedField(block: FunctionalExpressionExtendedField.() -> Expression): Expression = - FunctionalExpressionExtendedField(this).block() diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt deleted file mode 100644 index 343b8287e..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ /dev/null @@ -1,118 +0,0 @@ -package scientifik.kmath.linear - -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.Ring -import scientifik.kmath.structures.* - -/** - * Basic implementation of Matrix space based on [NDStructure] - */ -class BufferMatrixContext>( - override val elementContext: R, - private val bufferFactory: BufferFactory -) : GenericMatrixContext { - - override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix { - val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) } - return BufferMatrix(rows, columns, buffer) - } - - override fun point(size: Int, initializer: (Int) -> T): Point = bufferFactory(size, initializer) - - companion object -} - -@Suppress("OVERRIDE_BY_INLINE") -object RealMatrixContext : GenericMatrixContext { - - override val elementContext: RealField get() = RealField - - override inline fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix { - val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } - return BufferMatrix(rows, columns, buffer) - } - - override inline fun point(size: Int, initializer: (Int) -> Double): Point = RealBuffer(size, initializer) -} - -class BufferMatrix( - override val rowNum: Int, - override val colNum: Int, - val buffer: Buffer, - override val features: Set = emptySet() -) : FeaturedMatrix { - - init { - if (buffer.size != rowNum * colNum) { - error("Dimension mismatch for matrix structure") - } - } - - override val shape: IntArray get() = intArrayOf(rowNum, colNum) - - override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix = - BufferMatrix(rowNum, colNum, buffer, this.features + features) - - override operator fun get(index: IntArray): T = get(index[0], index[1]) - - override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j] - - override fun elements(): Sequence> = sequence { - for (i in 0 until rowNum) for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j)) - } - - override fun equals(other: Any?): Boolean { - if (this === other) return true - return when (other) { - is NDStructure<*> -> return NDStructure.equals(this, other) - else -> false - } - } - - override fun hashCode(): Int { - var result = buffer.hashCode() - result = 31 * result + features.hashCode() - return result - } - - override fun toString(): String { - return if (rowNum <= 5 && colNum <= 5) { - "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" + - rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer -> - buffer.asSequence().joinToString(separator = "\t") { it.toString() } - } - } else { - "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)" - } - } -} - -/** - * Optimized dot product for real matrices - */ -infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix { - require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } - - val array = DoubleArray(this.rowNum * other.colNum) - - //convert to array to insure there is not memory indirection - fun Buffer.unsafeArray(): DoubleArray = if (this is RealBuffer) { - array - } else { - DoubleArray(size) { get(it) } - } - - val a = this.buffer.unsafeArray() - val b = other.buffer.unsafeArray() - - for (i in (0 until rowNum)) { - for (j in (0 until other.colNum)) { - for (k in (0 until colNum)) { - array[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j] - } - } - } - - val buffer = RealBuffer(array) - return BufferMatrix(rowNum, other.colNum, buffer) -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt deleted file mode 100644 index 9b60bf719..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt +++ /dev/null @@ -1,88 +0,0 @@ -package scientifik.kmath.linear - -import scientifik.kmath.operations.Ring -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.Structure2D -import scientifik.kmath.structures.asBuffer -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract -import kotlin.math.sqrt - -/** - * A 2d structure plus optional matrix-specific features - */ -interface FeaturedMatrix : Matrix { - - override val shape: IntArray get() = intArrayOf(rowNum, colNum) - - val features: Set - - /** - * Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure. - * - * The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to - * add only those features that are valid. - */ - fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix - - companion object -} - -inline fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double): Matrix { - contract { callsInPlace(initializer) } - return MatrixContext.real.produce(rows, columns, initializer) -} - -/** - * Build a square matrix from given elements. - */ -fun Structure2D.Companion.square(vararg elements: T): FeaturedMatrix { - val size: Int = sqrt(elements.size.toDouble()).toInt() - require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" } - val buffer = elements.asBuffer() - return BufferMatrix(size, size, buffer) -} - -val Matrix<*>.features: Set get() = (this as? FeaturedMatrix)?.features ?: emptySet() - -/** - * Check if matrix has the given feature class - */ -inline fun Matrix<*>.hasFeature(): Boolean = - features.find { it is T } != null - -/** - * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria - */ -inline fun Matrix<*>.getFeature(): T? = - features.filterIsInstance().firstOrNull() - -/** - * Diagonal matrix of ones. The matrix is virtual no actual matrix is created - */ -fun > GenericMatrixContext.one(rows: Int, columns: Int): FeaturedMatrix = - VirtualMatrix(rows, columns, DiagonalFeature) { i, j -> - if (i == j) elementContext.one else elementContext.zero - } - - -/** - * A virtual matrix of zeroes - */ -fun > GenericMatrixContext.zero(rows: Int, columns: Int): FeaturedMatrix = - VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } - -class TransposedFeature(val original: Matrix) : MatrixFeature - -/** - * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` - */ -fun Matrix.transpose(): Matrix { - return this.getFeature>()?.original ?: VirtualMatrix( - this.colNum, - this.rowNum, - setOf(TransposedFeature(this)) - ) { i, j -> get(j, i) } -} - -infix fun Matrix.dot(other: Matrix): Matrix = with(MatrixContext.real) { dot(other) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgebra.kt deleted file mode 100644 index fb49d18ed..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgebra.kt +++ /dev/null @@ -1,28 +0,0 @@ -package scientifik.kmath.linear - -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.VirtualBuffer - -typealias Point = Buffer - -/** - * A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors - */ -interface LinearSolver { - fun solve(a: Matrix, b: Matrix): Matrix - fun solve(a: Matrix, b: Point): Point = solve(a, b.asMatrix()).asPoint() - fun inverse(a: Matrix): Matrix -} - -/** - * Convert matrix to vector if it is possible - */ -fun Matrix.asPoint(): Point = - if (this.colNum == 1) { - VirtualBuffer(rowNum) { get(it, 0) } - } else { - error("Can't convert matrix with more than one column to vector") - } - -fun Point.asMatrix(): VirtualMatrix = VirtualMatrix(size, 1) { i, _ -> get(i) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt deleted file mode 100644 index 763bb1615..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt +++ /dev/null @@ -1,109 +0,0 @@ -package scientifik.kmath.linear - -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.SpaceOperations -import scientifik.kmath.operations.invoke -import scientifik.kmath.operations.sum -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.BufferFactory -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.asSequence - -/** - * Basic operations on matrices. Operates on [Matrix] - */ -interface MatrixContext : SpaceOperations> { - /** - * Produce a matrix with this context and given dimensions - */ - fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix - - infix fun Matrix.dot(other: Matrix): Matrix - - infix fun Matrix.dot(vector: Point): Point - - operator fun Matrix.times(value: T): Matrix - - operator fun T.times(m: Matrix): Matrix = m * this - - companion object { - /** - * Non-boxing double matrix - */ - val real: RealMatrixContext = RealMatrixContext - - /** - * A structured matrix with custom buffer - */ - fun > buffered( - ring: R, - bufferFactory: BufferFactory = Buffer.Companion::boxing - ): GenericMatrixContext = BufferMatrixContext(ring, bufferFactory) - - /** - * Automatic buffered matrix, unboxed if it is possible - */ - inline fun > auto(ring: R): GenericMatrixContext = - buffered(ring, Buffer.Companion::auto) - } -} - -interface GenericMatrixContext> : MatrixContext { - /** - * The ring context for matrix elements - */ - val elementContext: R - - /** - * Produce a point compatible with matrix space - */ - fun point(size: Int, initializer: (Int) -> T): Point - - override infix fun Matrix.dot(other: Matrix): Matrix { - //TODO add typed error - require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } - - return produce(rowNum, other.colNum) { i, j -> - val row = rows[i] - val column = other.columns[j] - elementContext { sum(row.asSequence().zip(column.asSequence(), ::multiply)) } - } - } - - override infix fun Matrix.dot(vector: Point): Point { - //TODO add typed error - require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" } - - return point(rowNum) { i -> - val row = rows[i] - elementContext { sum(row.asSequence().zip(vector.asSequence(), ::multiply)) } - } - } - - override operator fun Matrix.unaryMinus(): Matrix = - produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } } - - override fun add(a: Matrix, b: Matrix): Matrix { - require(a.rowNum == b.rowNum && a.colNum == b.colNum) { - "Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]" - } - - return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } } - } - - override operator fun Matrix.minus(b: Matrix): Matrix { - require(rowNum == b.rowNum && colNum == b.colNum) { - "Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]" - } - - return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } } - } - - override fun multiply(a: Matrix, k: Number): Matrix = - produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } } - - operator fun Number.times(matrix: FeaturedMatrix): Matrix = matrix * this - - override operator fun Matrix.times(value: T): Matrix = - produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } } -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt deleted file mode 100644 index 87cfe21b0..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt +++ /dev/null @@ -1,62 +0,0 @@ -package scientifik.kmath.linear - -/** - * A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix - * operations performance in some cases. - */ -interface MatrixFeature - -/** - * The matrix with this feature is considered to have only diagonal non-null elements - */ -object DiagonalFeature : MatrixFeature - -/** - * Matrix with this feature has all zero elements - */ -object ZeroFeature : MatrixFeature - -/** - * Matrix with this feature have unit elements on diagonal and zero elements in all other places - */ -object UnitFeature : MatrixFeature - -/** - * Inverted matrix feature - */ -interface InverseMatrixFeature : MatrixFeature { - val inverse: FeaturedMatrix -} - -/** - * A determinant container - */ -interface DeterminantFeature : MatrixFeature { - val determinant: T -} - -@Suppress("FunctionName") -fun DeterminantFeature(determinant: T): DeterminantFeature = object : DeterminantFeature { - override val determinant: T = determinant -} - -/** - * Lower triangular matrix - */ -object LFeature : MatrixFeature - -/** - * Upper triangular feature - */ -object UFeature : MatrixFeature - -/** - * TODO add documentation - */ -interface LUPDecompositionFeature : MatrixFeature { - val l: FeaturedMatrix - val u: FeaturedMatrix - val p: FeaturedMatrix -} - -//TODO add sparse matrix feature diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt deleted file mode 100644 index 5266dc884..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt +++ /dev/null @@ -1,59 +0,0 @@ -package scientifik.kmath.linear - -import scientifik.kmath.structures.Matrix - -class VirtualMatrix( - override val rowNum: Int, - override val colNum: Int, - override val features: Set = emptySet(), - val generator: (i: Int, j: Int) -> T -) : FeaturedMatrix { - - constructor(rowNum: Int, colNum: Int, vararg features: MatrixFeature, generator: (i: Int, j: Int) -> T) : this( - rowNum, - colNum, - setOf(*features), - generator - ) - - override val shape: IntArray get() = intArrayOf(rowNum, colNum) - - override operator fun get(i: Int, j: Int): T = generator(i, j) - - override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix = - VirtualMatrix(rowNum, colNum, this.features + features, generator) - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is FeaturedMatrix<*>) return false - - if (rowNum != other.rowNum) return false - if (colNum != other.colNum) return false - - return elements().all { (index, value) -> value == other[index] } - } - - override fun hashCode(): Int { - var result = rowNum - result = 31 * result + colNum - result = 31 * result + features.hashCode() - result = 31 * result + generator.hashCode() - return result - } - - - companion object { - /** - * Wrap a matrix adding additional features to it - */ - fun wrap(matrix: Matrix, vararg features: MatrixFeature): FeaturedMatrix { - return if (matrix is VirtualMatrix) { - VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator) - } else { - VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j -> - matrix[i, j] - } - } - } - } -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt deleted file mode 100644 index be222783e..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt +++ /dev/null @@ -1,211 +0,0 @@ -package scientifik.kmath.misc - -import scientifik.kmath.linear.Point -import scientifik.kmath.operations.ExtendedField -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.invoke -import scientifik.kmath.operations.sum -import scientifik.kmath.structures.asBuffer -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract - -/* - * Implementation of backward-mode automatic differentiation. - * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d - */ - -/** - * Differentiable variable with value and derivative of differentiation ([deriv]) result - * with respect to this variable. - */ -open class Variable(val value: T) - -class DerivationResult( - value: T, - val deriv: Map, T>, - val context: Field -) : Variable(value) { - fun deriv(variable: Variable): T = deriv[variable] ?: context.zero - - /** - * compute divergence - */ - fun div(): T = context { sum(deriv.values) } - - /** - * Compute a gradient for variables in given order - */ - fun grad(vararg variables: Variable): Point { - check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" } - return variables.map(::deriv).asBuffer() - } -} - -/** - * Runs differentiation and establishes [AutoDiffField] context inside the block of code. - * - * The partial derivatives are placed in argument `d` variable - * - * Example: - * ``` - * val x = Variable(2) // define variable(s) and their values - * val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context - * assertEquals(17.0, y.x) // the value of result (y) - * assertEquals(9.0, x.d) // dy/dx - * ``` - */ -inline fun > F.deriv(body: AutoDiffField.() -> Variable): DerivationResult { - contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } - - return (AutoDiffContext(this)) { - val result = body() - result.d = context.one // computing derivative w.r.t result - runBackwardPass() - DerivationResult(result.value, derivatives, this@deriv) - } -} - - -abstract class AutoDiffField> : Field> { - abstract val context: F - - /** - * A variable accessing inner state of derivatives. - * Use this function in inner builders to avoid creating additional derivative bindings - */ - abstract var Variable.d: T - - /** - * Performs update of derivative after the rest of the formula in the back-pass. - * - * For example, implementation of `sin` function is: - * - * ``` - * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result - * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function - * } - * ``` - */ - abstract fun derive(value: R, block: F.(R) -> Unit): R - - abstract fun variable(value: T): Variable - - inline fun variable(block: F.() -> T): Variable = variable(context.block()) - - // Overloads for Double constants - - override operator fun Number.plus(b: Variable): Variable = - derive(variable { this@plus.toDouble() * one + b.value }) { z -> - b.d += z.d - } - - override operator fun Variable.plus(b: Number): Variable = b.plus(this) - - override operator fun Number.minus(b: Variable): Variable = - derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } - - override operator fun Variable.minus(b: Number): Variable = - derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } -} - -/** - * Automatic Differentiation context class. - */ -@PublishedApi -internal class AutoDiffContext>(override val context: F) : AutoDiffField() { - // this stack contains pairs of blocks and values to apply them to - private var stack: Array = arrayOfNulls(8) - private var sp: Int = 0 - val derivatives: MutableMap, T> = hashMapOf() - override val zero: Variable get() = Variable(context.zero) - override val one: Variable get() = Variable(context.one) - - /** - * A variable coupled with its derivative. For internal use only - */ - private class VariableWithDeriv(x: T, var d: T) : Variable(x) - - override fun variable(value: T): Variable = - VariableWithDeriv(value, context.zero) - - override var Variable.d: T - get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero - set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value - - @Suppress("UNCHECKED_CAST") - override fun derive(value: R, block: F.(R) -> Unit): R { - // save block to stack for backward pass - if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) - stack[sp++] = block - stack[sp++] = value - return value - } - - @Suppress("UNCHECKED_CAST") - fun runBackwardPass() { - while (sp > 0) { - val value = stack[--sp] - val block = stack[--sp] as F.(Any?) -> Unit - context.block(value) - } - } - - // Basic math (+, -, *, /) - - - override fun add(a: Variable, b: Variable): Variable = derive(variable { a.value + b.value }) { z -> - a.d += z.d - b.d += z.d - } - - override fun multiply(a: Variable, b: Variable): Variable = derive(variable { a.value * b.value }) { z -> - a.d += z.d * b.value - b.d += z.d * a.value - } - - override fun divide(a: Variable, b: Variable): Variable = derive(variable { a.value / b.value }) { z -> - a.d += z.d / b.value - b.d -= z.d * a.value / (b.value * b.value) - } - - override fun multiply(a: Variable, k: Number): Variable = derive(variable { k.toDouble() * a.value }) { z -> - a.d += z.d * k.toDouble() - } -} - -// Extensions for differentiation of various basic mathematical functions - -// x ^ 2 -fun > AutoDiffField.sqr(x: Variable): Variable = - derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } - -// x ^ 1/2 -fun > AutoDiffField.sqrt(x: Variable): Variable = - derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } - -// x ^ y (const) -fun > AutoDiffField.pow(x: Variable, y: Double): Variable = - derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } - -fun > AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble()) - -// exp(x) -fun > AutoDiffField.exp(x: Variable): Variable = - derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value } - -// ln(x) -fun > AutoDiffField.ln(x: Variable): Variable = - derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value } - -// x ^ y (any) -fun > AutoDiffField.pow(x: Variable, y: Variable): Variable = - exp(y * ln(x)) - -// sin(x) -fun > AutoDiffField.sin(x: Variable): Variable = - derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } - -// cos(x) -fun > AutoDiffField.cos(x: Variable): Variable = - derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt deleted file mode 100644 index e11adc135..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/cumulative.kt +++ /dev/null @@ -1,80 +0,0 @@ -package scientifik.kmath.misc - -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract -import kotlin.jvm.JvmName - -/** - * Generic cumulative operation on iterator. - * - * @param T the type of initial iterable. - * @param R the type of resulting iterable. - * @param initial lazy evaluated. - */ -inline fun Iterator.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterator { - contract { callsInPlace(operation) } - - return object : Iterator { - var state: R = initial - - override fun hasNext(): Boolean = this@cumulative.hasNext() - - override fun next(): R { - state = operation(state, this@cumulative.next()) - return state - } - } -} - -inline fun Iterable.cumulative(initial: R, crossinline operation: (R, T) -> R): Iterable = - Iterable { this@cumulative.iterator().cumulative(initial, operation) } - -inline fun Sequence.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence = Sequence { - this@cumulative.iterator().cumulative(initial, operation) -} - -fun List.cumulative(initial: R, operation: (R, T) -> R): List = - iterator().cumulative(initial, operation).asSequence().toList() - -//Cumulative sum - -/** - * Cumulative sum with custom space - */ -fun Iterable.cumulativeSum(space: Space): Iterable = - space { cumulative(zero) { element: T, sum: T -> sum + element } } - -@JvmName("cumulativeSumOfDouble") -fun Iterable.cumulativeSum(): Iterable = cumulative(0.0) { element, sum -> sum + element } - -@JvmName("cumulativeSumOfInt") -fun Iterable.cumulativeSum(): Iterable = cumulative(0) { element, sum -> sum + element } - -@JvmName("cumulativeSumOfLong") -fun Iterable.cumulativeSum(): Iterable = cumulative(0L) { element, sum -> sum + element } - -fun Sequence.cumulativeSum(space: Space): Sequence = - space { cumulative(zero) { element: T, sum: T -> sum + element } } - -@JvmName("cumulativeSumOfDouble") -fun Sequence.cumulativeSum(): Sequence = cumulative(0.0) { element, sum -> sum + element } - -@JvmName("cumulativeSumOfInt") -fun Sequence.cumulativeSum(): Sequence = cumulative(0) { element, sum -> sum + element } - -@JvmName("cumulativeSumOfLong") -fun Sequence.cumulativeSum(): Sequence = cumulative(0L) { element, sum -> sum + element } - -fun List.cumulativeSum(space: Space): List = - space { cumulative(zero) { element: T, sum: T -> sum + element } } - -@JvmName("cumulativeSumOfDouble") -fun List.cumulativeSum(): List = cumulative(0.0) { element, sum -> sum + element } - -@JvmName("cumulativeSumOfInt") -fun List.cumulativeSum(): List = cumulative(0) { element, sum -> sum + element } - -@JvmName("cumulativeSumOfLong") -fun List.cumulativeSum(): List = cumulative(0L) { element, sum -> sum + element } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt deleted file mode 100644 index f18bde597..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ /dev/null @@ -1,340 +0,0 @@ -package scientifik.kmath.operations - -/** - * Stub for DSL the [Algebra] is. - */ -@DslMarker -annotation class KMathContext - -/** - * Represents an algebraic structure. - * - * @param T the type of element of this structure. - */ -interface Algebra { - /** - * Wrap raw string or variable - */ - fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") - - /** - * Dynamic call of unary operation with name [operation] on [arg] - */ - fun unaryOperation(operation: String, arg: T): T - - /** - * Dynamic call of binary operation [operation] on [left] and [right] - */ - fun binaryOperation(operation: String, left: T, right: T): T -} - -/** - * An algebraic structure where elements can have numeric representation. - * - * @param T the type of element of this structure. - */ -interface NumericAlgebra : Algebra { - /** - * Wraps a number. - */ - fun number(value: Number): T - - /** - * Dynamic call of binary operation [operation] on [left] and [right] where left element is [Number]. - */ - fun leftSideNumberOperation(operation: String, left: Number, right: T): T = - binaryOperation(operation, number(left), right) - - /** - * Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number]. - */ - fun rightSideNumberOperation(operation: String, left: T, right: Number): T = - leftSideNumberOperation(operation, right, left) -} - -/** - * Call a block with an [Algebra] as receiver. - */ -inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) - -/** - * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as - * multiplication by scalars. - * - * @param T the type of element of this semispace. - */ -interface SpaceOperations : Algebra { - /** - * Addition of two elements. - * - * @param a the addend. - * @param b the augend. - * @return the sum. - */ - fun add(a: T, b: T): T - - /** - * Multiplication of element by scalar. - * - * @param a the multiplier. - * @param k the multiplicand. - * @return the produce. - */ - fun multiply(a: T, k: Number): T - - // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176 - - /** - * The negation of this element. - * - * @receiver this value. - * @return the additive inverse of this value. - */ - operator fun T.unaryMinus(): T = multiply(this, -1.0) - - /** - * Returns this value. - * - * @receiver this value. - * @return this value. - */ - operator fun T.unaryPlus(): T = this - - /** - * Addition of two elements. - * - * @receiver the addend. - * @param b the augend. - * @return the sum. - */ - operator fun T.plus(b: T): T = add(this, b) - - /** - * Subtraction of two elements. - * - * @receiver the minuend. - * @param b the subtrahend. - * @return the difference. - */ - operator fun T.minus(b: T): T = add(this, -b) - - /** - * Multiplication of this element by a scalar. - * - * @receiver the multiplier. - * @param k the multiplicand. - * @return the product. - */ - operator fun T.times(k: Number): T = multiply(this, k.toDouble()) - - /** - * Division of this element by scalar. - * - * @receiver the dividend. - * @param k the divisor. - * @return the quotient. - */ - operator fun T.div(k: Number): T = multiply(this, 1.0 / k.toDouble()) - - /** - * Multiplication of this number by element. - * - * @receiver the multiplier. - * @param b the multiplicand. - * @return the product. - */ - operator fun Number.times(b: T): T = b * this - - override fun unaryOperation(operation: String, arg: T): T = when (operation) { - PLUS_OPERATION -> arg - MINUS_OPERATION -> -arg - else -> error("Unary operation $operation not defined in $this") - } - - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - PLUS_OPERATION -> add(left, right) - MINUS_OPERATION -> left - right - else -> error("Binary operation $operation not defined in $this") - } - - companion object { - /** - * The identifier of addition. - */ - const val PLUS_OPERATION: String = "+" - - /** - * The identifier of subtraction (and negation). - */ - const val MINUS_OPERATION: String = "-" - - const val NOT_OPERATION: String = "!" - } -} - -/** - * Represents linear space, i.e. algebraic structure with associative binary operation called "addition" and its neutral - * element as well as multiplication by scalars. - * - * @param T the type of element of this group. - */ -interface Space : SpaceOperations { - /** - * The neutral element of addition. - */ - val zero: T -} - -/** - * Represents semiring, i.e. algebraic structure with two associative binary operations called "addition" and - * "multiplication". - * - * @param T the type of element of this semiring. - */ -interface RingOperations : SpaceOperations { - /** - * Multiplies two elements. - * - * @param a the multiplier. - * @param b the multiplicand. - */ - fun multiply(a: T, b: T): T - - /** - * Multiplies this element by scalar. - * - * @receiver the multiplier. - * @param b the multiplicand. - */ - operator fun T.times(b: T): T = multiply(this, b) - - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - TIMES_OPERATION -> multiply(left, right) - else -> super.binaryOperation(operation, left, right) - } - - companion object { - /** - * The identifier of multiplication. - */ - const val TIMES_OPERATION: String = "*" - } -} - -/** - * Represents ring, i.e. algebraic structure with two associative binary operations called "addition" and - * "multiplication" and their neutral elements. - * - * @param T the type of element of this ring. - */ -interface Ring : Space, RingOperations, NumericAlgebra { - /** - * neutral operation for multiplication - */ - val one: T - - override fun number(value: Number): T = one * value.toDouble() - - override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { - SpaceOperations.PLUS_OPERATION -> left + right - SpaceOperations.MINUS_OPERATION -> left - right - RingOperations.TIMES_OPERATION -> left * right - else -> super.leftSideNumberOperation(operation, left, right) - } - - override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { - SpaceOperations.PLUS_OPERATION -> left + right - SpaceOperations.MINUS_OPERATION -> left - right - RingOperations.TIMES_OPERATION -> left * right - else -> super.rightSideNumberOperation(operation, left, right) - } - - /** - * Addition of element and scalar. - * - * @receiver the addend. - * @param b the augend. - */ - operator fun T.plus(b: Number): T = this + number(b) - - /** - * Addition of scalar and element. - * - * @receiver the addend. - * @param b the augend. - */ - operator fun Number.plus(b: T): T = b + this - - /** - * Subtraction of element from number. - * - * @receiver the minuend. - * @param b the subtrahend. - * @receiver the difference. - */ - operator fun T.minus(b: Number): T = this - number(b) - - /** - * Subtraction of number from element. - * - * @receiver the minuend. - * @param b the subtrahend. - * @receiver the difference. - */ - operator fun Number.minus(b: T): T = -b + this -} - -/** - * Represents semifield, i.e. algebraic structure with three operations: associative "addition" and "multiplication", - * and "division". - * - * @param T the type of element of this semifield. - */ -interface FieldOperations : RingOperations { - /** - * Division of two elements. - * - * @param a the dividend. - * @param b the divisor. - * @return the quotient. - */ - fun divide(a: T, b: T): T - - /** - * Division of two elements. - * - * @receiver the dividend. - * @param b the divisor. - * @return the quotient. - */ - operator fun T.div(b: T): T = divide(this, b) - - override fun binaryOperation(operation: String, left: T, right: T): T = when (operation) { - DIV_OPERATION -> divide(left, right) - else -> super.binaryOperation(operation, left, right) - } - - companion object { - /** - * The identifier of division. - */ - const val DIV_OPERATION: String = "/" - } -} - -/** - * Represents field, i.e. algebraic structure with three operations: associative "addition" and "multiplication", - * and "division" and their neutral elements. - * - * @param T the type of element of this semifield. - */ -interface Field : Ring, FieldOperations { - /** - * Division of element by scalar. - * - * @receiver the dividend. - * @param b the divisor. - * @return the quotient. - */ - operator fun Number.div(b: T): T = this * divide(one, b) -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt deleted file mode 100644 index 0735a96da..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ /dev/null @@ -1,270 +0,0 @@ -package scientifik.kmath.operations - -import scientifik.kmath.operations.RealField.pow -import kotlin.math.abs -import kotlin.math.pow as kpow - -/** - * Advanced Number-like semifield that implements basic operations. - */ -interface ExtendedFieldOperations : - FieldOperations, - TrigonometricOperations, - HyperbolicOperations, - PowerOperations, - ExponentialOperations { - - override fun tan(arg: T): T = sin(arg) / cos(arg) - override fun tanh(arg: T): T = sinh(arg) / cosh(arg) - - override fun unaryOperation(operation: String, arg: T): T = when (operation) { - TrigonometricOperations.COS_OPERATION -> cos(arg) - TrigonometricOperations.SIN_OPERATION -> sin(arg) - TrigonometricOperations.TAN_OPERATION -> tan(arg) - TrigonometricOperations.ACOS_OPERATION -> acos(arg) - TrigonometricOperations.ASIN_OPERATION -> asin(arg) - TrigonometricOperations.ATAN_OPERATION -> atan(arg) - HyperbolicOperations.COSH_OPERATION -> cosh(arg) - HyperbolicOperations.SINH_OPERATION -> sinh(arg) - HyperbolicOperations.TANH_OPERATION -> tanh(arg) - HyperbolicOperations.ACOSH_OPERATION -> acosh(arg) - HyperbolicOperations.ASINH_OPERATION -> asinh(arg) - HyperbolicOperations.ATANH_OPERATION -> atanh(arg) - PowerOperations.SQRT_OPERATION -> sqrt(arg) - ExponentialOperations.EXP_OPERATION -> exp(arg) - ExponentialOperations.LN_OPERATION -> ln(arg) - else -> super.unaryOperation(operation, arg) - } -} - - -/** - * Advanced Number-like field that implements basic operations. - */ -interface ExtendedField : ExtendedFieldOperations, Field { - override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2 - override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2 - override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) - override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg) - override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) - override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2 - - override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { - PowerOperations.POW_OPERATION -> power(left, right) - else -> super.rightSideNumberOperation(operation, left, right) - } -} - -/** - * Real field element wrapping double. - * - * @property value the [Double] value wrapped by this [Real]. - * - * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 - */ -inline class Real(val value: Double) : FieldElement { - override val context: RealField - get() = RealField - - override fun unwrap(): Double = value - - override fun Double.wrap(): Real = Real(value) - - companion object -} - -/** - * A field for [Double] without boxing. Does not produce appropriate field element. - */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -object RealField : ExtendedField, Norm { - override val zero: Double - get() = 0.0 - - override val one: Double - get() = 1.0 - - override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { - PowerOperations.POW_OPERATION -> left pow right - else -> super.binaryOperation(operation, left, right) - } - - override inline fun add(a: Double, b: Double): Double = a + b - override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() - - override inline fun multiply(a: Double, b: Double): Double = a * b - - override inline fun divide(a: Double, b: Double): Double = a / b - - override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) - override inline fun cos(arg: Double): Double = kotlin.math.cos(arg) - override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) - override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) - override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) - override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) - - override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg) - override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg) - override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg) - override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg) - override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg) - override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg) - - override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) - override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) - override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) - - override inline fun norm(arg: Double): Double = abs(arg) - - override inline fun Double.unaryMinus(): Double = -this - override inline fun Double.plus(b: Double): Double = this + b - override inline fun Double.minus(b: Double): Double = this - b - override inline fun Double.times(b: Double): Double = this * b - override inline fun Double.div(b: Double): Double = this / b -} - -/** - * A field for [Float] without boxing. Does not produce appropriate field element. - */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -object FloatField : ExtendedField, Norm { - override val zero: Float - get() = 0.0f - - override val one: Float - get() = 1.0f - - override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) { - PowerOperations.POW_OPERATION -> left pow right - else -> super.binaryOperation(operation, left, right) - } - - override inline fun add(a: Float, b: Float): Float = a + b - override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() - - override inline fun multiply(a: Float, b: Float): Float = a * b - - override inline fun divide(a: Float, b: Float): Float = a / b - - override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) - override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) - override inline fun tan(arg: Float): Float = kotlin.math.tan(arg) - override inline fun acos(arg: Float): Float = kotlin.math.acos(arg) - override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) - override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) - - override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg) - override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg) - override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg) - override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg) - override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg) - override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg) - - override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat()) - override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) - override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) - - override inline fun norm(arg: Float): Float = abs(arg) - - override inline fun Float.unaryMinus(): Float = -this - override inline fun Float.plus(b: Float): Float = this + b - override inline fun Float.minus(b: Float): Float = this - b - override inline fun Float.times(b: Float): Float = this * b - override inline fun Float.div(b: Float): Float = this / b -} - -/** - * A field for [Int] without boxing. Does not produce corresponding ring element. - */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -object IntRing : Ring, Norm { - override val zero: Int - get() = 0 - - override val one: Int - get() = 1 - - override inline fun add(a: Int, b: Int): Int = a + b - override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a - - override inline fun multiply(a: Int, b: Int): Int = a * b - - override inline fun norm(arg: Int): Int = abs(arg) - - override inline fun Int.unaryMinus(): Int = -this - override inline fun Int.plus(b: Int): Int = this + b - override inline fun Int.minus(b: Int): Int = this - b - override inline fun Int.times(b: Int): Int = this * b -} - -/** - * A field for [Short] without boxing. Does not produce appropriate ring element. - */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -object ShortRing : Ring, Norm { - override val zero: Short - get() = 0 - - override val one: Short - get() = 1 - - override inline fun add(a: Short, b: Short): Short = (a + b).toShort() - override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() - - override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() - - override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() - - override inline fun Short.unaryMinus(): Short = (-this).toShort() - override inline fun Short.plus(b: Short): Short = (this + b).toShort() - override inline fun Short.minus(b: Short): Short = (this - b).toShort() - override inline fun Short.times(b: Short): Short = (this * b).toShort() -} - -/** - * A field for [Byte] without boxing. Does not produce appropriate ring element. - */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -object ByteRing : Ring, Norm { - override val zero: Byte - get() = 0 - - override val one: Byte - get() = 1 - - override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() - override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() - - override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() - - override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() - - override inline fun Byte.unaryMinus(): Byte = (-this).toByte() - override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() - override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() - override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() -} - -/** - * A field for [Double] without boxing. Does not produce appropriate ring element. - */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -object LongRing : Ring, Norm { - override val zero: Long - get() = 0 - - override val one: Long - get() = 1 - - override inline fun add(a: Long, b: Long): Long = a + b - override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() - - override inline fun multiply(a: Long, b: Long): Long = a * b - - override fun norm(arg: Long): Long = abs(arg) - - override inline fun Long.unaryMinus(): Long = (-this) - override inline fun Long.plus(b: Long): Long = (this + b) - override inline fun Long.minus(b: Long): Long = (this - b) - override inline fun Long.times(b: Long): Long = (this * b) -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt deleted file mode 100644 index 2c3d69094..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferAccessor2D.kt +++ /dev/null @@ -1,43 +0,0 @@ -package scientifik.kmath.structures - -import kotlin.reflect.KClass - -/** - * A context that allows to operate on a [MutableBuffer] as on 2d array - */ -class BufferAccessor2D(val type: KClass, val rowNum: Int, val colNum: Int) { - operator fun Buffer.get(i: Int, j: Int): T = get(i + colNum * j) - - operator fun MutableBuffer.set(i: Int, j: Int, value: T) { - set(i + colNum * j, value) - } - - inline fun create(init: (i: Int, j: Int) -> T): MutableBuffer = - MutableBuffer.auto(type, rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) } - - fun create(mat: Structure2D): MutableBuffer = create { i, j -> mat[i, j] } - - //TODO optimize wrapper - fun MutableBuffer.collect(): Structure2D = - NDStructure.auto(type, rowNum, colNum) { (i, j) -> get(i, j) }.as2D() - - - inner class Row(val buffer: MutableBuffer, val rowIndex: Int) : MutableBuffer { - override val size: Int get() = colNum - - override operator fun get(index: Int): T = buffer[rowIndex, index] - - override operator fun set(index: Int, value: T) { - buffer[rowIndex, index] = value - } - - override fun copy(): MutableBuffer = MutableBuffer.auto(type, colNum) { get(it) } - override operator fun iterator(): Iterator = (0 until colNum).map(::get).iterator() - - } - - /** - * Get row - */ - fun MutableBuffer.row(i: Int): Row = Row(this, i) -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt deleted file mode 100644 index 2c0c2021f..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferedNDAlgebra.kt +++ /dev/null @@ -1,43 +0,0 @@ -package scientifik.kmath.structures - -import scientifik.kmath.operations.* - -interface BufferedNDAlgebra : NDAlgebra> { - val strides: Strides - - override fun check(vararg elements: NDBuffer): Unit = - require(elements.all { it.strides == strides }) { ("Strides mismatch") } - - /** - * Convert any [NDStructure] to buffered structure using strides from this context. - * If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over - * indices. - * - * If the argument is [NDBuffer] with different strides structure, the new element will be produced. - */ - fun NDStructure.toBuffer(): NDBuffer { - return if (this is NDBuffer && this.strides == this@BufferedNDAlgebra.strides) { - this - } else { - produce { index -> get(index) } - } - } - - /** - * Convert a buffer to element of this algebra - */ - fun NDBuffer.toElement(): MathElement> -} - - -interface BufferedNDSpace> : NDSpace>, BufferedNDAlgebra { - override fun NDBuffer.toElement(): SpaceElement, *, out BufferedNDSpace> -} - -interface BufferedNDRing> : NDRing>, BufferedNDSpace { - override fun NDBuffer.toElement(): RingElement, *, out BufferedNDRing> -} - -interface BufferedNDField> : NDField>, BufferedNDRing { - override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt deleted file mode 100644 index 4afaa63ab..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ /dev/null @@ -1,271 +0,0 @@ -package scientifik.kmath.structures - -import scientifik.kmath.operations.Complex -import scientifik.kmath.operations.complex -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract -import kotlin.reflect.KClass - -/** - * Function that produces [Buffer] from its size and function that supplies values. - * - * @param T the type of buffer. - */ -typealias BufferFactory = (Int, (Int) -> T) -> Buffer - -/** - * Function that produces [MutableBuffer] from its size and function that supplies values. - * - * @param T the type of buffer. - */ -typealias MutableBufferFactory = (Int, (Int) -> T) -> MutableBuffer - -/** - * A generic immutable random-access structure for both primitives and objects. - * - * @param T the type of elements contained in the buffer. - */ -interface Buffer { - /** - * The size of this buffer. - */ - val size: Int - - /** - * Gets element at given index. - */ - operator fun get(index: Int): T - - /** - * Iterates over all elements. - */ - operator fun iterator(): Iterator - - /** - * Checks content equality with another buffer. - */ - fun contentEquals(other: Buffer<*>): Boolean = - asSequence().mapIndexed { index, value -> value == other[index] }.all { it } - - companion object { - inline fun real(size: Int, initializer: (Int) -> Double): RealBuffer { - val array = DoubleArray(size) { initializer(it) } - return RealBuffer(array) - } - - /** - * Create a boxing buffer of given type - */ - inline fun boxing(size: Int, initializer: (Int) -> T): Buffer = ListBuffer(List(size, initializer)) - - @Suppress("UNCHECKED_CAST") - inline fun auto(type: KClass, size: Int, crossinline initializer: (Int) -> T): Buffer { - //TODO add resolution based on Annotation or companion resolution - return when (type) { - Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer - Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as Buffer - Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer - Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer - Complex::class -> complex(size) { initializer(it) as Complex } as Buffer - else -> boxing(size, initializer) - } - } - - /** - * Create most appropriate immutable buffer for given type avoiding boxing wherever possible - */ - @Suppress("UNCHECKED_CAST") - inline fun auto(size: Int, crossinline initializer: (Int) -> T): Buffer = - auto(T::class, size, initializer) - } -} - -/** - * Creates a sequence that returns all elements from this [Buffer]. - */ -fun Buffer.asSequence(): Sequence = Sequence(::iterator) - -/** - * Creates an iterable that returns all elements from this [Buffer]. - */ -fun Buffer.asIterable(): Iterable = Iterable(::iterator) - -/** - * Returns an [IntRange] of the valid indices for this [Buffer]. - */ -val Buffer<*>.indices: IntRange get() = 0 until size - -/** - * A generic mutable random-access structure for both primitives and objects. - * - * @param T the type of elements contained in the buffer. - */ -interface MutableBuffer : Buffer { - /** - * Sets the array element at the specified [index] to the specified [value]. - */ - operator fun set(index: Int, value: T) - - /** - * Returns a shallow copy of the buffer. - */ - fun copy(): MutableBuffer - - companion object { - /** - * Create a boxing mutable buffer of given type - */ - inline fun boxing(size: Int, initializer: (Int) -> T): MutableBuffer = - MutableListBuffer(MutableList(size, initializer)) - - @Suppress("UNCHECKED_CAST") - inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer = - when (type) { - Double::class -> RealBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer - Short::class -> ShortBuffer(ShortArray(size) { initializer(it) as Short }) as MutableBuffer - Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer - Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer - else -> boxing(size, initializer) - } - - /** - * Create most appropriate mutable buffer for given type avoiding boxing wherever possible - */ - @Suppress("UNCHECKED_CAST") - inline fun auto(size: Int, initializer: (Int) -> T): MutableBuffer = - auto(T::class, size, initializer) - - val real: MutableBufferFactory = { size: Int, initializer: (Int) -> Double -> - RealBuffer(DoubleArray(size) { initializer(it) }) - } - } -} - -/** - * [Buffer] implementation over [List]. - * - * @param T the type of elements contained in the buffer. - * @property list The underlying list. - */ -inline class ListBuffer(val list: List) : Buffer { - override val size: Int - get() = list.size - - override operator fun get(index: Int): T = list[index] - override operator fun iterator(): Iterator = list.iterator() -} - -/** - * Returns an [ListBuffer] that wraps the original list. - */ -fun List.asBuffer(): ListBuffer = ListBuffer(this) - -/** - * Creates a new [ListBuffer] with the specified [size], where each element is calculated by calling the specified - * [init] function. - * - * The function [init] is called for each array element sequentially starting from the first one. - * It should return the value for an array element given its index. - */ -inline fun ListBuffer(size: Int, init: (Int) -> T): ListBuffer { - contract { callsInPlace(init) } - return List(size, init).asBuffer() -} - -/** - * [MutableBuffer] implementation over [MutableList]. - * - * @param T the type of elements contained in the buffer. - * @property list The underlying list. - */ -inline class MutableListBuffer(val list: MutableList) : MutableBuffer { - override val size: Int - get() = list.size - - override operator fun get(index: Int): T = list[index] - - override operator fun set(index: Int, value: T) { - list[index] = value - } - - override operator fun iterator(): Iterator = list.iterator() - override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) -} - -/** - * [MutableBuffer] implementation over [Array]. - * - * @param T the type of elements contained in the buffer. - * @property array The underlying array. - */ -class ArrayBuffer(private val array: Array) : MutableBuffer { - // Can't inline because array is invariant - override val size: Int - get() = array.size - - override operator fun get(index: Int): T = array[index] - - override operator fun set(index: Int, value: T) { - array[index] = value - } - - override operator fun iterator(): Iterator = array.iterator() - override fun copy(): MutableBuffer = ArrayBuffer(array.copyOf()) -} - -/** - * Returns an [ArrayBuffer] that wraps the original array. - */ -fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) - -/** - * Immutable wrapper for [MutableBuffer]. - * - * @param T the type of elements contained in the buffer. - * @property buffer The underlying buffer. - */ -inline class ReadOnlyBuffer(val buffer: MutableBuffer) : Buffer { - override val size: Int get() = buffer.size - - override operator fun get(index: Int): T = buffer[index] - - override operator fun iterator(): Iterator = buffer.iterator() -} - -/** - * A buffer with content calculated on-demand. The calculated content is not stored, so it is recalculated on each call. - * Useful when one needs single element from the buffer. - * - * @param T the type of elements provided by the buffer. - */ -class VirtualBuffer(override val size: Int, private val generator: (Int) -> T) : Buffer { - override operator fun get(index: Int): T { - if (index < 0 || index >= size) throw IndexOutOfBoundsException("Expected index from 0 to ${size - 1}, but found $index") - return generator(index) - } - - override operator fun iterator(): Iterator = (0 until size).asSequence().map(generator).iterator() - - override fun contentEquals(other: Buffer<*>): Boolean { - return if (other is VirtualBuffer) { - this.size == other.size && this.generator == other.generator - } else { - super.contentEquals(other) - } - } -} - -/** - * Convert this buffer to read-only buffer. - */ -fun Buffer.asReadOnly(): Buffer = if (this is MutableBuffer) ReadOnlyBuffer(this) else this - -/** - * Typealias for buffer transformations. - */ -typealias BufferTransform = (Buffer) -> Buffer - -/** - * Typealias for buffer transformations with suspend function. - */ -typealias SuspendBufferTransform = suspend (Buffer) -> Buffer diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt deleted file mode 100644 index f09db3c72..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt +++ /dev/null @@ -1,155 +0,0 @@ -package scientifik.kmath.structures - -import scientifik.kmath.operations.Complex -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space - - -/** - * An exception is thrown when the expected ans actual shape of NDArray differs - */ -class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : RuntimeException() - - -/** - * The base interface for all nd-algebra implementations - * @param T the type of nd-structure element - * @param C the type of the element context - * @param N the type of the structure - */ -interface NDAlgebra> { - val shape: IntArray - val elementContext: C - - /** - * Produce a new [N] structure using given initializer function - */ - fun produce(initializer: C.(IntArray) -> T): N - - /** - * Map elements from one structure to another one - */ - fun map(arg: N, transform: C.(T) -> T): N - - /** - * Map indexed elements - */ - fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N - - /** - * Combine two structures into one - */ - fun combine(a: N, b: N, transform: C.(T, T) -> T): N - - /** - * Check if given elements are consistent with this context - */ - fun check(vararg elements: N) { - elements.forEach { - if (!shape.contentEquals(it.shape)) { - throw ShapeMismatchException(shape, it.shape) - } - } - } - - /** - * element-by-element invoke a function working on [T] on a [NDStructure] - */ - operator fun Function1.invoke(structure: N): N = map(structure) { value -> this@invoke(value) } - - companion object -} - -/** - * An nd-space over element space - */ -interface NDSpace, N : NDStructure> : Space, NDAlgebra { - /** - * Element-by-element addition - */ - override fun add(a: N, b: N): N = combine(a, b) { aValue, bValue -> add(aValue, bValue) } - - /** - * Multiply all elements by constant - */ - override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) } - - //TODO move to extensions after KEEP-176 - operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) } - - operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) } - - operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) } - operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) } - - companion object -} - -/** - * An nd-ring over element ring - */ -interface NDRing, N : NDStructure> : Ring, NDSpace { - - /** - * Element-by-element multiplication - */ - override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } - - //TODO move to extensions after KEEP-176 - operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) } - - operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) } - - companion object -} - -/** - * Field of [NDStructure]. - * - * @param T the type of the element contained in ND structure. - * @param N the type of ND structure. - * @param F field of structure elements. - */ -interface NDField, N : NDStructure> : Field, NDRing { - - /** - * Element-by-element division - */ - override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) } - - //TODO move to extensions after KEEP-176 - operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) } - - operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) } - - companion object { - - private val realNDFieldCache = HashMap() - - /** - * Create a nd-field for [Double] values or pull it from cache if it was created previously - */ - fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) } - - /** - * Create a nd-field with boxing generic buffer - */ - fun > boxing( - field: F, - vararg shape: Int, - bufferFactory: BufferFactory = Buffer.Companion::boxing - ): BoxingNDField = BoxingNDField(shape, field, bufferFactory) - - /** - * Create a most suitable implementation for nd-field using reified class. - */ - @Suppress("UNCHECKED_CAST") - inline fun > auto(field: F, vararg shape: Int): BufferedNDField = - when { - T::class == Double::class -> real(*shape) as BufferedNDField - T::class == Complex::class -> complex(*shape) as BufferedNDField - else -> BoxingNDField(shape, field, Buffer.Companion::auto) - } - } -} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt deleted file mode 100644 index 9aa674177..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortBuffer.kt +++ /dev/null @@ -1,55 +0,0 @@ -package scientifik.kmath.structures - -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract - -/** - * Specialized [MutableBuffer] implementation over [ShortArray]. - * - * @property array the underlying array. - */ -inline class ShortBuffer(val array: ShortArray) : MutableBuffer { - override val size: Int get() = array.size - - override operator fun get(index: Int): Short = array[index] - - override operator fun set(index: Int, value: Short) { - array[index] = value - } - - override operator fun iterator(): ShortIterator = array.iterator() - - override fun copy(): MutableBuffer = - ShortBuffer(array.copyOf()) -} - -/** - * Creates a new [ShortBuffer] with the specified [size], where each element is calculated by calling the specified - * [init] function. - * - * The function [init] is called for each array element sequentially starting from the first one. - * It should return the value for an buffer element given its index. - */ -inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer { - contract { callsInPlace(init) } - return ShortBuffer(ShortArray(size) { init(it) }) -} - -/** - * Returns a new [ShortBuffer] of given elements. - */ -fun ShortBuffer(vararg shorts: Short): ShortBuffer = ShortBuffer(shorts) - -/** - * Returns a [ShortArray] containing all of the elements of this [MutableBuffer]. - */ -val MutableBuffer.array: ShortArray - get() = (if (this is ShortBuffer) array else ShortArray(size) { get(it) }) - -/** - * Returns [ShortBuffer] over this array. - * - * @receiver the array. - * @return the new buffer. - */ -fun ShortArray.asBuffer(): ShortBuffer = ShortBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt deleted file mode 100644 index eeb6bd3dc..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt +++ /dev/null @@ -1,58 +0,0 @@ -package scientifik.kmath.structures - -/** - * A structure that is guaranteed to be two-dimensional - */ -interface Structure2D : NDStructure { - val rowNum: Int get() = shape[0] - val colNum: Int get() = shape[1] - - operator fun get(i: Int, j: Int): T - - override operator fun get(index: IntArray): T { - require(index.size == 2) { "Index dimension mismatch. Expected 2 but found ${index.size}" } - return get(index[0], index[1]) - } - - val rows: Buffer> - get() = VirtualBuffer(rowNum) { i -> - VirtualBuffer(colNum) { j -> get(i, j) } - } - - val columns: Buffer> - get() = VirtualBuffer(colNum) { j -> - VirtualBuffer(rowNum) { i -> get(i, j) } - } - - override fun elements(): Sequence> = sequence { - for (i in (0 until rowNum)) { - for (j in (0 until colNum)) { - yield(intArrayOf(i, j) to get(i, j)) - } - } - } - - companion object -} - -/** - * A 2D wrapper for nd-structure - */ -private inline class Structure2DWrapper(val structure: NDStructure) : Structure2D { - override val shape: IntArray get() = structure.shape - - override operator fun get(i: Int, j: Int): T = structure[i, j] - - override fun elements(): Sequence> = structure.elements() -} - -/** - * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch - */ -fun NDStructure.as2D(): Structure2D = if (shape.size == 2) { - Structure2DWrapper(this) -} else { - error("Can't create 2d-structure from ${shape.size}d-structure") -} - -typealias Matrix = Structure2D diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/ExpressionFieldTest.kt similarity index 58% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/ExpressionFieldTest.kt index 485de08b4..484993eef 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/ExpressionFieldTest.kt @@ -1,24 +1,26 @@ -package scientifik.kmath.expressions +package kscience.kmath.expressions -import scientifik.kmath.operations.Complex -import scientifik.kmath.operations.ComplexField -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.invoke +import kscience.kmath.operations.Complex +import kscience.kmath.operations.ComplexField +import kscience.kmath.operations.RealField +import kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFails class ExpressionFieldTest { + val x by symbol @Test fun testExpression() { val context = FunctionalExpressionField(RealField) val expression = context { - val x = variable("x", 2.0) + val x by binding() x * x + 2 * x + one } - assertEquals(expression("x" to 1.0), 4.0) - assertEquals(expression(), 9.0) + assertEquals(expression(x to 1.0), 4.0) + assertFails { expression()} } @Test @@ -26,33 +28,33 @@ class ExpressionFieldTest { val context = FunctionalExpressionField(ComplexField) val expression = context { - val x = variable("x", Complex(2.0, 0.0)) + val x = bind(x) x * x + 2 * x + one } - assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0)) - assertEquals(expression(), Complex(9.0, 0.0)) + assertEquals(expression(x to Complex(1.0, 0.0)), Complex(4.0, 0.0)) + //assertEquals(expression(), Complex(9.0, 0.0)) } @Test fun separateContext() { fun FunctionalExpressionField.expression(): Expression { - val x = variable("x") + val x by binding() return x * x + 2 * x + one } val expression = FunctionalExpressionField(RealField).expression() - assertEquals(expression("x" to 1.0), 4.0) + assertEquals(expression(x to 1.0), 4.0) } @Test fun valueExpression() { val expressionBuilder: FunctionalExpressionField.() -> Expression = { - val x = variable("x") + val x by binding() x * x + 2 * x + one } val expression = FunctionalExpressionField(RealField).expressionBuilder() - assertEquals(expression("x" to 1.0), 4.0) + assertEquals(expression(x to 1.0), 4.0) } } diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt new file mode 100644 index 000000000..510ed23a9 --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -0,0 +1,285 @@ +package kscience.kmath.expressions + +import kscience.kmath.operations.RealField +import kscience.kmath.structures.asBuffer +import kotlin.math.E +import kotlin.math.PI +import kotlin.math.pow +import kotlin.math.sqrt +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class SimpleAutoDiffTest { + + fun dx( + xBinding: Pair, + body: SimpleAutoDiffField.(x: AutoDiffValue) -> AutoDiffValue, + ): DerivationResult = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) } + + fun dxy( + xBinding: Pair, + yBinding: Pair, + body: SimpleAutoDiffField.(x: AutoDiffValue, y: AutoDiffValue) -> AutoDiffValue, + ): DerivationResult = RealField.simpleAutoDiff(xBinding, yBinding) { + body(bind(xBinding.first), bind(yBinding.first)) + } + + fun diff(block: SimpleAutoDiffField.() -> AutoDiffValue): SimpleAutoDiffExpression { + return SimpleAutoDiffExpression(RealField, block) + } + + val x by symbol + val y by symbol + val z by symbol + + @Test + fun testPlusX2() { + val y = RealField.simpleAutoDiff(x to 3.0) { + // diff w.r.t this x at 3 + val x = bind(x) + x + x + } + assertEquals(6.0, y.value) // y = x + x = 6 + assertEquals(2.0, y.derivative(x)) // dy/dx = 2 + } + + @Test + fun testPlusX2Expr() { + val expr = diff { + val x = bind(x) + x + x + } + assertEquals(6.0, expr(x to 3.0)) // y = x + x = 6 + assertEquals(2.0, expr.derivative(x)(x to 3.0)) // dy/dx = 2 + } + + + @Test + fun testPlus() { + // two variables + val z = RealField.simpleAutoDiff(x to 2.0, y to 3.0) { + val x = bind(x) + val y = bind(y) + x + y + } + assertEquals(5.0, z.value) // z = x + y = 5 + assertEquals(1.0, z.derivative(x)) // dz/dx = 1 + assertEquals(1.0, z.derivative(y)) // dz/dy = 1 + } + + @Test + fun testMinus() { + // two variables + val z = RealField.simpleAutoDiff(x to 7.0, y to 3.0) { + val x = bind(x) + val y = bind(y) + + x - y + } + assertEquals(4.0, z.value) // z = x - y = 4 + assertEquals(1.0, z.derivative(x)) // dz/dx = 1 + assertEquals(-1.0, z.derivative(y)) // dz/dy = -1 + } + + @Test + fun testMulX2() { + val y = dx(x to 3.0) { x -> + // diff w.r.t this x at 3 + x * x + } + assertEquals(9.0, y.value) // y = x * x = 9 + assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7 + } + + @Test + fun testSqr() { + val y = dx(x to 3.0) { x -> sqr(x) } + assertEquals(9.0, y.value) // y = x ^ 2 = 9 + assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7 + } + + @Test + fun testSqrSqr() { + val y = dx(x to 2.0) { x -> sqr(sqr(x)) } + assertEquals(16.0, y.value) // y = x ^ 4 = 16 + assertEquals(32.0, y.derivative(x)) // dy/dx = 4 * x^3 = 32 + } + + @Test + fun testX3() { + val y = dx(x to 2.0) { x -> + // diff w.r.t this x at 2 + x * x * x + } + assertEquals(8.0, y.value) // y = x * x * x = 8 + assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x * x = 12 + } + + @Test + fun testDiv() { + val z = dxy(x to 5.0, y to 2.0) { x, y -> + x / y + } + assertEquals(2.5, z.value) // z = x / y = 2.5 + assertEquals(0.5, z.derivative(x)) // dz/dx = 1 / y = 0.5 + assertEquals(-1.25, z.derivative(y)) // dz/dy = -x / y^2 = -1.25 + } + + @Test + fun testPow3() { + val y = dx(x to 2.0) { x -> + // diff w.r.t this x at 2 + pow(x, 3) + } + assertEquals(8.0, y.value) // y = x ^ 3 = 8 + assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x ^ 2 = 12 + } + + @Test + fun testPowFull() { + val z = dxy(x to 2.0, y to 3.0) { x, y -> + pow(x, y) + } + assertApprox(8.0, z.value) // z = x ^ y = 8 + assertApprox(12.0, z.derivative(x)) // dz/dx = y * x ^ (y - 1) = 12 + assertApprox(8.0 * kotlin.math.ln(2.0), z.derivative(y)) // dz/dy = x ^ y * ln(x) + } + + @Test + fun testFromPaper() { + val y = dx(x to 3.0) { x -> 2 * x + x * x * x } + assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33 + assertEquals(29.0, y.derivative(x)) // dy/dx = 2 + 3 * x * x = 29 + } + + @Test + fun testInnerVariable() { + val y = dx(x to 1.0) { x -> + const(1.0) * x + } + assertEquals(1.0, y.value) // y = x ^ n = 1 + assertEquals(1.0, y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1 + } + + @Test + fun testLongChain() { + val n = 10_000 + val y = dx(x to 1.0) { x -> + var res = const(1.0) + for (i in 1..n) res *= x + res + } + assertEquals(1.0, y.value) // y = x ^ n = 1 + assertEquals(n.toDouble(), y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1 + } + + @Test + fun testExample() { + val y = dx(x to 2.0) { x -> sqr(x) + 5 * x + 3 } + assertEquals(17.0, y.value) // the value of result (y) + assertEquals(9.0, y.derivative(x)) // dy/dx + } + + @Test + fun testSqrt() { + val y = dx(x to 16.0) { x -> sqrt(x) } + assertEquals(4.0, y.value) // y = x ^ 1/2 = 4 + assertEquals(1.0 / 8, y.derivative(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8 + } + + @Test + fun testSin() { + val y = dx(x to PI / 6.0) { x -> sin(x) } + assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5 + assertApprox(sqrt(3.0) / 2, y.derivative(x)) // dy/dx = cos(pi/6) = sqrt(3)/2 + } + + @Test + fun testCos() { + val y = dx(x to PI / 6) { x -> cos(x) } + assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2 + assertApprox(-0.5, y.derivative(x)) // dy/dx = -sin(pi/6) = -0.5 + } + + @Test + fun testTan() { + val y = dx(x to PI / 6) { x -> tan(x) } + assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3) + assertApprox(4.0 / 3.0, y.derivative(x)) // dy/dx = sec(pi/6)^2 = 4/3 + } + + @Test + fun testAsin() { + val y = dx(x to PI / 6) { x -> asin(x) } + assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6) + assertApprox(6.0 / sqrt(36 - PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(36-pi^2) + } + + @Test + fun testAcos() { + val y = dx(x to PI / 6) { x -> acos(x) } + assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6) + assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2) + } + + @Test + fun testAtan() { + val y = dx(x to PI / 6) { x -> atan(x) } + assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6) + assertApprox(36.0 / (36.0 + PI * PI), y.derivative(x)) // dy/dx = 36/(36+pi^2) + } + + @Test + fun testSinh() { + val y = dx(x to 0.0) { x -> sinh(x) } + assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0) + assertApprox(kotlin.math.cosh(0.0), y.derivative(x)) // dy/dx = cosh(0) + } + + @Test + fun testCosh() { + val y = dx(x to 0.0) { x -> cosh(x) } + assertApprox(1.0, y.value) //y = cosh(0) + assertApprox(0.0, y.derivative(x)) // dy/dx = sinh(0) + } + + @Test + fun testTanh() { + val y = dx(x to 1.0) { x -> tanh(x) } + assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6) + assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2 + } + + @Test + fun testAsinh() { + val y = dx(x to PI / 6) { x -> asinh(x) } + assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6) + assertApprox(6.0 / sqrt(36 + PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(pi^2+36) + } + + @Test + fun testAcosh() { + val y = dx(x to PI / 6) { x -> acosh(x) } + assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6) + assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2) + } + + @Test + fun testAtanh() { + val y = dx(x to PI / 6) { x -> atanh(x) } + assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6) + assertApprox(-36.0 / (PI * PI - 36.0), y.derivative(x)) // dy/dx = -36/(pi^2-36) + } + + @Test + fun testDivGrad() { + val res = dxy(x to 1.0, y to 2.0) { x, y -> x * x + y * y } + assertEquals(6.0, res.div()) + assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer())) + } + + private fun assertApprox(a: Double, b: Double) { + if ((a - b) > 1e-10) assertEquals(a, b) + } +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt similarity index 85% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt index 52a2f80a6..d7755dcb5 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt @@ -1,11 +1,13 @@ -package scientifik.kmath.linear +package kscience.kmath.linear -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.NDStructure -import scientifik.kmath.structures.as2D +import kscience.kmath.operations.invoke +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.NDStructure +import kscience.kmath.structures.as2D import kotlin.test.Test import kotlin.test.assertEquals +@Suppress("UNUSED_VARIABLE") class MatrixTest { @Test fun testTranspose() { @@ -38,7 +40,7 @@ class MatrixTest { infix fun Matrix.pow(power: Int): Matrix { var res = this repeat(power - 1) { - res = res dot this + res = RealMatrixContext.invoke { res dot this@pow } } return res } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/RealLUSolverTest.kt similarity index 76% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/linear/RealLUSolverTest.kt index 34bd8a0e3..28dfe46ec 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/RealLUSolverTest.kt @@ -1,17 +1,15 @@ -package scientifik.kmath.linear +package kscience.kmath.linear -import scientifik.kmath.structures.Matrix -import kotlin.contracts.ExperimentalContracts +import kscience.kmath.structures.Matrix import kotlin.test.Test import kotlin.test.assertEquals -@ExperimentalContracts class RealLUSolverTest { @Test fun testInvertOne() { val matrix = MatrixContext.real.one(2, 2) - val inverted = MatrixContext.real.inverse(matrix) + val inverted = MatrixContext.real.inverseWithLUP(matrix) assertEquals(matrix, inverted) } @@ -39,7 +37,7 @@ class RealLUSolverTest { 1.0, 3.0 ) - val inverted = MatrixContext.real.inverse(matrix) + val inverted = MatrixContext.real.inverseWithLUP(matrix) val expected = Matrix.square( 0.375, -0.125, diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/VectorSpaceTest.kt similarity index 100% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/VectorSpaceTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/linear/VectorSpaceTest.kt diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/CumulativeKtTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/misc/CumulativeKtTest.kt similarity index 90% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/CumulativeKtTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/misc/CumulativeKtTest.kt index 82ea5318f..1e6d2fd5d 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/CumulativeKtTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/misc/CumulativeKtTest.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.misc +package kscience.kmath.misc import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntAlgebraTest.kt similarity index 94% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntAlgebraTest.kt index d140f1017..78611e5d2 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntAlgebraTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntAlgebraTest.kt @@ -1,6 +1,6 @@ -package scientifik.kmath.operations +package kscience.kmath.operations -import scientifik.kmath.operations.internal.RingVerifier +import kscience.kmath.operations.internal.RingVerifier import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConstructorTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntConstructorTest.kt similarity index 93% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConstructorTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntConstructorTest.kt index 5e3f6d1b0..ba2582bbf 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConstructorTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntConstructorTest.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.operations +package kscience.kmath.operations import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConversionsTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntConversionsTest.kt similarity index 96% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConversionsTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntConversionsTest.kt index 41df1968d..0b433c436 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntConversionsTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntConversionsTest.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.operations +package kscience.kmath.operations import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntOperationsTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntOperationsTest.kt similarity index 99% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntOperationsTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntOperationsTest.kt index b7f4cf43b..a3ed85c7b 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/BigIntOperationsTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/BigIntOperationsTest.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.operations +package kscience.kmath.operations import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/ComplexFieldTest.kt similarity index 96% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexFieldTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/ComplexFieldTest.kt index 2c480ebea..c0b4853f4 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/ComplexFieldTest.kt @@ -1,6 +1,6 @@ -package scientifik.kmath.operations +package kscience.kmath.operations -import scientifik.kmath.operations.internal.FieldVerifier +import kscience.kmath.operations.internal.FieldVerifier import kotlin.math.PI import kotlin.math.abs import kotlin.test.Test diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/ComplexTest.kt similarity index 85% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/ComplexTest.kt index e8d698c70..456e41467 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/ComplexTest.kt @@ -1,7 +1,8 @@ -package scientifik.kmath.operations +package kscience.kmath.operations import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertTrue internal class ComplexTest { @Test @@ -13,7 +14,7 @@ internal class ComplexTest { @Test fun reciprocal() { - assertEquals(Complex(0.5, -0.0), 2.toComplex().reciprocal) + assertTrue { (Complex(0.5, -0.0) - 2.toComplex().reciprocal).r < 1e-10} } @Test diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/RealFieldTest.kt similarity index 75% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/RealFieldTest.kt index a168b4afd..5705733cf 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/RealFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/RealFieldTest.kt @@ -1,6 +1,6 @@ -package scientifik.kmath.operations +package kscience.kmath.operations -import scientifik.kmath.operations.internal.FieldVerifier +import kscience.kmath.operations.internal.FieldVerifier import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/AlgebraicVerifier.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/AlgebraicVerifier.kt similarity index 55% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/AlgebraicVerifier.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/AlgebraicVerifier.kt index cb097d46e..7334c13a3 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/AlgebraicVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/AlgebraicVerifier.kt @@ -1,6 +1,6 @@ -package scientifik.kmath.operations.internal +package kscience.kmath.operations.internal -import scientifik.kmath.operations.Algebra +import kscience.kmath.operations.Algebra internal interface AlgebraicVerifier where A : Algebra { val algebra: A diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/FieldVerifier.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/FieldVerifier.kt similarity index 77% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/FieldVerifier.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/FieldVerifier.kt index 973fd00b1..89f31c75b 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/FieldVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/FieldVerifier.kt @@ -1,7 +1,7 @@ -package scientifik.kmath.operations.internal +package kscience.kmath.operations.internal -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.invoke +import kscience.kmath.operations.Field +import kscience.kmath.operations.invoke import kotlin.test.assertEquals import kotlin.test.assertNotEquals @@ -12,6 +12,8 @@ internal class FieldVerifier(override val algebra: Field, a: T, b: T, c: T super.verify() algebra { + assertEquals(a + b, b + a, "Addition in $algebra is not commutative.") + assertEquals(a * b, b * a, "Multiplication in $algebra is not commutative.") assertNotEquals(a / b, b / a, "Division in $algebra is not anti-commutative.") assertNotEquals((a / b) / c, a / (b / c), "Division in $algebra is associative.") assertEquals((a + b) / c, (a / c) + (b / c), "Division in $algebra is not right-distributive.") diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/RingVerifier.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/RingVerifier.kt similarity index 87% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/RingVerifier.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/RingVerifier.kt index 047a213e9..359ba1701 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/RingVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/RingVerifier.kt @@ -1,7 +1,7 @@ -package scientifik.kmath.operations.internal +package kscience.kmath.operations.internal -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.invoke +import kscience.kmath.operations.Ring +import kscience.kmath.operations.invoke import kotlin.test.assertEquals internal open class RingVerifier(override val algebra: Ring, a: T, b: T, c: T, x: Number) : @@ -10,7 +10,7 @@ internal open class RingVerifier(override val algebra: Ring, a: T, b: T, c super.verify() algebra { - assertEquals(a * b, a * b, "Multiplication in $algebra is not commutative.") + assertEquals(a + b, b + a, "Addition in $algebra is not commutative.") assertEquals(a * b * c, a * (b * c), "Multiplication in $algebra is not associative.") assertEquals(c * (a + b), (c * a) + (c * b), "Multiplication in $algebra is not distributive.") assertEquals(a * one, one * a, "$one in $algebra is not a neutral multiplication element.") diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/SpaceVerifier.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/SpaceVerifier.kt similarity index 87% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/SpaceVerifier.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/SpaceVerifier.kt index bc241c97d..045abb71f 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/internal/SpaceVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/operations/internal/SpaceVerifier.kt @@ -1,7 +1,7 @@ -package scientifik.kmath.operations.internal +package kscience.kmath.operations.internal -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke +import kscience.kmath.operations.Space +import kscience.kmath.operations.invoke import kotlin.test.assertEquals import kotlin.test.assertNotEquals @@ -15,7 +15,6 @@ internal open class SpaceVerifier( AlgebraicVerifier> { override fun verify() { algebra { - assertEquals(a + b, b + a, "Addition in $algebra is not commutative.") assertEquals(a + b + c, a + (b + c), "Addition in $algebra is not associative.") assertEquals(x * (a + b), x * a + x * b, "Addition in $algebra is not distributive.") assertEquals((a + b) * x, a * x + b * x, "Addition in $algebra is not distributive.") diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/ComplexBufferSpecTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/ComplexBufferSpecTest.kt similarity index 68% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/ComplexBufferSpecTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/structures/ComplexBufferSpecTest.kt index cbbe6f0f4..4837236db 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/ComplexBufferSpecTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/ComplexBufferSpecTest.kt @@ -1,7 +1,7 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.Complex -import scientifik.kmath.operations.complex +import kscience.kmath.operations.Complex +import kscience.kmath.operations.complex import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt new file mode 100644 index 000000000..1129a8a36 --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt @@ -0,0 +1,18 @@ +package kscience.kmath.structures + +import kscience.kmath.operations.internal.FieldVerifier +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class NDFieldTest { + @Test + fun verify() { + NDField.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) } + } + + @Test + fun testStrides() { + val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() } + assertEquals(ndArray[5, 5], 10.0) + } +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt similarity index 89% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt rename to kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt index b7e2594ec..22a0d3629 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt @@ -1,13 +1,14 @@ -package scientifik.kmath.structures +package kscience.kmath.structures -import scientifik.kmath.operations.Norm -import scientifik.kmath.operations.invoke -import scientifik.kmath.structures.NDElement.Companion.real2D +import kscience.kmath.operations.Norm +import kscience.kmath.operations.invoke +import kscience.kmath.structures.NDElement.Companion.real2D import kotlin.math.abs import kotlin.math.pow import kotlin.test.Test import kotlin.test.assertEquals +@Suppress("UNUSED_VARIABLE") class NumberNDFieldTest { val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() } val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt deleted file mode 100644 index c08a63ccb..000000000 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt +++ /dev/null @@ -1,181 +0,0 @@ -package scientifik.kmath.misc - -import scientifik.kmath.operations.RealField -import scientifik.kmath.structures.asBuffer -import kotlin.math.PI -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertTrue - -class AutoDiffTest { - fun Variable(int: Int): Variable = Variable(int.toDouble()) - - fun deriv(body: AutoDiffField.() -> Variable): DerivationResult = - RealField.deriv(body) - - @Test - fun testPlusX2() { - val x = Variable(3) // diff w.r.t this x at 3 - val y = deriv { x + x } - assertEquals(6.0, y.value) // y = x + x = 6 - assertEquals(2.0, y.deriv(x)) // dy/dx = 2 - } - - @Test - fun testPlus() { - // two variables - val x = Variable(2) - val y = Variable(3) - val z = deriv { x + y } - assertEquals(5.0, z.value) // z = x + y = 5 - assertEquals(1.0, z.deriv(x)) // dz/dx = 1 - assertEquals(1.0, z.deriv(y)) // dz/dy = 1 - } - - @Test - fun testMinus() { - // two variables - val x = Variable(7) - val y = Variable(3) - val z = deriv { x - y } - assertEquals(4.0, z.value) // z = x - y = 4 - assertEquals(1.0, z.deriv(x)) // dz/dx = 1 - assertEquals(-1.0, z.deriv(y)) // dz/dy = -1 - } - - @Test - fun testMulX2() { - val x = Variable(3) // diff w.r.t this x at 3 - val y = deriv { x * x } - assertEquals(9.0, y.value) // y = x * x = 9 - assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7 - } - - @Test - fun testSqr() { - val x = Variable(3) - val y = deriv { sqr(x) } - assertEquals(9.0, y.value) // y = x ^ 2 = 9 - assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7 - } - - @Test - fun testSqrSqr() { - val x = Variable(2) - val y = deriv { sqr(sqr(x)) } - assertEquals(16.0, y.value) // y = x ^ 4 = 16 - assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32 - } - - @Test - fun testX3() { - val x = Variable(2) // diff w.r.t this x at 2 - val y = deriv { x * x * x } - assertEquals(8.0, y.value) // y = x * x * x = 8 - assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12 - } - - @Test - fun testDiv() { - val x = Variable(5) - val y = Variable(2) - val z = deriv { x / y } - assertEquals(2.5, z.value) // z = x / y = 2.5 - assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5 - assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25 - } - - @Test - fun testPow3() { - val x = Variable(2) // diff w.r.t this x at 2 - val y = deriv { pow(x, 3) } - assertEquals(8.0, y.value) // y = x ^ 3 = 8 - assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12 - } - - @Test - fun testPowFull() { - val x = Variable(2) - val y = Variable(3) - val z = deriv { pow(x, y) } - assertApprox(8.0, z.value) // z = x ^ y = 8 - assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12 - assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x) - } - - @Test - fun testFromPaper() { - val x = Variable(3) - val y = deriv { 2 * x + x * x * x } - assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33 - assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29 - } - - @Test - fun testInnerVariable() { - val x = Variable(1) - val y = deriv { - Variable(1) * x - } - assertEquals(1.0, y.value) // y = x ^ n = 1 - assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1 - } - - @Test - fun testLongChain() { - val n = 10_000 - val x = Variable(1) - val y = deriv { - var res = Variable(1) - for (i in 1..n) res *= x - res - } - assertEquals(1.0, y.value) // y = x ^ n = 1 - assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1 - } - - @Test - fun testExample() { - val x = Variable(2) - val y = deriv { sqr(x) + 5 * x + 3 } - assertEquals(17.0, y.value) // the value of result (y) - assertEquals(9.0, y.deriv(x)) // dy/dx - } - - @Test - fun testSqrt() { - val x = Variable(16) - val y = deriv { sqrt(x) } - assertEquals(4.0, y.value) // y = x ^ 1/2 = 4 - assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8 - } - - @Test - fun testSin() { - val x = Variable(PI / 6) - val y = deriv { sin(x) } - assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5 - assertApprox(kotlin.math.sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(PI/6) = sqrt(3)/2 - } - - @Test - fun testCos() { - val x = Variable(PI / 6) - val y = deriv { cos(x) } - assertApprox(kotlin.math.sqrt(3.0) / 2, y.value) // y = cos(PI/6) = sqrt(3)/2 - assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(PI/6) = -0.5 - } - - @Test - fun testDivGrad() { - val x = Variable(1.0) - val y = Variable(2.0) - val res = deriv { x * x + y * y } - assertEquals(6.0, res.div()) - assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer())) - } - - private fun assertApprox(a: Double, b: Double) { - if ((a - b) > 1e-10) assertEquals(a, b) - } -} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt deleted file mode 100644 index 7abeefca6..000000000 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/structures/NDFieldTest.kt +++ /dev/null @@ -1,13 +0,0 @@ -package scientifik.kmath.structures - -import kotlin.test.Test -import kotlin.test.assertEquals - - -class NDFieldTest { - @Test - fun testStrides() { - val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() } - assertEquals(ndArray[5, 5], 10.0) - } -} diff --git a/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt new file mode 100644 index 000000000..9bd6a9fc4 --- /dev/null +++ b/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt @@ -0,0 +1,59 @@ +package kscience.kmath.operations + +import java.math.BigDecimal +import java.math.BigInteger +import java.math.MathContext + +/** + * A field over [BigInteger]. + */ +public object JBigIntegerField : Field, NumericAlgebra { + public override val zero: BigInteger + get() = BigInteger.ZERO + + public override val one: BigInteger + get() = BigInteger.ONE + + public override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) + public override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) + public override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) + public override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b) + public override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger()) + public override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) + public override operator fun BigInteger.unaryMinus(): BigInteger = negate() +} + +/** + * An abstract field over [BigDecimal]. + * + * @property mathContext the [MathContext] to use. + */ +public abstract class JBigDecimalFieldBase internal constructor( + private val mathContext: MathContext = MathContext.DECIMAL64, +) : Field, PowerOperations, NumericAlgebra { + public override val zero: BigDecimal + get() = BigDecimal.ZERO + + public override val one: BigDecimal + get() = BigDecimal.ONE + + public override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) + public override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) + public override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) + + public override fun multiply(a: BigDecimal, k: Number): BigDecimal = + a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext) + + public override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) + public override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) + public override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) + public override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) + public override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) +} + +/** + * A field over [BigDecimal]. + */ +public class JBigDecimalField(mathContext: MathContext = MathContext.DECIMAL64) : JBigDecimalFieldBase(mathContext) { + public companion object : JBigDecimalFieldBase() +} diff --git a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt deleted file mode 100644 index f10ef24da..000000000 --- a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt +++ /dev/null @@ -1,59 +0,0 @@ -package scientifik.kmath.operations - -import java.math.BigDecimal -import java.math.BigInteger -import java.math.MathContext - -/** - * A field over [BigInteger]. - */ -object JBigIntegerField : Field { - override val zero: BigInteger - get() = BigInteger.ZERO - - override val one: BigInteger - get() = BigInteger.ONE - - override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) - override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) - override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) - override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b) - override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger()) - override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) - override operator fun BigInteger.unaryMinus(): BigInteger = negate() -} - -/** - * An abstract field over [BigDecimal]. - * - * @property mathContext the [MathContext] to use. - */ -abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathContext = MathContext.DECIMAL64) : - Field, - PowerOperations { - override val zero: BigDecimal - get() = BigDecimal.ZERO - - override val one: BigDecimal - get() = BigDecimal.ONE - - override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) - override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) - override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) - - override fun multiply(a: BigDecimal, k: Number): BigDecimal = - a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext) - - override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) - override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) - override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) - override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) - override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) -} - -/** - * A field over [BigDecimal]. - */ -class JBigDecimalField(mathContext: MathContext = MathContext.DECIMAL64) : JBigDecimalFieldBase(mathContext) { - companion object : JBigDecimalFieldBase() -} diff --git a/kmath-coroutines/build.gradle.kts b/kmath-coroutines/build.gradle.kts index 4469a9ef6..e108c2755 100644 --- a/kmath-coroutines/build.gradle.kts +++ b/kmath-coroutines/build.gradle.kts @@ -1,12 +1,8 @@ -plugins { - id("scientifik.mpp") - //id("scientifik.atomic") -} +plugins { id("ru.mipt.npm.mpp") } kotlin.sourceSets { all { with(languageSettings) { - useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") useExperimentalAnnotation("kotlinx.coroutines.InternalCoroutinesApi") useExperimentalAnnotation("kotlinx.coroutines.ExperimentalCoroutinesApi") useExperimentalAnnotation("kotlinx.coroutines.FlowPreview") @@ -16,15 +12,7 @@ kotlin.sourceSets { commonMain { dependencies { api(project(":kmath-core")) - api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Scientifik.coroutinesVersion}") + api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${ru.mipt.npm.gradle.KScienceVersions.coroutinesVersion}") } } - - jvmMain { - dependencies { api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Scientifik.coroutinesVersion}") } - } - - jsMain { - dependencies { api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Scientifik.coroutinesVersion}") } - } } diff --git a/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/BlockingIntChain.kt b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/BlockingIntChain.kt new file mode 100644 index 000000000..6088267a2 --- /dev/null +++ b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/BlockingIntChain.kt @@ -0,0 +1,12 @@ +package kscience.kmath.chains + +/** + * Performance optimized chain for integer values + */ +public abstract class BlockingIntChain : Chain { + public abstract fun nextInt(): Int + + override suspend fun next(): Int = nextInt() + + public fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() } +} diff --git a/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/BlockingRealChain.kt b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/BlockingRealChain.kt new file mode 100644 index 000000000..718b3a18b --- /dev/null +++ b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/BlockingRealChain.kt @@ -0,0 +1,12 @@ +package kscience.kmath.chains + +/** + * Performance optimized chain for real values + */ +public abstract class BlockingRealChain : Chain { + public abstract fun nextDouble(): Double + + override suspend fun next(): Double = nextDouble() + + public fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() } +} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/Chain.kt similarity index 61% rename from kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt rename to kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/Chain.kt index f0ffd13cd..7ff7b7aae 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/Chain.kt @@ -14,7 +14,7 @@ * limitations under the License. */ -package scientifik.kmath.chains +package kscience.kmath.chains import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.FlowCollector @@ -26,47 +26,44 @@ import kotlinx.coroutines.sync.withLock * A not-necessary-Markov chain of some type * @param R - the chain element type */ -interface Chain : Flow { +public interface Chain : Flow { /** * Generate next value, changing state if needed */ - suspend fun next(): R + public suspend fun next(): R /** * Create a copy of current chain state. Consuming resulting chain does not affect initial chain */ - fun fork(): Chain + public fun fork(): Chain override suspend fun collect(collector: FlowCollector): Unit = flow { while (true) emit(next()) }.collect(collector) - companion object + public companion object } - -fun Iterator.asChain(): Chain = SimpleChain { next() } -fun Sequence.asChain(): Chain = iterator().asChain() +public fun Iterator.asChain(): Chain = SimpleChain { next() } +public fun Sequence.asChain(): Chain = iterator().asChain() /** - * A simple chain of independent tokens + * A simple chain of independent tokens. [fork] returns the same chain. */ -class SimpleChain(private val gen: suspend () -> R) : Chain { - override suspend fun next(): R = gen() - override fun fork(): Chain = this +public class SimpleChain(private val gen: suspend () -> R) : Chain { + public override suspend fun next(): R = gen() + public override fun fork(): Chain = this } /** * A stateless Markov chain */ -class MarkovChain(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain { - - private val mutex = Mutex() - +public class MarkovChain(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain { + private val mutex: Mutex = Mutex() private var value: R? = null - fun value(): R? = value + public fun value(): R? = value - override suspend fun next(): R { + public override suspend fun next(): R { mutex.withLock { val newValue = gen(value ?: seed()) value = newValue @@ -74,9 +71,7 @@ class MarkovChain(private val seed: suspend () -> R, private val ge } } - override fun fork(): Chain { - return MarkovChain(seed = { value ?: seed() }, gen = gen) - } + public override fun fork(): Chain = MarkovChain(seed = { value ?: seed() }, gen = gen) } /** @@ -84,19 +79,18 @@ class MarkovChain(private val seed: suspend () -> R, private val ge * @param S - the state of the chain * @param forkState - the function to copy current state without modifying it */ -class StatefulChain( +public class StatefulChain( private val state: S, private val seed: S.() -> R, private val forkState: ((S) -> S), private val gen: suspend S.(R) -> R ) : Chain { private val mutex: Mutex = Mutex() - private var value: R? = null - fun value(): R? = value + public fun value(): R? = value - override suspend fun next(): R { + public override suspend fun next(): R { mutex.withLock { val newValue = state.gen(value ?: state.seed()) value = newValue @@ -104,25 +98,22 @@ class StatefulChain( } } - override fun fork(): Chain = StatefulChain(forkState(state), seed, forkState, gen) + public override fun fork(): Chain = StatefulChain(forkState(state), seed, forkState, gen) } /** * A chain that repeats the same value */ -class ConstantChain(val value: T) : Chain { - override suspend fun next(): T = value - - override fun fork(): Chain { - return this - } +public class ConstantChain(public val value: T) : Chain { + public override suspend fun next(): T = value + public override fun fork(): Chain = this } /** * Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed * since mapped chain consumes tokens. Accepts regular transformation function */ -fun Chain.map(func: suspend (T) -> R): Chain = object : Chain { +public fun Chain.map(func: suspend (T) -> R): Chain = object : Chain { override suspend fun next(): R = func(this@map.next()) override fun fork(): Chain = this@map.fork().map(func) } @@ -130,7 +121,7 @@ fun Chain.map(func: suspend (T) -> R): Chain = object : Chain { /** * [block] must be a pure function or at least not use external random variables, otherwise fork could be broken */ -fun Chain.filter(block: (T) -> Boolean): Chain = object : Chain { +public fun Chain.filter(block: (T) -> Boolean): Chain = object : Chain { override suspend fun next(): T { var next: T @@ -146,23 +137,26 @@ fun Chain.filter(block: (T) -> Boolean): Chain = object : Chain { /** * Map the whole chain */ -fun Chain.collect(mapper: suspend (Chain) -> R): Chain = object : Chain { +public fun Chain.collect(mapper: suspend (Chain) -> R): Chain = object : Chain { override suspend fun next(): R = mapper(this@collect) override fun fork(): Chain = this@collect.fork().collect(mapper) } -fun Chain.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain) -> R): Chain = - object : Chain { - override suspend fun next(): R = state.mapper(this@collectWithState) +public fun Chain.collectWithState( + state: S, + stateFork: (S) -> S, + mapper: suspend S.(Chain) -> R +): Chain = object : Chain { + override suspend fun next(): R = state.mapper(this@collectWithState) - override fun fork(): Chain = - this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) - } + override fun fork(): Chain = + this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) +} /** * Zip two chains together using given transformation */ -fun Chain.zip(other: Chain, block: suspend (T, U) -> R): Chain = object : Chain { +public fun Chain.zip(other: Chain, block: suspend (T, U) -> R): Chain = object : Chain { override suspend fun next(): R = block(this@zip.next(), other.next()) override fun fork(): Chain = this@zip.fork().zip(other.fork(), block) } diff --git a/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/flowExtra.kt b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/flowExtra.kt new file mode 100644 index 000000000..6b14057fe --- /dev/null +++ b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/chains/flowExtra.kt @@ -0,0 +1,26 @@ +package kscience.kmath.chains + +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.runningReduce +import kotlinx.coroutines.flow.scan +import kscience.kmath.operations.Space +import kscience.kmath.operations.SpaceOperations +import kscience.kmath.operations.invoke + +@ExperimentalCoroutinesApi +public fun Flow.cumulativeSum(space: SpaceOperations): Flow = + space { runningReduce { sum, element -> sum + element } } + +@ExperimentalCoroutinesApi +public fun Flow.mean(space: Space): Flow = space { + data class Accumulator(var sum: T, var num: Int) + + scan(Accumulator(zero, 0)) { sum, element -> + sum.apply { + this.sum += element + this.num += 1 + } + }.map { it.sum / it.num } +} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/coroutines/coroutinesExtra.kt similarity index 62% rename from kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt rename to kmath-coroutines/src/commonMain/kotlin/kscience/kmath/coroutines/coroutinesExtra.kt index 692f89589..7dcdc0d62 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/coroutines/coroutinesExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/coroutines/coroutinesExtra.kt @@ -1,11 +1,10 @@ -package scientifik.kmath.coroutines +package kscience.kmath.coroutines import kotlinx.coroutines.* import kotlinx.coroutines.channels.produce import kotlinx.coroutines.flow.* -import kotlin.contracts.contract -val Dispatchers.Math: CoroutineDispatcher +public val Dispatchers.Math: CoroutineDispatcher get() = Default /** @@ -15,31 +14,26 @@ internal class LazyDeferred(val dispatcher: CoroutineDispatcher, val block: s private var deferred: Deferred? = null internal fun start(scope: CoroutineScope) { - if (deferred == null) { - deferred = scope.async(dispatcher, block = block) - } + if (deferred == null) deferred = scope.async(dispatcher, block = block) } suspend fun await(): T = deferred?.await() ?: error("Coroutine not started") } -class AsyncFlow internal constructor(internal val deferredFlow: Flow>) : Flow { - override suspend fun collect(collector: FlowCollector) { +public class AsyncFlow internal constructor(internal val deferredFlow: Flow>) : Flow { + override suspend fun collect(collector: FlowCollector): Unit = deferredFlow.collect { collector.emit((it.await())) } - } } -fun Flow.async( +public fun Flow.async( dispatcher: CoroutineDispatcher = Dispatchers.Default, block: suspend CoroutineScope.(T) -> R ): AsyncFlow { - val flow = map { - LazyDeferred(dispatcher) { block(it) } - } + val flow = map { LazyDeferred(dispatcher) { block(it) } } return AsyncFlow(flow) } -fun AsyncFlow.map(action: (T) -> R): AsyncFlow = +public fun AsyncFlow.map(action: (T) -> R): AsyncFlow = AsyncFlow(deferredFlow.map { input -> //TODO add function composition LazyDeferred(input.dispatcher) { @@ -48,7 +42,7 @@ fun AsyncFlow.map(action: (T) -> R): AsyncFlow = } }) -suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector) { +public suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector) { require(concurrency >= 1) { "Buffer size should be more than 1, but was $concurrency" } coroutineScope { @@ -76,18 +70,14 @@ suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector< } } -suspend inline fun AsyncFlow.collect(concurrency: Int, crossinline action: suspend (value: T) -> Unit) { - contract { callsInPlace(action) } +public suspend inline fun AsyncFlow.collect( + concurrency: Int, + crossinline action: suspend (value: T) -> Unit +): Unit = collect(concurrency, object : FlowCollector { + override suspend fun emit(value: T): Unit = action(value) +}) - collect(concurrency, object : FlowCollector { - override suspend fun emit(value: T): Unit = action(value) - }) -} - -inline fun Flow.mapParallel( +public inline fun Flow.mapParallel( dispatcher: CoroutineDispatcher = Dispatchers.Default, crossinline transform: suspend (T) -> R -): Flow { - contract { callsInPlace(transform) } - return flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher) -} +): Flow = flatMapMerge { value -> flow { emit(transform(value)) } }.flowOn(dispatcher) diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/streaming/BufferFlow.kt similarity index 62% rename from kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt rename to kmath-coroutines/src/commonMain/kotlin/kscience/kmath/streaming/BufferFlow.kt index 9b7e82da5..328a7807c 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt +++ b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/streaming/BufferFlow.kt @@ -1,28 +1,28 @@ -package scientifik.kmath.streaming +package kscience.kmath.streaming import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.flow.* -import scientifik.kmath.chains.BlockingRealChain -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.BufferFactory -import scientifik.kmath.structures.RealBuffer -import scientifik.kmath.structures.asBuffer +import kscience.kmath.chains.BlockingRealChain +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.BufferFactory +import kscience.kmath.structures.RealBuffer +import kscience.kmath.structures.asBuffer /** * Create a [Flow] from buffer */ -fun Buffer.asFlow(): Flow = iterator().asFlow() +public fun Buffer.asFlow(): Flow = iterator().asFlow() /** * Flat map a [Flow] of [Buffer] into continuous [Flow] of elements */ @FlowPreview -fun Flow>.spread(): Flow = flatMapConcat { it.asFlow() } +public fun Flow>.spread(): Flow = flatMapConcat { it.asFlow() } /** * Collect incoming flow into fixed size chunks */ -fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow> = flow { +public fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow> = flow { require(bufferSize > 0) { "Resulting chunk size must be more than zero" } val list = ArrayList(bufferSize) var counter = 0 @@ -30,6 +30,7 @@ fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow< this@chunked.collect { element -> list.add(element) counter++ + if (counter == bufferSize) { val buffer = bufferFactory(bufferSize) { list[it] } emit(buffer) @@ -37,22 +38,19 @@ fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow< counter = 0 } } - if (counter > 0) { - emit(bufferFactory(counter) { list[it] }) - } + + if (counter > 0) emit(bufferFactory(counter) { list[it] }) } /** * Specialized flow chunker for real buffer */ -fun Flow.chunked(bufferSize: Int): Flow = flow { +public fun Flow.chunked(bufferSize: Int): Flow = flow { require(bufferSize > 0) { "Resulting chunk size must be more than zero" } if (this@chunked is BlockingRealChain) { - //performance optimization for blocking primitive chain - while (true) { - emit(nextBlock(bufferSize).asBuffer()) - } + // performance optimization for blocking primitive chain + while (true) emit(nextBlock(bufferSize).asBuffer()) } else { val array = DoubleArray(bufferSize) var counter = 0 @@ -60,15 +58,15 @@ fun Flow.chunked(bufferSize: Int): Flow = flow { this@chunked.collect { element -> array[counter] = element counter++ + if (counter == bufferSize) { val buffer = RealBuffer(array) emit(buffer) counter = 0 } } - if (counter > 0) { - emit(RealBuffer(counter) { array[it] }) - } + + if (counter > 0) emit(RealBuffer(counter) { array[it] }) } } @@ -76,9 +74,10 @@ fun Flow.chunked(bufferSize: Int): Flow = flow { * Map a flow to a moving window buffer. The window step is one. * In order to get different steps, one could use skip operation. */ -fun Flow.windowed(window: Int): Flow> = flow { +public fun Flow.windowed(window: Int): Flow> = flow { require(window > 1) { "Window size must be more than one" } val ringBuffer = RingBuffer.boxing(window) + this@windowed.collect { element -> ringBuffer.push(element) emit(ringBuffer.snapshot()) diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/streaming/RingBuffer.kt similarity index 65% rename from kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt rename to kmath-coroutines/src/commonMain/kotlin/kscience/kmath/streaming/RingBuffer.kt index f1c0bfc6a..385bbaae2 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/RingBuffer.kt +++ b/kmath-coroutines/src/commonMain/kotlin/kscience/kmath/streaming/RingBuffer.kt @@ -1,37 +1,37 @@ -package scientifik.kmath.streaming +package kscience.kmath.streaming import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.MutableBuffer -import scientifik.kmath.structures.VirtualBuffer +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.MutableBuffer +import kscience.kmath.structures.VirtualBuffer /** * Thread-safe ring buffer */ @Suppress("UNCHECKED_CAST") -class RingBuffer( +public class RingBuffer( private val buffer: MutableBuffer, private var startIndex: Int = 0, size: Int = 0 ) : Buffer { private val mutex: Mutex = Mutex() - override var size: Int = size + public override var size: Int = size private set - override operator fun get(index: Int): T { + public override operator fun get(index: Int): T { require(index >= 0) { "Index must be positive" } require(index < size) { "Index $index is out of circular buffer size $size" } return buffer[startIndex.forward(index)] as T } - fun isFull(): Boolean = size == buffer.size + public fun isFull(): Boolean = size == buffer.size /** * Iterator could provide wrong results if buffer is changed in initialization (iteration is safe) */ - override operator fun iterator(): Iterator = object : AbstractIterator() { + public override operator fun iterator(): Iterator = object : AbstractIterator() { private var count = size private var index = startIndex val copy = buffer.copy() @@ -48,23 +48,17 @@ class RingBuffer( /** * A safe snapshot operation */ - suspend fun snapshot(): Buffer { + public suspend fun snapshot(): Buffer { mutex.withLock { val copy = buffer.copy() - return VirtualBuffer(size) { i -> - copy[startIndex.forward(i)] as T - } + return VirtualBuffer(size) { i -> copy[startIndex.forward(i)] as T } } } - suspend fun push(element: T) { + public suspend fun push(element: T) { mutex.withLock { buffer[startIndex.forward(size)] = element - if (isFull()) { - startIndex++ - } else { - size++ - } + if (isFull()) startIndex++ else size++ } } @@ -72,8 +66,8 @@ class RingBuffer( @Suppress("NOTHING_TO_INLINE") private inline fun Int.forward(n: Int): Int = (this + n) % (buffer.size) - companion object { - inline fun build(size: Int, empty: T): RingBuffer { + public companion object { + public inline fun build(size: Int, empty: T): RingBuffer { val buffer = MutableBuffer.auto(size) { empty } as MutableBuffer return RingBuffer(buffer) } @@ -81,7 +75,7 @@ class RingBuffer( /** * Slow yet universal buffer */ - fun boxing(size: Int): RingBuffer { + public fun boxing(size: Int): RingBuffer { val buffer: MutableBuffer = MutableBuffer.boxing(size) { null } return RingBuffer(buffer) } diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt deleted file mode 100644 index e9b499d71..000000000 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt +++ /dev/null @@ -1,12 +0,0 @@ -package scientifik.kmath.chains - -/** - * Performance optimized chain for integer values - */ -abstract class BlockingIntChain : Chain { - abstract fun nextInt(): Int - - override suspend fun next(): Int = nextInt() - - fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() } -} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt deleted file mode 100644 index ab819d327..000000000 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt +++ /dev/null @@ -1,12 +0,0 @@ -package scientifik.kmath.chains - -/** - * Performance optimized chain for real values - */ -abstract class BlockingRealChain : Chain { - abstract fun nextDouble(): Double - - override suspend fun next(): Double = nextDouble() - - fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() } -} diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt deleted file mode 100644 index 5db660c39..000000000 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/flowExtra.kt +++ /dev/null @@ -1,27 +0,0 @@ -package scientifik.kmath.chains - -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.map -import kotlinx.coroutines.flow.scan -import kotlinx.coroutines.flow.scanReduce -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.SpaceOperations -import scientifik.kmath.operations.invoke - -@ExperimentalCoroutinesApi -fun Flow.cumulativeSum(space: SpaceOperations): Flow = space { - scanReduce { sum: T, element: T -> sum + element } -} - -@ExperimentalCoroutinesApi -fun Flow.mean(space: Space): Flow = space { - class Accumulator(var sum: T, var num: Int) - - scan(Accumulator(zero, 0)) { sum, element -> - sum.apply { - this.sum += element - this.num += 1 - } - }.map { it.sum / it.num } -} diff --git a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/chains/ChainExt.kt similarity index 55% rename from kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt rename to kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/chains/ChainExt.kt index 5686b0ac0..3dfeddbac 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/chains/ChainExt.kt @@ -1,17 +1,16 @@ -package scientifik.kmath.chains +package kscience.kmath.chains import kotlinx.coroutines.runBlocking /** * Represent a chain as regular iterator (uses blocking calls) */ -operator fun Chain.iterator(): Iterator = object : Iterator { +public operator fun Chain.iterator(): Iterator = object : Iterator { override fun hasNext(): Boolean = true - override fun next(): R = runBlocking { next() } } /** * Represent a chain as a sequence */ -fun Chain.asSequence(): Sequence = Sequence { this@asSequence.iterator() } \ No newline at end of file +public fun Chain.asSequence(): Sequence = Sequence { this@asSequence.iterator() } diff --git a/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt new file mode 100644 index 000000000..7aa746797 --- /dev/null +++ b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt @@ -0,0 +1,56 @@ +package kscience.kmath.structures + +import kotlinx.coroutines.* +import kscience.kmath.coroutines.Math + +public class LazyNDStructure( + public val scope: CoroutineScope, + public override val shape: IntArray, + public val function: suspend (IntArray) -> T +) : NDStructure { + private val cache: MutableMap> = hashMapOf() + + public fun deferred(index: IntArray): Deferred = cache.getOrPut(index) { + scope.async(context = Dispatchers.Math) { function(index) } + } + + public suspend fun await(index: IntArray): T = deferred(index).await() + public override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() } + + public override fun elements(): Sequence> { + val strides = DefaultStrides(shape) + val res = runBlocking { strides.indices().toList().map { index -> index to await(index) } } + return res.asSequence() + } + + public override fun equals(other: Any?): Boolean { + return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false) + } + + public override fun hashCode(): Int { + var result = scope.hashCode() + result = 31 * result + shape.contentHashCode() + result = 31 * result + function.hashCode() + result = 31 * result + cache.hashCode() + return result + } +} + +public fun NDStructure.deferred(index: IntArray): Deferred = + if (this is LazyNDStructure) deferred(index) else CompletableDeferred(get(index)) + +public suspend fun NDStructure.await(index: IntArray): T = + if (this is LazyNDStructure) await(index) else get(index) + +/** + * PENDING would benefit from KEEP-176 + */ +public inline fun NDStructure.mapAsyncIndexed( + scope: CoroutineScope, + crossinline function: suspend (T, index: IntArray) -> R +): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index), index) } + +public inline fun NDStructure.mapAsync( + scope: CoroutineScope, + crossinline function: suspend (T) -> R +): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index)) } diff --git a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt deleted file mode 100644 index ff732a06b..000000000 --- a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/structures/LazyNDStructure.kt +++ /dev/null @@ -1,63 +0,0 @@ -package scientifik.kmath.structures - -import kotlinx.coroutines.* -import scientifik.kmath.coroutines.Math - -class LazyNDStructure( - val scope: CoroutineScope, - override val shape: IntArray, - val function: suspend (IntArray) -> T -) : NDStructure { - private val cache: MutableMap> = hashMapOf() - - fun deferred(index: IntArray): Deferred = cache.getOrPut(index) { - scope.async(context = Dispatchers.Math) { - function(index) - } - } - - suspend fun await(index: IntArray): T = deferred(index).await() - - override operator fun get(index: IntArray): T = runBlocking { - deferred(index).await() - } - - override fun elements(): Sequence> { - val strides = DefaultStrides(shape) - val res = runBlocking { - strides.indices().toList().map { index -> index to await(index) } - } - return res.asSequence() - } - - override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) - } - - override fun hashCode(): Int { - var result = scope.hashCode() - result = 31 * result + shape.contentHashCode() - result = 31 * result + function.hashCode() - result = 31 * result + cache.hashCode() - return result - } -} - -fun NDStructure.deferred(index: IntArray): Deferred = - if (this is LazyNDStructure) this.deferred(index) else CompletableDeferred(get(index)) - -suspend fun NDStructure.await(index: IntArray): T = - if (this is LazyNDStructure) this.await(index) else get(index) - -/** - * PENDING would benefit from KEEP-176 - */ -inline fun NDStructure.mapAsyncIndexed( - scope: CoroutineScope, - crossinline function: suspend (T, index: IntArray) -> R -): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index), index) } - -inline fun NDStructure.mapAsync( - scope: CoroutineScope, - crossinline function: suspend (T) -> R -): LazyNDStructure = LazyNDStructure(scope, shape) { index -> function(get(index)) } diff --git a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt b/kmath-coroutines/src/jvmTest/kotlin/kscience/kmath/streaming/BufferFlowTest.kt similarity index 86% rename from kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt rename to kmath-coroutines/src/jvmTest/kotlin/kscience/kmath/streaming/BufferFlowTest.kt index 427349072..a9bf38c12 100644 --- a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt +++ b/kmath-coroutines/src/jvmTest/kotlin/kscience/kmath/streaming/BufferFlowTest.kt @@ -1,12 +1,12 @@ -package scientifik.kmath.streaming +package kscience.kmath.streaming import kotlinx.coroutines.* import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.collect +import kscience.kmath.coroutines.async +import kscience.kmath.coroutines.collect +import kscience.kmath.coroutines.mapParallel import org.junit.jupiter.api.Timeout -import scientifik.kmath.coroutines.async -import scientifik.kmath.coroutines.collect -import scientifik.kmath.coroutines.mapParallel import java.util.concurrent.Executors import kotlin.test.Test @@ -14,7 +14,7 @@ import kotlin.test.Test @ExperimentalCoroutinesApi @InternalCoroutinesApi @FlowPreview -class BufferFlowTest { +internal class BufferFlowTest { val dispatcher: CoroutineDispatcher = Executors.newFixedThreadPool(4).asCoroutineDispatcher() @Test diff --git a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/RingBufferTest.kt b/kmath-coroutines/src/jvmTest/kotlin/kscience/kmath/streaming/RingBufferTest.kt similarity index 88% rename from kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/RingBufferTest.kt rename to kmath-coroutines/src/jvmTest/kotlin/kscience/kmath/streaming/RingBufferTest.kt index c84ef89ef..5bb0c1d40 100644 --- a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/RingBufferTest.kt +++ b/kmath-coroutines/src/jvmTest/kotlin/kscience/kmath/streaming/RingBufferTest.kt @@ -1,12 +1,12 @@ -package scientifik.kmath.streaming +package kscience.kmath.streaming import kotlinx.coroutines.flow.* import kotlinx.coroutines.runBlocking -import scientifik.kmath.structures.asSequence +import kscience.kmath.structures.asSequence import kotlin.test.Test import kotlin.test.assertEquals -class RingBufferTest { +internal class RingBufferTest { @Test fun push() { val buffer = RingBuffer.build(20, Double.NaN) diff --git a/kmath-dimensions/build.gradle.kts b/kmath-dimensions/build.gradle.kts index dda6cd2f0..9bf89fc43 100644 --- a/kmath-dimensions/build.gradle.kts +++ b/kmath-dimensions/build.gradle.kts @@ -1,8 +1,9 @@ plugins { - id("scientifik.mpp") + id("ru.mipt.npm.mpp") + id("ru.mipt.npm.native") } -description = "A proof of concept module for adding typ-safe dimensions to structures" +description = "A proof of concept module for adding type-safe dimensions to structures" kotlin.sourceSets { commonMain { @@ -11,9 +12,13 @@ kotlin.sourceSets { } } - jvmMain{ - dependencies{ + jvmMain { + dependencies { api(kotlin("reflect")) } } -} \ No newline at end of file +} + +readme{ + maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE +} diff --git a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Dimensions.kt b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Dimensions.kt new file mode 100644 index 000000000..9450f9174 --- /dev/null +++ b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Dimensions.kt @@ -0,0 +1,49 @@ +package kscience.kmath.dimensions + +import kotlin.reflect.KClass + +/** + * Represents a quantity of dimensions in certain structure. + * + * @property dim The number of dimensions. + */ +public interface Dimension { + public val dim: UInt + + public companion object +} + +public fun KClass.dim(): UInt = Dimension.resolve(this).dim + +public expect fun Dimension.Companion.resolve(type: KClass): D + +/** + * Finds or creates [Dimension] with [Dimension.dim] equal to [dim]. + */ +public expect fun Dimension.Companion.of(dim: UInt): Dimension + +/** + * Finds [Dimension.dim] of given type [D]. + */ +public inline fun Dimension.Companion.dim(): UInt = D::class.dim() + +/** + * Type representing 1 dimension. + */ +public object D1 : Dimension { + override val dim: UInt get() = 1U +} + +/** + * Type representing 2 dimensions. + */ +public object D2 : Dimension { + override val dim: UInt get() = 2U +} + +/** + * Type representing 3 dimensions. + */ +public object D3 : Dimension { + override val dim: UInt get() = 3U +} diff --git a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt new file mode 100644 index 000000000..0244eae7f --- /dev/null +++ b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt @@ -0,0 +1,155 @@ +package kscience.kmath.dimensions + +import kscience.kmath.linear.* +import kscience.kmath.operations.invoke +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.Structure2D + +/** + * A matrix with compile-time controlled dimension + */ +public interface DMatrix : Structure2D { + public companion object { + /** + * Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed + */ + public inline fun coerce(structure: Structure2D): DMatrix { + require(structure.rowNum == Dimension.dim().toInt()) { + "Row number mismatch: expected ${Dimension.dim()} but found ${structure.rowNum}" + } + + require(structure.colNum == Dimension.dim().toInt()) { + "Column number mismatch: expected ${Dimension.dim()} but found ${structure.colNum}" + } + + return DMatrixWrapper(structure) + } + + /** + * The same as [DMatrix.coerce] but without dimension checks. Use with caution + */ + public fun coerceUnsafe(structure: Structure2D): DMatrix = + DMatrixWrapper(structure) + } +} + +/** + * An inline wrapper for a Matrix + */ +public inline class DMatrixWrapper( + private val structure: Structure2D, +) : DMatrix { + override val shape: IntArray get() = structure.shape + override val rowNum: Int get() = shape[0] + override val colNum: Int get() = shape[1] + override operator fun get(i: Int, j: Int): T = structure[i, j] +} + +/** + * Dimension-safe point + */ +public interface DPoint : Point { + public companion object { + public inline fun coerce(point: Point): DPoint { + require(point.size == Dimension.dim().toInt()) { + "Vector dimension mismatch: expected ${Dimension.dim()}, but found ${point.size}" + } + + return DPointWrapper(point) + } + + public fun coerceUnsafe(point: Point): DPoint = DPointWrapper(point) + } +} + +/** + * Dimension-safe point wrapper + */ +public inline class DPointWrapper(public val point: Point) : + DPoint { + override val size: Int get() = point.size + + override operator fun get(index: Int): T = point[index] + + override operator fun iterator(): Iterator = point.iterator() +} + + +/** + * Basic operations on dimension-safe matrices. Operates on [Matrix] + */ +public inline class DMatrixContext(public val context: MatrixContext>) { + public inline fun Matrix.coerce(): DMatrix { + require(rowNum == Dimension.dim().toInt()) { + "Row number mismatch: expected ${Dimension.dim()} but found $rowNum" + } + + require(colNum == Dimension.dim().toInt()) { + "Column number mismatch: expected ${Dimension.dim()} but found $colNum" + } + + return DMatrix.coerceUnsafe(this) + } + + /** + * Produce a matrix with this context and given dimensions + */ + public inline fun produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix { + val rows = Dimension.dim() + val cols = Dimension.dim() + return context.produce(rows.toInt(), cols.toInt(), initializer).coerce() + } + + public inline fun point(noinline initializer: (Int) -> T): DPoint { + val size = Dimension.dim() + + return DPoint.coerceUnsafe( + context.point( + size.toInt(), + initializer + ) + ) + } + + public inline infix fun DMatrix.dot( + other: DMatrix, + ): DMatrix = context { this@dot dot other }.coerce() + + public inline infix fun DMatrix.dot(vector: DPoint): DPoint = + DPoint.coerceUnsafe(context { this@dot dot vector }) + + public inline operator fun DMatrix.times(value: T): DMatrix = + context { this@times.times(value) }.coerce() + + public inline operator fun T.times(m: DMatrix): DMatrix = + m * this + + public inline operator fun DMatrix.plus(other: DMatrix): DMatrix = + context { this@plus + other }.coerce() + + public inline operator fun DMatrix.minus(other: DMatrix): DMatrix = + context { this@minus + other }.coerce() + + public inline operator fun DMatrix.unaryMinus(): DMatrix = + context { this@unaryMinus.unaryMinus() }.coerce() + + public inline fun DMatrix.transpose(): DMatrix = + context { (this@transpose as Matrix).transpose() }.coerce() + + public companion object { + public val real: DMatrixContext = DMatrixContext(MatrixContext.real) + } +} + + +/** + * A square unit matrix + */ +public inline fun DMatrixContext.one(): DMatrix = produce { i, j -> + if (i == j) 1.0 else 0.0 +} + +public inline fun DMatrixContext.zero(): DMatrix = + produce { _, _ -> + 0.0 + } \ No newline at end of file diff --git a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt deleted file mode 100644 index f40483cfd..000000000 --- a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Dimensions.kt +++ /dev/null @@ -1,35 +0,0 @@ -package scientifik.kmath.dimensions - -import kotlin.reflect.KClass - -/** - * An abstract class which is not used in runtime. Designates a size of some structure. - * Could be replaced later by fully inline constructs - */ -interface Dimension { - - val dim: UInt - companion object { - - } -} - -fun KClass.dim(): UInt = Dimension.resolve(this).dim - -expect fun Dimension.Companion.resolve(type: KClass): D - -expect fun Dimension.Companion.of(dim: UInt): Dimension - -inline fun Dimension.Companion.dim(): UInt = D::class.dim() - -object D1 : Dimension { - override val dim: UInt get() = 1U -} - -object D2 : Dimension { - override val dim: UInt get() = 2U -} - -object D3 : Dimension { - override val dim: UInt get() = 3U -} diff --git a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt deleted file mode 100644 index 7b0244bdf..000000000 --- a/kmath-dimensions/src/commonMain/kotlin/scientifik/kmath/dimensions/Wrappers.kt +++ /dev/null @@ -1,157 +0,0 @@ -package scientifik.kmath.dimensions - -import scientifik.kmath.linear.GenericMatrixContext -import scientifik.kmath.linear.MatrixContext -import scientifik.kmath.linear.Point -import scientifik.kmath.linear.transpose -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.invoke -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.Structure2D - -/** - * A matrix with compile-time controlled dimension - */ -interface DMatrix : Structure2D { - companion object { - /** - * Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed - */ - inline fun coerce(structure: Structure2D): DMatrix { - if (structure.rowNum != Dimension.dim().toInt()) { - error("Row number mismatch: expected ${Dimension.dim()} but found ${structure.rowNum}") - } - if (structure.colNum != Dimension.dim().toInt()) { - error("Column number mismatch: expected ${Dimension.dim()} but found ${structure.colNum}") - } - return DMatrixWrapper(structure) - } - - /** - * The same as [coerce] but without dimension checks. Use with caution - */ - fun coerceUnsafe(structure: Structure2D): DMatrix { - return DMatrixWrapper(structure) - } - } -} - -/** - * An inline wrapper for a Matrix - */ -inline class DMatrixWrapper( - val structure: Structure2D -) : DMatrix { - override val shape: IntArray get() = structure.shape - override operator fun get(i: Int, j: Int): T = structure[i, j] -} - -/** - * Dimension-safe point - */ -interface DPoint : Point { - companion object { - inline fun coerce(point: Point): DPoint { - if (point.size != Dimension.dim().toInt()) { - error("Vector dimension mismatch: expected ${Dimension.dim()}, but found ${point.size}") - } - return DPointWrapper(point) - } - - fun coerceUnsafe(point: Point): DPoint { - return DPointWrapper(point) - } - } -} - -/** - * Dimension-safe point wrapper - */ -inline class DPointWrapper(val point: Point) : - DPoint { - override val size: Int get() = point.size - - override operator fun get(index: Int): T = point[index] - - override operator fun iterator(): Iterator = point.iterator() -} - - -/** - * Basic operations on dimension-safe matrices. Operates on [Matrix] - */ -inline class DMatrixContext>(val context: GenericMatrixContext) { - - inline fun Matrix.coerce(): DMatrix { - check( - rowNum == Dimension.dim().toInt() - ) { "Row number mismatch: expected ${Dimension.dim()} but found $rowNum" } - - check( - colNum == Dimension.dim().toInt() - ) { "Column number mismatch: expected ${Dimension.dim()} but found $colNum" } - - return DMatrix.coerceUnsafe(this) - } - - /** - * Produce a matrix with this context and given dimensions - */ - inline fun produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix { - val rows = Dimension.dim() - val cols = Dimension.dim() - return context.produce(rows.toInt(), cols.toInt(), initializer).coerce() - } - - inline fun point(noinline initializer: (Int) -> T): DPoint { - val size = Dimension.dim() - - return DPoint.coerceUnsafe( - context.point( - size.toInt(), - initializer - ) - ) - } - - inline infix fun DMatrix.dot( - other: DMatrix - ): DMatrix = context { this@dot dot other }.coerce() - - inline infix fun DMatrix.dot(vector: DPoint): DPoint = - DPoint.coerceUnsafe(context { this@dot dot vector }) - - inline operator fun DMatrix.times(value: T): DMatrix = - context { this@times.times(value) }.coerce() - - inline operator fun T.times(m: DMatrix): DMatrix = - m * this - - inline operator fun DMatrix.plus(other: DMatrix): DMatrix = - context { this@plus + other }.coerce() - - inline operator fun DMatrix.minus(other: DMatrix): DMatrix = - context { this@minus + other }.coerce() - - inline operator fun DMatrix.unaryMinus(): DMatrix = - context { this@unaryMinus.unaryMinus() }.coerce() - - inline fun DMatrix.transpose(): DMatrix = - context { (this@transpose as Matrix).transpose() }.coerce() - - /** - * A square unit matrix - */ - inline fun one(): DMatrix = produce { i, j -> - if (i == j) context.elementContext.one else context.elementContext.zero - } - - inline fun zero(): DMatrix = produce { _, _ -> - context.elementContext.zero - } - - companion object { - val real: DMatrixContext = DMatrixContext(MatrixContext.real) - } -} diff --git a/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt similarity index 70% rename from kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt rename to kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt index 8dabdeeac..b9193d4dd 100644 --- a/kmath-dimensions/src/commonTest/kotlin/scientifik/dimensions/DMatrixContextTest.kt +++ b/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt @@ -1,11 +1,13 @@ -package scientifik.dimensions +package kscience.dimensions -import scientifik.kmath.dimensions.D2 -import scientifik.kmath.dimensions.D3 -import scientifik.kmath.dimensions.DMatrixContext +import kscience.kmath.dimensions.D2 +import kscience.kmath.dimensions.D3 +import kscience.kmath.dimensions.DMatrixContext +import kscience.kmath.dimensions.one import kotlin.test.Test -class DMatrixContextTest { +@Suppress("UNUSED_VARIABLE") +internal class DMatrixContextTest { @Test fun testDimensionSafeMatrix() { val res = with(DMatrixContext.real) { @@ -26,4 +28,4 @@ class DMatrixContextTest { m1.transpose() + m2 } } -} \ No newline at end of file +} diff --git a/kmath-dimensions/src/jsMain/kotlin/kscience/kmath/dimensions/dimJs.kt b/kmath-dimensions/src/jsMain/kotlin/kscience/kmath/dimensions/dimJs.kt new file mode 100644 index 000000000..4230da156 --- /dev/null +++ b/kmath-dimensions/src/jsMain/kotlin/kscience/kmath/dimensions/dimJs.kt @@ -0,0 +1,18 @@ +package kscience.kmath.dimensions + +import kotlin.reflect.KClass + +private val dimensionMap: MutableMap = hashMapOf(1u to D1, 2u to D2, 3u to D3) + +@Suppress("UNCHECKED_CAST") +public actual fun Dimension.Companion.resolve(type: KClass): D = dimensionMap + .entries + .map(MutableMap.MutableEntry::value) + .find { it::class == type } as? D + ?: error("Can't resolve dimension $type") + +public actual fun Dimension.Companion.of(dim: UInt): Dimension = dimensionMap.getOrPut(dim) { + object : Dimension { + override val dim: UInt get() = dim + } +} diff --git a/kmath-dimensions/src/jsMain/kotlin/scientifik/kmath/dimensions/dim.kt b/kmath-dimensions/src/jsMain/kotlin/scientifik/kmath/dimensions/dim.kt deleted file mode 100644 index bbd580629..000000000 --- a/kmath-dimensions/src/jsMain/kotlin/scientifik/kmath/dimensions/dim.kt +++ /dev/null @@ -1,22 +0,0 @@ -package scientifik.kmath.dimensions - -import kotlin.reflect.KClass - -private val dimensionMap = hashMapOf( - 1u to D1, - 2u to D2, - 3u to D3 -) - -@Suppress("UNCHECKED_CAST") -actual fun Dimension.Companion.resolve(type: KClass): D { - return dimensionMap.entries.find { it.value::class == type }?.value as? D ?: error("Can't resolve dimension $type") -} - -actual fun Dimension.Companion.of(dim: UInt): Dimension { - return dimensionMap.getOrPut(dim) { - object : Dimension { - override val dim: UInt get() = dim - } - } -} \ No newline at end of file diff --git a/kmath-dimensions/src/jvmMain/kotlin/kscience/kmath/dimensions/dimJvm.kt b/kmath-dimensions/src/jvmMain/kotlin/kscience/kmath/dimensions/dimJvm.kt new file mode 100644 index 000000000..dec3979ef --- /dev/null +++ b/kmath-dimensions/src/jvmMain/kotlin/kscience/kmath/dimensions/dimJvm.kt @@ -0,0 +1,16 @@ +package kscience.kmath.dimensions + +import kotlin.reflect.KClass + +public actual fun Dimension.Companion.resolve(type: KClass): D = + type.objectInstance ?: error("No object instance for dimension class") + +public actual fun Dimension.Companion.of(dim: UInt): Dimension = when (dim) { + 1u -> D1 + 2u -> D2 + 3u -> D3 + + else -> object : Dimension { + override val dim: UInt get() = dim + } +} \ No newline at end of file diff --git a/kmath-dimensions/src/jvmMain/kotlin/scientifik/kmath/dimensions/dim.kt b/kmath-dimensions/src/jvmMain/kotlin/scientifik/kmath/dimensions/dim.kt deleted file mode 100644 index e8fe8f59b..000000000 --- a/kmath-dimensions/src/jvmMain/kotlin/scientifik/kmath/dimensions/dim.kt +++ /dev/null @@ -1,18 +0,0 @@ -package scientifik.kmath.dimensions - -import kotlin.reflect.KClass - -actual fun Dimension.Companion.resolve(type: KClass): D{ - return type.objectInstance ?: error("No object instance for dimension class") -} - -actual fun Dimension.Companion.of(dim: UInt): Dimension{ - return when(dim){ - 1u -> D1 - 2u -> D2 - 3u -> D3 - else -> object : Dimension { - override val dim: UInt get() = dim - } - } -} \ No newline at end of file diff --git a/kmath-dimensions/src/nativeMain/kotlin/kscience/kmath/dimensions/dimNative.kt b/kmath-dimensions/src/nativeMain/kotlin/kscience/kmath/dimensions/dimNative.kt new file mode 100644 index 000000000..aeaeaf759 --- /dev/null +++ b/kmath-dimensions/src/nativeMain/kotlin/kscience/kmath/dimensions/dimNative.kt @@ -0,0 +1,20 @@ +package kscience.kmath.dimensions + +import kotlin.native.concurrent.ThreadLocal +import kotlin.reflect.KClass + +@ThreadLocal +private val dimensionMap: MutableMap = hashMapOf(1u to D1, 2u to D2, 3u to D3) + +@Suppress("UNCHECKED_CAST") +public actual fun Dimension.Companion.resolve(type: KClass): D = dimensionMap + .entries + .map(MutableMap.MutableEntry::value) + .find { it::class == type } as? D + ?: error("Can't resolve dimension $type") + +public actual fun Dimension.Companion.of(dim: UInt): Dimension = dimensionMap.getOrPut(dim) { + object : Dimension { + override val dim: UInt get() = dim + } +} diff --git a/kmath-ejml/build.gradle.kts b/kmath-ejml/build.gradle.kts new file mode 100644 index 000000000..fa4aa3e39 --- /dev/null +++ b/kmath-ejml/build.gradle.kts @@ -0,0 +1,8 @@ +plugins { + id("ru.mipt.npm.jvm") +} + +dependencies { + implementation("org.ejml:ejml-simple:0.39") + implementation(project(":kmath-core")) +} diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt new file mode 100644 index 000000000..82a5399fd --- /dev/null +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt @@ -0,0 +1,90 @@ +package kscience.kmath.ejml + +import kscience.kmath.linear.* +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.NDStructure +import kscience.kmath.structures.RealBuffer +import org.ejml.dense.row.factory.DecompositionFactory_DDRM +import org.ejml.simple.SimpleMatrix +import kotlin.reflect.KClass +import kotlin.reflect.cast + +/** + * Represents featured matrix over EJML [SimpleMatrix]. + * + * @property origin the underlying [SimpleMatrix]. + * @author Iaroslav Postovalov + */ +public class EjmlMatrix( + public val origin: SimpleMatrix, +) : Matrix { + public override val rowNum: Int get() = origin.numRows() + + public override val colNum: Int get() = origin.numCols() + + @UnstableKMathAPI + override fun getFeature(type: KClass): T? = when (type) { + InverseMatrixFeature::class -> object : InverseMatrixFeature { + override val inverse: Matrix by lazy { EjmlMatrix(origin.invert()) } + } + DeterminantFeature::class -> object : DeterminantFeature { + override val determinant: Double by lazy(origin::determinant) + } + SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature { + private val svd by lazy { + DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false) + .apply { decompose(origin.ddrm.copy()) } + } + + override val u: Matrix by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) } + override val s: Matrix by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) } + override val v: Matrix by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) } + override val singularValues: Point by lazy { RealBuffer(svd.singularValues) } + } + QRDecompositionFeature::class -> object : QRDecompositionFeature { + private val qr by lazy { + DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) } + } + + override val q: Matrix by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) } + override val r: Matrix by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) } + } + CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { + override val l: Matrix by lazy { + val cholesky = + DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) } + + EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature + } + } + LupDecompositionFeature::class -> object : LupDecompositionFeature { + private val lup by lazy { + DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) } + } + + override val l: Matrix by lazy { + EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature + } + + override val u: Matrix by lazy { + EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature + } + + override val p: Matrix by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) } + } + else -> null + }?.let { type.cast(it) } + + public override operator fun get(i: Int, j: Int): Double = origin[i, j] + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is Matrix<*>) return false + return NDStructure.contentEquals(this, other) + } + + override fun hashCode(): Int = origin.hashCode() + + +} diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt new file mode 100644 index 000000000..8184d0110 --- /dev/null +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -0,0 +1,92 @@ +package kscience.kmath.ejml + +import kscience.kmath.linear.InverseMatrixFeature +import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.Point +import kscience.kmath.linear.origin +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.getFeature +import org.ejml.simple.SimpleMatrix + +/** + * Represents context of basic operations operating with [EjmlMatrix]. + * + * @author Iaroslav Postovalov + */ +public object EjmlMatrixContext : MatrixContext { + + /** + * Converts this matrix to EJML one. + */ + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toEjml(): EjmlMatrix = when (val matrix = origin) { + is EjmlMatrix -> matrix + else -> produce(rowNum, colNum) { i, j -> get(i, j) } + } + + /** + * Converts this vector to EJML one. + */ + public fun Point.toEjml(): EjmlVector = + if (this is EjmlVector) this else EjmlVector(SimpleMatrix(size, 1).also { + (0 until it.numRows()).forEach { row -> it[row, 0] = get(row) } + }) + + override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix = + EjmlMatrix(SimpleMatrix(rows, columns).also { + (0 until it.numRows()).forEach { row -> + (0 until it.numCols()).forEach { col -> it[row, col] = initializer(row, col) } + } + }) + + override fun point(size: Int, initializer: (Int) -> Double): Point = + EjmlVector(SimpleMatrix(size, 1).also { + (0 until it.numRows()).forEach { row -> it[row, 0] = initializer(row) } + }) + + public override fun Matrix.dot(other: Matrix): EjmlMatrix = + EjmlMatrix(toEjml().origin.mult(other.toEjml().origin)) + + public override fun Matrix.dot(vector: Point): EjmlVector = + EjmlVector(toEjml().origin.mult(vector.toEjml().origin)) + + public override fun add(a: Matrix, b: Matrix): EjmlMatrix = + EjmlMatrix(a.toEjml().origin + b.toEjml().origin) + + public override operator fun Matrix.minus(b: Matrix): EjmlMatrix = + EjmlMatrix(toEjml().origin - b.toEjml().origin) + + public override fun multiply(a: Matrix, k: Number): EjmlMatrix = + produce(a.rowNum, a.colNum) { i, j -> a[i, j] * k.toDouble() } + + public override operator fun Matrix.times(value: Double): EjmlMatrix = + EjmlMatrix(toEjml().origin.scale(value)) +} + +/** + * Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix. + * + * @param a the base matrix. + * @param b n by p matrix. + * @return the solution for 'x' that is n by p. + * @author Iaroslav Postovalov + */ +public fun EjmlMatrixContext.solve(a: Matrix, b: Matrix): EjmlMatrix = + EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin)) + +/** + * Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix. + * + * @param a the base matrix. + * @param b n by p vector. + * @return the solution for 'x' that is n by p. + * @author Iaroslav Postovalov + */ +public fun EjmlMatrixContext.solve(a: Matrix, b: Point): EjmlVector = + EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) + +@OptIn(UnstableKMathAPI::class) +public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature>()!!.inverse as EjmlMatrix + +public fun EjmlMatrixContext.inverse(matrix: Matrix): Matrix = matrix.toEjml().inverted() \ No newline at end of file diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt new file mode 100644 index 000000000..f7cd1b66d --- /dev/null +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt @@ -0,0 +1,40 @@ +package kscience.kmath.ejml + +import org.ejml.simple.SimpleMatrix +import kscience.kmath.linear.Point +import kscience.kmath.structures.Buffer + +/** + * Represents point over EJML [SimpleMatrix]. + * + * @property origin the underlying [SimpleMatrix]. + * @author Iaroslav Postovalov + */ +public class EjmlVector internal constructor(public val origin: SimpleMatrix) : Point { + public override val size: Int + get() = origin.numRows() + + init { + require(origin.numCols() == 1) { "Only single column matrices are allowed" } + } + + public override operator fun get(index: Int): Double = origin[index] + + public override operator fun iterator(): Iterator = object : Iterator { + private var cursor: Int = 0 + + override fun next(): Double { + cursor += 1 + return origin[cursor - 1] + } + + override fun hasNext(): Boolean = cursor < origin.numCols() * origin.numRows() + } + + public override fun contentEquals(other: Buffer<*>): Boolean { + if (other is EjmlVector) return origin.isIdentical(other.origin, 0.0) + return super.contentEquals(other) + } + + public override fun toString(): String = "EjmlVector(origin=$origin)" +} diff --git a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt new file mode 100644 index 000000000..455b52d9d --- /dev/null +++ b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt @@ -0,0 +1,79 @@ +package kscience.kmath.ejml + +import kscience.kmath.linear.DeterminantFeature +import kscience.kmath.linear.LupDecompositionFeature +import kscience.kmath.linear.MatrixFeature +import kscience.kmath.linear.plus +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.structures.getFeature +import org.ejml.dense.row.factory.DecompositionFactory_DDRM +import org.ejml.simple.SimpleMatrix +import kotlin.random.Random +import kotlin.random.asJavaRandom +import kotlin.test.* + +internal class EjmlMatrixTest { + private val random = Random(0) + + private val randomMatrix: SimpleMatrix + get() { + val s = random.nextInt(2, 100) + return SimpleMatrix.random_DDRM(s, s, 0.0, 10.0, random.asJavaRandom()) + } + + @Test + fun rowNum() { + val m = randomMatrix + assertEquals(m.numRows(), EjmlMatrix(m).rowNum) + } + + @Test + fun colNum() { + val m = randomMatrix + assertEquals(m.numCols(), EjmlMatrix(m).rowNum) + } + + @Test + fun shape() { + val m = randomMatrix + val w = EjmlMatrix(m) + assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList()) + } + + @OptIn(UnstableKMathAPI::class) + @Test + fun features() { + val m = randomMatrix + val w = EjmlMatrix(m) + val det = w.getFeature>() ?: fail() + assertEquals(m.determinant(), det.determinant) + val lup = w.getFeature>() ?: fail() + + val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols()) + .also { it.decompose(m.ddrm.copy()) } + + assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), lup.l) + assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), lup.u) + assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), lup.p) + } + + private object SomeFeature : MatrixFeature {} + + @OptIn(UnstableKMathAPI::class) + @Test + fun suggestFeature() { + assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature()) + } + + @Test + fun get() { + val m = randomMatrix + assertEquals(m[0, 0], EjmlMatrix(m)[0, 0]) + } + + @Test + fun origin() { + val m = randomMatrix + assertSame(m, EjmlMatrix(m).origin) + } +} diff --git a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlVectorTest.kt b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlVectorTest.kt new file mode 100644 index 000000000..e27f977d2 --- /dev/null +++ b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlVectorTest.kt @@ -0,0 +1,47 @@ +package kscience.kmath.ejml + +import org.ejml.simple.SimpleMatrix +import kotlin.random.Random +import kotlin.random.asJavaRandom +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertSame + +internal class EjmlVectorTest { + private val random = Random(0) + + private val randomMatrix: SimpleMatrix + get() = SimpleMatrix.random_DDRM(random.nextInt(2, 100), 1, 0.0, 10.0, random.asJavaRandom()) + + @Test + fun size() { + val m = randomMatrix + val w = EjmlVector(m) + assertEquals(m.numRows(), w.size) + } + + @Test + fun get() { + val m = randomMatrix + val w = EjmlVector(m) + assertEquals(m[0, 0], w[0]) + } + + @Test + fun iterator() { + val m = randomMatrix + val w = EjmlVector(m) + + assertEquals( + m.iterator(true, 0, 0, m.numRows() - 1, 0).asSequence().toList(), + w.iterator().asSequence().toList() + ) + } + + @Test + fun origin() { + val m = randomMatrix + val w = EjmlVector(m) + assertSame(m, w.origin) + } +} diff --git a/kmath-for-real/README.md b/kmath-for-real/README.md new file mode 100644 index 000000000..d6b66b7da --- /dev/null +++ b/kmath-for-real/README.md @@ -0,0 +1,44 @@ +# Real number specialization module (`kmath-for-real`) + + - [RealVector](src/commonMain/kotlin/kscience/kmath/real/RealVector.kt) : Numpy-like operations for Buffers/Points + - [RealMatrix](src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt) : Numpy-like operations for 2d real structures + - [grids](src/commonMain/kotlin/kscience/kmath/structures/grids.kt) : Uniform grid generators + + +> #### Artifact: +> +> This module artifact: `kscience.kmath:kmath-for-real:0.2.0-dev-4`. +> +> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-for-real/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-for-real/_latestVersion) +> +> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-for-real/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-for-real/_latestVersion) +> +> **Gradle:** +> +> ```gradle +> repositories { +> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } +> maven { url 'https://dl.bintray.com/mipt-npm/kscience' } +> maven { url 'https://dl.bintray.com/mipt-npm/dev' } +> maven { url 'https://dl.bintray.com/hotkeytlt/maven' } +> +> } +> +> dependencies { +> implementation 'kscience.kmath:kmath-for-real:0.2.0-dev-4' +> } +> ``` +> **Gradle Kotlin DSL:** +> +> ```kotlin +> repositories { +> maven("https://dl.bintray.com/kotlin/kotlin-eap") +> maven("https://dl.bintray.com/mipt-npm/kscience") +> maven("https://dl.bintray.com/mipt-npm/dev") +> maven("https://dl.bintray.com/hotkeytlt/maven") +> } +> +> dependencies { +> implementation("kscience.kmath:kmath-for-real:0.2.0-dev-4") +> } +> ``` diff --git a/kmath-for-real/build.gradle.kts b/kmath-for-real/build.gradle.kts index 46d2682f7..f26f98c2c 100644 --- a/kmath-for-real/build.gradle.kts +++ b/kmath-for-real/build.gradle.kts @@ -1,6 +1,37 @@ -plugins { id("scientifik.mpp") } - -kotlin.sourceSets { - all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") } - commonMain { dependencies { api(project(":kmath-core")) } } +plugins { + id("ru.mipt.npm.mpp") +} + +kotlin.sourceSets.commonMain { + dependencies { + api(project(":kmath-core")) + } +} + +readme { + description = """ + Extension module that should be used to achieve numpy-like behavior. + All operations are specialized to work with `Double` numbers without declaring algebraic contexts. + One can still use generic algebras though. + """.trimIndent() + maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL + propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) + + feature( + id = "RealVector", + description = "Numpy-like operations for Buffers/Points", + ref = "src/commonMain/kotlin/kscience/kmath/real/RealVector.kt" + ) + + feature( + id = "RealMatrix", + description = "Numpy-like operations for 2d real structures", + ref = "src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt" + ) + + feature( + id = "grids", + description = "Uniform grid generators", + ref = "src/commonMain/kotlin/kscience/kmath/structures/grids.kt" + ) } diff --git a/kmath-for-real/docs/README-TEMPLATE.md b/kmath-for-real/docs/README-TEMPLATE.md new file mode 100644 index 000000000..670844bd0 --- /dev/null +++ b/kmath-for-real/docs/README-TEMPLATE.md @@ -0,0 +1,5 @@ +# Real number specialization module (`kmath-for-real`) + +${features} + +${artifact} diff --git a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt new file mode 100644 index 000000000..274030aff --- /dev/null +++ b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt @@ -0,0 +1,177 @@ +package kscience.kmath.real + +import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.VirtualMatrix +import kscience.kmath.linear.inverseWithLUP +import kscience.kmath.linear.real +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.RealBuffer +import kscience.kmath.structures.asIterable +import kotlin.math.pow + +/* + * Functions for convenient "numpy-like" operations with Double matrices. + * + * Initial implementation of these functions is taken from: + * https://github.com/thomasnield/numky/blob/master/src/main/kotlin/org/nield/numky/linear/DoubleOperators.kt + * + */ + +/* + * Functions that help create a real (Double) matrix + */ + +public typealias RealMatrix = Matrix + +public fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum, initializer) + +public fun Array.toMatrix(): RealMatrix { + return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } +} + +public fun Sequence.toMatrix(): RealMatrix = toList().let { + MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } +} + +public fun RealMatrix.repeatStackVertical(n: Int): RealMatrix = + VirtualMatrix(rowNum * n, colNum) { row, col -> + get(if (row == 0) 0 else row % rowNum, col) + } + +/* + * Operations for matrix and real number + */ + +public operator fun RealMatrix.times(double: Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] * double + } + +public operator fun RealMatrix.plus(double: Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] + double + } + +public operator fun RealMatrix.minus(double: Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] - double + } + +public operator fun RealMatrix.div(double: Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> + this[row, col] / double + } + +public operator fun Double.times(matrix: RealMatrix): RealMatrix = + MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { row, col -> + this * matrix[row, col] + } + +public operator fun Double.plus(matrix: RealMatrix): RealMatrix = + MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { row, col -> + this + matrix[row, col] + } + +public operator fun Double.minus(matrix: RealMatrix): RealMatrix = + MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { row, col -> + this - matrix[row, col] + } + +// TODO: does this operation make sense? Should it be 'this/matrix[row, col]'? +//operator fun Double.div(matrix: RealMatrix) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { +// row, col -> matrix[row, col] / this +//} + +/* + * Operations on two matrices (per-element!) + */ + +@UnstableKMathAPI +public operator fun RealMatrix.times(other: RealMatrix): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> this[row, col] * other[row, col] } + +public operator fun RealMatrix.plus(other: RealMatrix): RealMatrix = + MatrixContext.real.add(this, other) + +public operator fun RealMatrix.minus(other: RealMatrix): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { row, col -> this[row, col] - other[row, col] } + +/* + * Operations on columns + */ + +public inline fun RealMatrix.appendColumn(crossinline mapper: (Buffer) -> Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum + 1) { row, col -> + if (col < colNum) + this[row, col] + else + mapper(rows[row]) + } + +public fun RealMatrix.extractColumns(columnRange: IntRange): RealMatrix = + MatrixContext.real.produce(rowNum, columnRange.count()) { row, col -> + this[row, columnRange.first + col] + } + +public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix = + extractColumns(columnIndex..columnIndex) + +public fun RealMatrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> + columns[j].asIterable().sum() +} + +public fun RealMatrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> + columns[j].asIterable().minOrNull() ?: error("Cannot produce min on empty column") +} + +public fun RealMatrix.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> + columns[j].asIterable().maxOrNull() ?: error("Cannot produce min on empty column") +} + +public fun RealMatrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> + columns[j].asIterable().average() +} + +/* + * Operations processing all elements + */ + +public fun RealMatrix.sum(): Double = elements().map { (_, value) -> value }.sum() +public fun RealMatrix.min(): Double? = elements().map { (_, value) -> value }.minOrNull() +public fun RealMatrix.max(): Double? = elements().map { (_, value) -> value }.maxOrNull() +public fun RealMatrix.average(): Double = elements().map { (_, value) -> value }.average() + +public inline fun RealMatrix.map(transform: (Double) -> Double): RealMatrix = + MatrixContext.real.produce(rowNum, colNum) { i, j -> + transform(get(i, j)) + } + +/** + * Inverse a square real matrix using LUP decomposition + */ +public fun RealMatrix.inverseWithLUP(): RealMatrix = MatrixContext.real.inverseWithLUP(this) + +//extended operations + +public fun RealMatrix.pow(p: Double): RealMatrix = map { it.pow(p) } + +public fun RealMatrix.pow(p: Int): RealMatrix = map { it.pow(p) } + +public fun exp(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.exp(it) } + +public fun sqrt(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.sqrt(it) } + +public fun RealMatrix.square(): RealMatrix = map { it.pow(2) } + +public fun sin(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.sin(it) } + +public fun cos(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.cos(it) } + +public fun tan(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.tan(it) } + +public fun ln(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.ln(it) } + +public fun log10(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.log10(it) } \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealVector.kt new file mode 100644 index 000000000..596692782 --- /dev/null +++ b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealVector.kt @@ -0,0 +1,82 @@ +package kscience.kmath.real + +import kscience.kmath.linear.Point +import kscience.kmath.operations.Norm +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.asBuffer +import kscience.kmath.structures.asIterable +import kotlin.math.pow +import kotlin.math.sqrt + +public typealias RealVector = Point + +public object VectorL2Norm : Norm, Double> { + override fun norm(arg: Point): Double = sqrt(arg.asIterable().sumByDouble(Number::toDouble)) +} + +public operator fun Buffer.Companion.invoke(vararg doubles: Double): RealVector = doubles.asBuffer() + +/** + * Fill the vector of given [size] with given [value] + */ +public fun Buffer.Companion.same(size: Int, value: Number): RealVector = real(size) { value.toDouble() } + +// Transformation methods + +public inline fun RealVector.map(transform: (Double) -> Double): RealVector = + Buffer.real(size) { transform(get(it)) } + +public inline fun RealVector.mapIndexed(transform: (index: Int, value: Double) -> Double): RealVector = + Buffer.real(size) { transform(it, get(it)) } + +public operator fun RealVector.plus(other: RealVector): RealVector = + mapIndexed { index, value -> value + other[index] } + +public operator fun RealVector.plus(number: Number): RealVector = map { it + number.toDouble() } + +public operator fun Number.plus(vector: RealVector): RealVector = vector + this + +public operator fun RealVector.unaryMinus(): Buffer = map { -it } + +public operator fun RealVector.minus(other: RealVector): RealVector = + mapIndexed { index, value -> value - other[index] } + +public operator fun RealVector.minus(number: Number): RealVector = map { it - number.toDouble() } + +public operator fun Number.minus(vector: RealVector): RealVector = vector.map { toDouble() - it } + +public operator fun RealVector.times(other: RealVector): RealVector = + mapIndexed { index, value -> value * other[index] } + +public operator fun RealVector.times(number: Number): RealVector = map { it * number.toDouble() } + +public operator fun Number.times(vector: RealVector): RealVector = vector * this + +public operator fun RealVector.div(other: RealVector): RealVector = + mapIndexed { index, value -> value / other[index] } + +public operator fun RealVector.div(number: Number): RealVector = map { it / number.toDouble() } + +public operator fun Number.div(vector: RealVector): RealVector = vector.map { toDouble() / it } + +//extended operations + +public fun RealVector.pow(p: Double): RealVector = map { it.pow(p) } + +public fun RealVector.pow(p: Int): RealVector = map { it.pow(p) } + +public fun exp(vector: RealVector): RealVector = vector.map { kotlin.math.exp(it) } + +public fun sqrt(vector: RealVector): RealVector = vector.map { kotlin.math.sqrt(it) } + +public fun RealVector.square(): RealVector = map { it.pow(2) } + +public fun sin(vector: RealVector): RealVector = vector.map { kotlin.math.sin(it) } + +public fun cos(vector: RealVector): RealVector = vector.map { kotlin.math.cos(it) } + +public fun tan(vector: RealVector): RealVector = vector.map { kotlin.math.tan(it) } + +public fun ln(vector: RealVector): RealVector = vector.map { kotlin.math.ln(it) } + +public fun log10(vector: RealVector): RealVector = vector.map { kotlin.math.log10(it) } diff --git a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/dot.kt b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/dot.kt new file mode 100644 index 000000000..9beffe6bb --- /dev/null +++ b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/dot.kt @@ -0,0 +1,31 @@ +package kscience.kmath.real + +import kscience.kmath.linear.BufferMatrix +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.RealBuffer + + +/** + * Optimized dot product for real matrices + */ +public infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix { + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } + val resultArray = DoubleArray(this.rowNum * other.colNum) + + //convert to array to insure there is no memory indirection + fun Buffer.unsafeArray() = if (this is RealBuffer) + this.array + else + DoubleArray(size) { get(it) } + + val a = this.buffer.unsafeArray() + val b = other.buffer.unsafeArray() + + for (i in (0 until rowNum)) + for (j in (0 until other.colNum)) + for (k in (0 until colNum)) + resultArray[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j] + + val buffer = RealBuffer(resultArray) + return BufferMatrix(rowNum, other.colNum, buffer) +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/grids.kt similarity index 60% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt rename to kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/grids.kt index 1272ddd1c..69a149fb8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt +++ b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/grids.kt @@ -1,5 +1,6 @@ -package scientifik.kmath.misc +package kscience.kmath.real +import kscience.kmath.structures.asBuffer import kotlin.math.abs /** @@ -10,17 +11,21 @@ import kotlin.math.abs * * If step is negative, the same goes from upper boundary downwards */ -fun ClosedFloatingPointRange.toSequenceWithStep(step: Double): Sequence = when { +public fun ClosedFloatingPointRange.toSequenceWithStep(step: Double): Sequence = when { step == 0.0 -> error("Zero step in double progression") + step > 0 -> sequence { var current = start + while (current <= endInclusive) { yield(current) current += step } } + else -> sequence { var current = endInclusive + while (current >= start) { yield(current) current += step @@ -28,19 +33,13 @@ fun ClosedFloatingPointRange.toSequenceWithStep(step: Double): Sequence< } } +public infix fun ClosedFloatingPointRange.step(step: Double): RealVector = + toSequenceWithStep(step).toList().asBuffer() + /** * Convert double range to sequence with the fixed number of points */ -fun ClosedFloatingPointRange.toSequenceWithPoints(numPoints: Int): Sequence { +public fun ClosedFloatingPointRange.toSequenceWithPoints(numPoints: Int): Sequence { require(numPoints > 1) { "The number of points should be more than 2" } return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1)) } - -/** - * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] - */ -@Deprecated("Replace by 'toSequenceWithPoints'") -fun ClosedFloatingPointRange.toGrid(numPoints: Int): DoubleArray { - require(numPoints >= 2) { "Can't create generic grid with less than two points" } - return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } -} diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt deleted file mode 100644 index 811b54d7c..000000000 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/RealVector.kt +++ /dev/null @@ -1,51 +0,0 @@ -package scientifik.kmath.real - -import scientifik.kmath.linear.BufferVectorSpace -import scientifik.kmath.linear.Point -import scientifik.kmath.linear.VectorSpace -import scientifik.kmath.operations.Norm -import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.SpaceElement -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.RealBuffer -import scientifik.kmath.structures.asBuffer -import scientifik.kmath.structures.asIterable -import kotlin.math.sqrt - -typealias RealPoint = Point - -fun DoubleArray.asVector(): RealVector = RealVector(this.asBuffer()) -fun List.asVector(): RealVector = RealVector(this.asBuffer()) - -object VectorL2Norm : Norm, Double> { - override fun norm(arg: Point): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) -} - -inline class RealVector(private val point: Point) : - SpaceElement>, RealPoint { - - override val context: VectorSpace get() = space(point.size) - - override fun unwrap(): RealPoint = point - - override fun RealPoint.wrap(): RealVector = RealVector(this) - - override val size: Int get() = point.size - - override operator fun get(index: Int): Double = point[index] - - override operator fun iterator(): Iterator = point.iterator() - - companion object { - private val spaceCache: MutableMap> = hashMapOf() - - inline operator fun invoke(dim: Int, initializer: (Int) -> Double): RealVector = - RealVector(RealBuffer(dim, initializer)) - - operator fun invoke(vararg values: Double): RealVector = values.asVector() - - fun space(dim: Int): BufferVectorSpace = spaceCache.getOrPut(dim) { - BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } - } - } -} diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt deleted file mode 100644 index 82c0e86b2..000000000 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realBuffer.kt +++ /dev/null @@ -1,8 +0,0 @@ -package scientifik.kmath.real - -import scientifik.kmath.structures.RealBuffer - -/** - * Simplified [RealBuffer] to array comparison - */ -fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt b/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt deleted file mode 100644 index 3752fc3ca..000000000 --- a/kmath-for-real/src/commonMain/kotlin/scientifik/kmath/real/realMatrix.kt +++ /dev/null @@ -1,166 +0,0 @@ -package scientifik.kmath.real - -import scientifik.kmath.linear.MatrixContext -import scientifik.kmath.linear.RealMatrixContext.elementContext -import scientifik.kmath.linear.VirtualMatrix -import scientifik.kmath.operations.invoke -import scientifik.kmath.operations.sum -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.RealBuffer -import scientifik.kmath.structures.asIterable -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.contract -import kotlin.math.pow - -/* - * Functions for convenient "numpy-like" operations with Double matrices. - * - * Initial implementation of these functions is taken from: - * https://github.com/thomasnield/numky/blob/master/src/main/kotlin/org/nield/numky/linear/DoubleOperators.kt - * - */ - -/* - * Functions that help create a real (Double) matrix - */ - -typealias RealMatrix = Matrix - -fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = - MatrixContext.real.produce(rowNum, colNum, initializer) - -fun Array.toMatrix(): RealMatrix { - return MatrixContext.real.produce(size, this[0].size) { row, col -> this[row][col] } -} - -fun Sequence.toMatrix(): RealMatrix = toList().let { - MatrixContext.real.produce(it.size, it[0].size) { row, col -> it[row][col] } -} - -fun Matrix.repeatStackVertical(n: Int): RealMatrix = - VirtualMatrix(rowNum * n, colNum) { row, col -> - get(if (row == 0) 0 else row % rowNum, col) - } - -/* - * Operations for matrix and real number - */ - -operator fun Matrix.times(double: Double): RealMatrix = - MatrixContext.real.produce(rowNum, colNum) { row, col -> - this[row, col] * double - } - -operator fun Matrix.plus(double: Double): RealMatrix = - MatrixContext.real.produce(rowNum, colNum) { row, col -> - this[row, col] + double - } - -operator fun Matrix.minus(double: Double): RealMatrix = - MatrixContext.real.produce(rowNum, colNum) { row, col -> - this[row, col] - double - } - -operator fun Matrix.div(double: Double): RealMatrix = - MatrixContext.real.produce(rowNum, colNum) { row, col -> - this[row, col] / double - } - -operator fun Double.times(matrix: Matrix): RealMatrix = - MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { row, col -> - this * matrix[row, col] - } - -operator fun Double.plus(matrix: Matrix): RealMatrix = - MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { row, col -> - this + matrix[row, col] - } - -operator fun Double.minus(matrix: Matrix): RealMatrix = - MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { row, col -> - this - matrix[row, col] - } - -// TODO: does this operation make sense? Should it be 'this/matrix[row, col]'? -//operator fun Double.div(matrix: Matrix) = MatrixContext.real.produce(matrix.rowNum, matrix.colNum) { -// row, col -> matrix[row, col] / this -//} - -/* - * Per-element (!) square and power operations - */ - -fun Matrix.square(): RealMatrix = MatrixContext.real.produce(rowNum, colNum) { row, col -> - this[row, col].pow(2) -} - -fun Matrix.pow(n: Int): RealMatrix = MatrixContext.real.produce(rowNum, colNum) { i, j -> - this[i, j].pow(n) -} - -/* - * Operations on two matrices (per-element!) - */ - -operator fun Matrix.times(other: Matrix): RealMatrix = - MatrixContext.real.produce(rowNum, colNum) { row, col -> - this[row, col] * other[row, col] - } - -operator fun Matrix.plus(other: Matrix): RealMatrix = - MatrixContext.real.add(this, other) - -operator fun Matrix.minus(other: Matrix): RealMatrix = - MatrixContext.real.produce(rowNum, colNum) { row, col -> - this[row, col] - other[row, col] - } - -/* - * Operations on columns - */ - -inline fun Matrix.appendColumn(crossinline mapper: (Buffer) -> Double): Matrix { - contract { callsInPlace(mapper) } - - return MatrixContext.real.produce(rowNum, colNum + 1) { row, col -> - if (col < colNum) - this[row, col] - else - mapper(rows[row]) - } -} - -fun Matrix.extractColumns(columnRange: IntRange): RealMatrix = - MatrixContext.real.produce(rowNum, columnRange.count()) { row, col -> - this[row, columnRange.first + col] - } - -fun Matrix.extractColumn(columnIndex: Int): RealMatrix = - extractColumns(columnIndex..columnIndex) - -fun Matrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> - val column = columns[j] - elementContext { sum(column.asIterable()) } -} - -fun Matrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> - columns[j].asIterable().min() ?: error("Cannot produce min on empty column") -} - -fun Matrix.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> - columns[j].asIterable().max() ?: error("Cannot produce min on empty column") -} - -fun Matrix.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> - columns[j].asIterable().average() -} - -/* - * Operations processing all elements - */ - -fun Matrix.sum(): Double = elements().map { (_, value) -> value }.sum() -fun Matrix.min(): Double? = elements().map { (_, value) -> value }.min() -fun Matrix.max(): Double? = elements().map { (_, value) -> value }.max() -fun Matrix.average(): Double = elements().map { (_, value) -> value }.average() diff --git a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt new file mode 100644 index 000000000..5f19e94b7 --- /dev/null +++ b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt @@ -0,0 +1,13 @@ +package kaceince.kmath.real + +import kscience.kmath.real.step +import kotlin.test.Test +import kotlin.test.assertEquals + +class GridTest { + @Test + fun testStepGrid(){ + val grid = 0.0..1.0 step 0.2 + assertEquals(6, grid.size) + } +} \ No newline at end of file diff --git a/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt similarity index 94% rename from kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt rename to kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt index 8918fb300..a89f99b3c 100644 --- a/kmath-for-real/src/commonTest/kotlin/scientific.kmath.real/RealMatrixTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt @@ -1,14 +1,14 @@ -package scientific.kmath.real +package kaceince.kmath.real -import scientifik.kmath.linear.VirtualMatrix -import scientifik.kmath.linear.build -import scientifik.kmath.real.* -import scientifik.kmath.structures.Matrix +import kscience.kmath.linear.build +import kscience.kmath.real.* +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.contentEquals import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue -class RealMatrixTest { +internal class RealMatrixTest { @Test fun testSum() { val m = realMatrix(10, 10) { i, j -> (i + j).toDouble() } @@ -41,7 +41,7 @@ class RealMatrixTest { 1.0, 0.0, 0.0, 0.0, 1.0, 2.0 ) - assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3)) + assertEquals(matrix2, matrix1.repeatStackVertical(3)) } @Test diff --git a/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealVectorTest.kt similarity index 52% rename from kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt rename to kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealVectorTest.kt index ef7f40afe..6215ba5e8 100644 --- a/kmath-for-real/src/commonTest/kotlin/scientifik/kmath/linear/VectorTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealVectorTest.kt @@ -1,30 +1,33 @@ -package scientifik.kmath.linear +package kaceince.kmath.real -import scientifik.kmath.operations.invoke -import scientifik.kmath.real.RealVector +import kscience.kmath.linear.* +import kscience.kmath.operations.invoke +import kscience.kmath.real.RealVector +import kscience.kmath.real.plus +import kscience.kmath.structures.Buffer import kotlin.test.Test import kotlin.test.assertEquals -class VectorTest { +internal class RealVectorTest { @Test fun testSum() { - val vector1 = RealVector(5) { it.toDouble() } - val vector2 = RealVector(5) { 5 - it.toDouble() } + val vector1 = Buffer.real(5) { it.toDouble() } + val vector2 = Buffer.real(5) { 5 - it.toDouble() } val sum = vector1 + vector2 assertEquals(5.0, sum[2]) } @Test fun testVectorToMatrix() { - val vector = RealVector(5) { it.toDouble() } + val vector = Buffer.real(5) { it.toDouble() } val matrix = vector.asMatrix() assertEquals(4.0, matrix[4, 0]) } @Test fun testDot() { - val vector1 = RealVector(5) { it.toDouble() } - val vector2 = RealVector(5) { 5 - it.toDouble() } + val vector1 = Buffer.real(5) { it.toDouble() } + val vector2 = Buffer.real(5) { 5 - it.toDouble() } val matrix1 = vector1.asMatrix() val matrix2 = vector2.asMatrix().transpose() val product = MatrixContext.real { matrix1 dot matrix2 } diff --git a/kmath-functions/build.gradle.kts b/kmath-functions/build.gradle.kts index 46d2682f7..2a4539c10 100644 --- a/kmath-functions/build.gradle.kts +++ b/kmath-functions/build.gradle.kts @@ -1,6 +1,9 @@ -plugins { id("scientifik.mpp") } - -kotlin.sourceSets { - all { languageSettings.useExperimentalAnnotation("kotlin.contracts.ExperimentalContracts") } - commonMain { dependencies { api(project(":kmath-core")) } } +plugins { + id("ru.mipt.npm.mpp") +} + +kotlin.sourceSets.commonMain { + dependencies { + api(project(":kmath-core")) + } } diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Piecewise.kt b/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Piecewise.kt similarity index 56% rename from kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Piecewise.kt rename to kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Piecewise.kt index 16f8aa12b..a8c020c05 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Piecewise.kt +++ b/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Piecewise.kt @@ -1,48 +1,46 @@ -package scientifik.kmath.functions +package kscience.kmath.functions -import scientifik.kmath.operations.Ring +import kscience.kmath.operations.Ring -interface Piecewise { - fun findPiece(arg: T): R? +public fun interface Piecewise { + public fun findPiece(arg: T): R? } -interface PiecewisePolynomial : +public fun interface PiecewisePolynomial : Piecewise> /** * Ordered list of pieces in piecewise function */ -class OrderedPiecewisePolynomial>(delimeter: T) : +public class OrderedPiecewisePolynomial>(delimiter: T) : PiecewisePolynomial { - - private val delimiters: ArrayList = arrayListOf(delimeter) - private val pieces: ArrayList> = ArrayList() + private val delimiters: MutableList = arrayListOf(delimiter) + private val pieces: MutableList> = arrayListOf() /** * Dynamically add a piece to the "right" side (beyond maximum argument value of previous piece) * @param right new rightmost position. If is less then current rightmost position, a error is thrown. */ - fun putRight(right: T, piece: Polynomial) { + public fun putRight(right: T, piece: Polynomial) { require(right > delimiters.last()) { "New delimiter should be to the right of old one" } delimiters.add(right) pieces.add(piece) } - fun putLeft(left: T, piece: Polynomial) { + public fun putLeft(left: T, piece: Polynomial) { require(left < delimiters.first()) { "New delimiter should be to the left of old one" } delimiters.add(0, left) pieces.add(0, piece) } override fun findPiece(arg: T): Polynomial? { - if (arg < delimiters.first() || arg >= delimiters.last()) { + if (arg < delimiters.first() || arg >= delimiters.last()) return null - } else { - for (index in 1 until delimiters.size) { - if (arg < delimiters[index]) { + else { + for (index in 1 until delimiters.size) + if (arg < delimiters[index]) return pieces[index - 1] - } - } + error("Piece not found") } } @@ -51,7 +49,7 @@ class OrderedPiecewisePolynomial>(delimeter: T) : /** * Return a value of polynomial function with given [ring] an given [arg] or null if argument is outside of piecewise definition. */ -fun , C : Ring> PiecewisePolynomial.value(ring: C, arg: T): T? = +public fun , C : Ring> PiecewisePolynomial.value(ring: C, arg: T): T? = findPiece(arg)?.value(ring, arg) -fun , C : Ring> PiecewisePolynomial.asFunction(ring: C): (T) -> T? = { value(ring, it) } \ No newline at end of file +public fun , C : Ring> PiecewisePolynomial.asFunction(ring: C): (T) -> T? = { value(ring, it) } \ No newline at end of file diff --git a/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Polynomial.kt new file mode 100644 index 000000000..820076c4c --- /dev/null +++ b/kmath-functions/src/commonMain/kotlin/kscience/kmath/functions/Polynomial.kt @@ -0,0 +1,66 @@ +package kscience.kmath.functions + +import kscience.kmath.operations.Ring +import kscience.kmath.operations.Space +import kscience.kmath.operations.invoke +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract +import kotlin.math.max +import kotlin.math.pow + +/** + * Polynomial coefficients without fixation on specific context they are applied to + * @param coefficients constant is the leftmost coefficient + */ +public inline class Polynomial(public val coefficients: List) + +@Suppress("FunctionName") +public fun Polynomial(vararg coefficients: T): Polynomial = Polynomial(coefficients.toList()) + +public fun Polynomial.value(): Double = coefficients.reduceIndexed { index, acc, d -> acc + d.pow(index) } + +public fun > Polynomial.value(ring: C, arg: T): T = ring { + if (coefficients.isEmpty()) return@ring zero + var res = coefficients.first() + var powerArg = arg + + for (index in 1 until coefficients.size) { + res += coefficients[index] * powerArg + // recalculating power on each step to avoid power costs on long polynomials + powerArg *= arg + } + + res +} + +/** + * Represent the polynomial as a regular context-less function + */ +public fun > Polynomial.asFunction(ring: C): (T) -> T = { value(ring, it) } + +/** + * An algebra for polynomials + */ +public class PolynomialSpace>(private val ring: C) : Space> { + public override val zero: Polynomial = Polynomial(emptyList()) + + public override fun add(a: Polynomial, b: Polynomial): Polynomial { + val dim = max(a.coefficients.size, b.coefficients.size) + + return ring { + Polynomial(List(dim) { index -> + a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero } + }) + } + } + + public override fun multiply(a: Polynomial, k: Number): Polynomial = + ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) } + + public operator fun Polynomial.invoke(arg: T): T = value(ring, arg) +} + +public inline fun , R> C.polynomial(block: PolynomialSpace.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return PolynomialSpace(this).block() +} diff --git a/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/Interpolator.kt b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/Interpolator.kt new file mode 100644 index 000000000..0620b4aa8 --- /dev/null +++ b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/Interpolator.kt @@ -0,0 +1,45 @@ +package kscience.kmath.interpolation + +import kscience.kmath.functions.PiecewisePolynomial +import kscience.kmath.functions.value +import kscience.kmath.operations.Ring +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.asBuffer + +public fun interface Interpolator { + public fun interpolate(points: XYPointSet): (X) -> Y +} + +public interface PolynomialInterpolator> : Interpolator { + public val algebra: Ring + + public fun getDefaultValue(): T = error("Out of bounds") + + public fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial + + override fun interpolate(points: XYPointSet): (T) -> T = { x -> + interpolatePolynomials(points).value(algebra, x) ?: getDefaultValue() + } +} + +public fun > PolynomialInterpolator.interpolatePolynomials( + x: Buffer, + y: Buffer +): PiecewisePolynomial { + val pointSet = BufferXYPointSet(x, y) + return interpolatePolynomials(pointSet) +} + +public fun > PolynomialInterpolator.interpolatePolynomials( + data: Map +): PiecewisePolynomial { + val pointSet = BufferXYPointSet(data.keys.toList().asBuffer(), data.values.toList().asBuffer()) + return interpolatePolynomials(pointSet) +} + +public fun > PolynomialInterpolator.interpolatePolynomials( + data: List> +): PiecewisePolynomial { + val pointSet = BufferXYPointSet(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer()) + return interpolatePolynomials(pointSet) +} diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/LinearInterpolator.kt similarity index 57% rename from kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt rename to kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/LinearInterpolator.kt index a7925180d..377aa1fbe 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LinearInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/LinearInterpolator.kt @@ -1,16 +1,16 @@ -package scientifik.kmath.interpolation +package kscience.kmath.interpolation -import scientifik.kmath.functions.OrderedPiecewisePolynomial -import scientifik.kmath.functions.PiecewisePolynomial -import scientifik.kmath.functions.Polynomial -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.invoke +import kscience.kmath.functions.OrderedPiecewisePolynomial +import kscience.kmath.functions.PiecewisePolynomial +import kscience.kmath.functions.Polynomial +import kscience.kmath.operations.Field +import kscience.kmath.operations.invoke /** * Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java */ -class LinearInterpolator>(override val algebra: Field) : PolynomialInterpolator { - override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra { +public class LinearInterpolator>(public override val algebra: Field) : PolynomialInterpolator { + public override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra { require(points.size > 0) { "Point array should not be empty" } insureSorted(points) diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LoessInterpolator.kt b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/LoessInterpolator.kt similarity index 98% rename from kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LoessInterpolator.kt rename to kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/LoessInterpolator.kt index 6707bd8bc..6931857b1 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/LoessInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/LoessInterpolator.kt @@ -1,8 +1,8 @@ -package scientifik.kmath.interpolation +package kscience.kmath.interpolation // -//import scientifik.kmath.functions.PiecewisePolynomial -//import scientifik.kmath.operations.Ring -//import scientifik.kmath.structures.Buffer +//import kscience.kmath.functions.PiecewisePolynomial +//import kscience.kmath.operations.Ring +//import kscience.kmath.structures.Buffer //import kotlin.math.abs //import kotlin.math.sqrt // diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/SplineInterpolator.kt similarity index 69% rename from kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt rename to kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/SplineInterpolator.kt index b709c4e87..6cda45f72 100644 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/SplineInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/SplineInterpolator.kt @@ -1,29 +1,25 @@ -package scientifik.kmath.interpolation +package kscience.kmath.interpolation -import scientifik.kmath.functions.OrderedPiecewisePolynomial -import scientifik.kmath.functions.PiecewisePolynomial -import scientifik.kmath.functions.Polynomial -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.invoke -import scientifik.kmath.structures.MutableBufferFactory +import kscience.kmath.functions.OrderedPiecewisePolynomial +import kscience.kmath.functions.PiecewisePolynomial +import kscience.kmath.functions.Polynomial +import kscience.kmath.operations.Field +import kscience.kmath.operations.invoke +import kscience.kmath.structures.MutableBufferFactory /** * Generic spline interpolator. Not recommended for performance critical places, use platform-specific and type specific ones. * Based on https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java */ -class SplineInterpolator>( - override val algebra: Field, - val bufferFactory: MutableBufferFactory +public class SplineInterpolator>( + public override val algebra: Field, + public val bufferFactory: MutableBufferFactory ) : PolynomialInterpolator { - //TODO possibly optimize zeroed buffers - override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra { - if (points.size < 3) { - error("Can't use spline interpolator with less than 3 points") - } + public override fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial = algebra { + require(points.size >= 3) { "Can't use spline interpolator with less than 3 points" } insureSorted(points) - // Number of intervals. The number of data points is n + 1. val n = points.size - 1 // Differences between knot points @@ -34,6 +30,7 @@ class SplineInterpolator>( for (i in 1 until n) { val g = 2.0 * (points.x[i + 1] - points.x[i - 1]) - h[i - 1] * mu[i - 1] mu[i] = h[i] / g + z[i] = (3.0 * (points.y[i + 1] * h[i - 1] - points.x[i] * (points.x[i + 1] - points.x[i - 1]) + points.y[i - 1] * h[i]) / (h[i - 1] * h[i]) - h[i - 1] * z[i - 1]) / g @@ -41,8 +38,9 @@ class SplineInterpolator>( // cubic spline coefficients -- b is linear, c quadratic, d is cubic (original y's are constants) - OrderedPiecewisePolynomial(points.x[points.size - 1]).apply { + OrderedPiecewisePolynomial(points.x[points.size - 1]).apply { var cOld = zero + for (j in n - 1 downTo 0) { val c = z[j] - mu[j] * cOld val a = points.y[j] @@ -53,7 +51,5 @@ class SplineInterpolator>( putLeft(points.x[j], polynomial) } } - } - } diff --git a/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/XYPointSet.kt b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/XYPointSet.kt new file mode 100644 index 000000000..2abb7742c --- /dev/null +++ b/kmath-functions/src/commonMain/kotlin/kscience/kmath/interpolation/XYPointSet.kt @@ -0,0 +1,53 @@ +package kscience.kmath.interpolation + +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.Structure2D + +public interface XYPointSet { + public val size: Int + public val x: Buffer + public val y: Buffer +} + +public interface XYZPointSet : XYPointSet { + public val z: Buffer +} + +internal fun > insureSorted(points: XYPointSet) { + for (i in 0 until points.size - 1) + require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" } +} + +public class NDStructureColumn(public val structure: Structure2D, public val column: Int) : Buffer { + public override val size: Int + get() = structure.rowNum + + init { + require(column < structure.colNum) { "Column index is outside of structure column range" } + } + + public override operator fun get(index: Int): T = structure[index, column] + public override operator fun iterator(): Iterator = sequence { repeat(size) { yield(get(it)) } }.iterator() +} + +public class BufferXYPointSet( + public override val x: Buffer, + public override val y: Buffer +) : XYPointSet { + public override val size: Int + get() = x.size + + init { + require(x.size == y.size) { "Sizes of x and y buffers should be the same" } + } +} + +public fun Structure2D.asXYPointSet(): XYPointSet { + require(shape[1] == 2) { "Structure second dimension should be of size 2" } + + return object : XYPointSet { + override val size: Int get() = this@asXYPointSet.shape[0] + override val x: Buffer get() = NDStructureColumn(this@asXYPointSet, 0) + override val y: Buffer get() = NDStructureColumn(this@asXYPointSet, 1) + } +} diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt deleted file mode 100644 index c4470ad27..000000000 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/Polynomial.kt +++ /dev/null @@ -1,77 +0,0 @@ -package scientifik.kmath.functions - -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract -import kotlin.math.max -import kotlin.math.pow - -/** - * Polynomial coefficients without fixation on specific context they are applied to - * @param coefficients constant is the leftmost coefficient - */ -inline class Polynomial(val coefficients: List) { - constructor(vararg coefficients: T) : this(coefficients.toList()) -} - -fun Polynomial.value(): Double = - coefficients.reduceIndexed { index: Int, acc: Double, d: Double -> acc + d.pow(index) } - -fun > Polynomial.value(ring: C, arg: T): T = ring { - if (coefficients.isEmpty()) return@ring zero - var res = coefficients.first() - var powerArg = arg - - for (index in 1 until coefficients.size) { - res += coefficients[index] * powerArg - //recalculating power on each step to avoid power costs on long polynomials - powerArg *= arg - } - - res -} - -/** - * Represent a polynomial as a context-dependent function - */ -fun > Polynomial.asMathFunction(): MathFunction = object : - MathFunction { - override operator fun C.invoke(arg: T): T = value(this, arg) -} - -/** - * Represent the polynomial as a regular context-less function - */ -fun > Polynomial.asFunction(ring: C): (T) -> T = { value(ring, it) } - -/** - * An algebra for polynomials - */ -class PolynomialSpace>(val ring: C) : Space> { - - override fun add(a: Polynomial, b: Polynomial): Polynomial { - val dim = max(a.coefficients.size, b.coefficients.size) - - return ring { - Polynomial(List(dim) { index -> - a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero } - }) - } - } - - override fun multiply(a: Polynomial, k: Number): Polynomial = - ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * k }) } - - override val zero: Polynomial = - Polynomial(emptyList()) - - operator fun Polynomial.invoke(arg: T): T = value(ring, arg) -} - -inline fun , R> C.polynomial(block: PolynomialSpace.() -> R): R { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return PolynomialSpace(this).block() -} diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/functions.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/functions.kt deleted file mode 100644 index 2b822b3ba..000000000 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/functions/functions.kt +++ /dev/null @@ -1,33 +0,0 @@ -package scientifik.kmath.functions - -import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.RealField - -/** - * A regular function that could be called only inside specific algebra context - * @param T source type - * @param C source algebra constraint - * @param R result type - */ -interface MathFunction, R> { - operator fun C.invoke(arg: T): R -} - -fun MathFunction.invoke(arg: Double): R = RealField.invoke(arg) - -/** - * A suspendable function defined in algebraic context - */ -interface SuspendableMathFunction, R> { - suspend operator fun C.invoke(arg: T): R -} - -suspend fun SuspendableMathFunction.invoke(arg: Double) = RealField.invoke(arg) - - -/** - * A parametric function with parameter - */ -interface ParametricFunction> { - operator fun C.invoke(arg: T, parameter: P): T -} diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/Interpolator.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/Interpolator.kt deleted file mode 100644 index 8d83e4198..000000000 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/Interpolator.kt +++ /dev/null @@ -1,45 +0,0 @@ -package scientifik.kmath.interpolation - -import scientifik.kmath.functions.PiecewisePolynomial -import scientifik.kmath.functions.value -import scientifik.kmath.operations.Ring -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.asBuffer - -interface Interpolator { - fun interpolate(points: XYPointSet): (X) -> Y -} - -interface PolynomialInterpolator> : Interpolator { - val algebra: Ring - - fun getDefaultValue(): T = error("Out of bounds") - - fun interpolatePolynomials(points: XYPointSet): PiecewisePolynomial - - override fun interpolate(points: XYPointSet): (T) -> T = { x -> - interpolatePolynomials(points).value(algebra, x) ?: getDefaultValue() - } -} - -fun > PolynomialInterpolator.interpolatePolynomials( - x: Buffer, - y: Buffer -): PiecewisePolynomial { - val pointSet = BufferXYPointSet(x, y) - return interpolatePolynomials(pointSet) -} - -fun > PolynomialInterpolator.interpolatePolynomials( - data: Map -): PiecewisePolynomial { - val pointSet = BufferXYPointSet(data.keys.toList().asBuffer(), data.values.toList().asBuffer()) - return interpolatePolynomials(pointSet) -} - -fun > PolynomialInterpolator.interpolatePolynomials( - data: List> -): PiecewisePolynomial { - val pointSet = BufferXYPointSet(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer()) - return interpolatePolynomials(pointSet) -} \ No newline at end of file diff --git a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt b/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt deleted file mode 100644 index 56953f9fc..000000000 --- a/kmath-functions/src/commonMain/kotlin/scientifik/kmath/interpolation/XYPointSet.kt +++ /dev/null @@ -1,52 +0,0 @@ -package scientifik.kmath.interpolation - -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.Structure2D - -interface XYPointSet { - val size: Int - val x: Buffer - val y: Buffer -} - -interface XYZPointSet : XYPointSet { - val z: Buffer -} - -internal fun > insureSorted(points: XYPointSet) { - for (i in 0 until points.size - 1) require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" } -} - -class NDStructureColumn(val structure: Structure2D, val column: Int) : Buffer { - init { - require(column < structure.colNum) { "Column index is outside of structure column range" } - } - - override val size: Int get() = structure.rowNum - - override operator fun get(index: Int): T = structure[index, column] - - override operator fun iterator(): Iterator = sequence { - repeat(size) { - yield(get(it)) - } - }.iterator() -} - -class BufferXYPointSet(override val x: Buffer, override val y: Buffer) : XYPointSet { - init { - require(x.size == y.size) { "Sizes of x and y buffers should be the same" } - } - - override val size: Int - get() = x.size -} - -fun Structure2D.asXYPointSet(): XYPointSet { - require(shape[1] == 2) { "Structure second dimension should be of size 2" } - return object : XYPointSet { - override val size: Int get() = this@asXYPointSet.shape[0] - override val x: Buffer get() = NDStructureColumn(this@asXYPointSet, 0) - override val y: Buffer get() = NDStructureColumn(this@asXYPointSet, 1) - } -} \ No newline at end of file diff --git a/kmath-functions/src/commonTest/kotlin/scientifik/kmath/interpolation/LinearInterpolatorTest.kt b/kmath-functions/src/commonTest/kotlin/kscience/kmath/interpolation/LinearInterpolatorTest.kt similarity index 72% rename from kmath-functions/src/commonTest/kotlin/scientifik/kmath/interpolation/LinearInterpolatorTest.kt rename to kmath-functions/src/commonTest/kotlin/kscience/kmath/interpolation/LinearInterpolatorTest.kt index 23acd835c..303615676 100644 --- a/kmath-functions/src/commonTest/kotlin/scientifik/kmath/interpolation/LinearInterpolatorTest.kt +++ b/kmath-functions/src/commonTest/kotlin/kscience/kmath/interpolation/LinearInterpolatorTest.kt @@ -1,13 +1,12 @@ -package scientifik.kmath.interpolation +package kscience.kmath.interpolation -import scientifik.kmath.functions.PiecewisePolynomial -import scientifik.kmath.functions.asFunction -import scientifik.kmath.operations.RealField +import kscience.kmath.functions.PiecewisePolynomial +import kscience.kmath.functions.asFunction +import kscience.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals - -class LinearInterpolatorTest { +internal class LinearInterpolatorTest { @Test fun testInterpolation() { val data = listOf( @@ -16,12 +15,12 @@ class LinearInterpolatorTest { 2.0 to 3.0, 3.0 to 4.0 ) + val polynomial: PiecewisePolynomial = LinearInterpolator(RealField).interpolatePolynomials(data) val function = polynomial.asFunction(RealField) - assertEquals(null, function(-1.0)) assertEquals(0.5, function(0.5)) assertEquals(2.0, function(1.5)) assertEquals(3.0, function(2.0)) } -} \ No newline at end of file +} diff --git a/kmath-geometry/build.gradle.kts b/kmath-geometry/build.gradle.kts index 39aa833ad..00abcb934 100644 --- a/kmath-geometry/build.gradle.kts +++ b/kmath-geometry/build.gradle.kts @@ -1,9 +1,7 @@ -plugins { - id("scientifik.mpp") -} +plugins { id("ru.mipt.npm.mpp") } kotlin.sourceSets.commonMain { dependencies { api(project(":kmath-core")) } -} \ No newline at end of file +} diff --git a/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/Euclidean2DSpace.kt b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/Euclidean2DSpace.kt new file mode 100644 index 000000000..c2a883a64 --- /dev/null +++ b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/Euclidean2DSpace.kt @@ -0,0 +1,47 @@ +package kscience.kmath.geometry + +import kscience.kmath.linear.Point +import kscience.kmath.operations.SpaceElement +import kscience.kmath.operations.invoke +import kotlin.math.sqrt + +public interface Vector2D : Point, Vector, SpaceElement { + public val x: Double + public val y: Double + public override val context: Euclidean2DSpace get() = Euclidean2DSpace + public override val size: Int get() = 2 + + public override operator fun get(index: Int): Double = when (index) { + 1 -> x + 2 -> y + else -> error("Accessing outside of point bounds") + } + + public override operator fun iterator(): Iterator = listOf(x, y).iterator() + public override fun unwrap(): Vector2D = this + public override fun Vector2D.wrap(): Vector2D = this +} + +public val Vector2D.r: Double + get() = Euclidean2DSpace { sqrt(norm()) } + +@Suppress("FunctionName") +public fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y) + +private data class Vector2DImpl( + override val x: Double, + override val y: Double +) : Vector2D + +/** + * 2D Euclidean space + */ +public object Euclidean2DSpace : GeometrySpace { + public override val zero: Vector2D by lazy { Vector2D(0.0, 0.0) } + + public fun Vector2D.norm(): Double = sqrt(x * x + y * y) + public override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm() + public override fun add(a: Vector2D, b: Vector2D): Vector2D = Vector2D(a.x + b.x, a.y + b.y) + public override fun multiply(a: Vector2D, k: Number): Vector2D = Vector2D(a.x * k.toDouble(), a.y * k.toDouble()) + public override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y +} diff --git a/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/Euclidean3DSpace.kt b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/Euclidean3DSpace.kt new file mode 100644 index 000000000..e0052d791 --- /dev/null +++ b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/Euclidean3DSpace.kt @@ -0,0 +1,53 @@ +package kscience.kmath.geometry + +import kscience.kmath.linear.Point +import kscience.kmath.operations.SpaceElement +import kscience.kmath.operations.invoke +import kotlin.math.sqrt + +public interface Vector3D : Point, Vector, SpaceElement { + public val x: Double + public val y: Double + public val z: Double + public override val context: Euclidean3DSpace get() = Euclidean3DSpace + public override val size: Int get() = 3 + + public override operator fun get(index: Int): Double = when (index) { + 1 -> x + 2 -> y + 3 -> z + else -> error("Accessing outside of point bounds") + } + + public override operator fun iterator(): Iterator = listOf(x, y, z).iterator() + public override fun unwrap(): Vector3D = this + public override fun Vector3D.wrap(): Vector3D = this +} + +@Suppress("FunctionName") +public fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z) + +public val Vector3D.r: Double get() = Euclidean3DSpace { sqrt(norm()) } + +private data class Vector3DImpl( + override val x: Double, + override val y: Double, + override val z: Double +) : Vector3D + +public object Euclidean3DSpace : GeometrySpace { + public override val zero: Vector3D by lazy { Vector3D(0.0, 0.0, 0.0) } + + public fun Vector3D.norm(): Double = sqrt(x * x + y * y + z * z) + + public override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm() + + public override fun add(a: Vector3D, b: Vector3D): Vector3D = + Vector3D(a.x + b.x, a.y + b.y, a.z + b.z) + + public override fun multiply(a: Vector3D, k: Number): Vector3D = + Vector3D(a.x * k.toDouble(), a.y * k.toDouble(), a.z * k.toDouble()) + + public override fun Vector3D.dot(other: Vector3D): Double = + x * other.x + y * other.y + z * other.z +} diff --git a/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/GeometrySpace.kt b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/GeometrySpace.kt new file mode 100644 index 000000000..54d2510cf --- /dev/null +++ b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/GeometrySpace.kt @@ -0,0 +1,17 @@ +package kscience.kmath.geometry + +import kscience.kmath.operations.Space + +public interface Vector + +public interface GeometrySpace : Space { + /** + * L2 distance + */ + public fun V.distanceTo(other: V): Double + + /** + * Scalar product + */ + public infix fun V.dot(other: V): Double +} \ No newline at end of file diff --git a/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/Line.kt b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/Line.kt new file mode 100644 index 000000000..ec2ce31ca --- /dev/null +++ b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/Line.kt @@ -0,0 +1,6 @@ +package kscience.kmath.geometry + +public data class Line(val base: V, val direction: V) + +public typealias Line2D = Line +public typealias Line3D = Line diff --git a/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/ReferenceFrame.kt b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/ReferenceFrame.kt new file mode 100644 index 000000000..f9de7b51f --- /dev/null +++ b/kmath-geometry/src/commonMain/kotlin/kscience/kmath/geometry/ReferenceFrame.kt @@ -0,0 +1,3 @@ +package kscience.kmath.geometry + +public interface ReferenceFrame diff --git a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt deleted file mode 100644 index f0dc49882..000000000 --- a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean2DSpace.kt +++ /dev/null @@ -1,54 +0,0 @@ -package scientifik.kmath.geometry - -import scientifik.kmath.linear.Point -import scientifik.kmath.operations.SpaceElement -import scientifik.kmath.operations.invoke -import kotlin.math.sqrt - - -interface Vector2D : Point, Vector, SpaceElement { - val x: Double - val y: Double - override val context: Euclidean2DSpace get() = Euclidean2DSpace - override val size: Int get() = 2 - - override operator fun get(index: Int): Double = when (index) { - 1 -> x - 2 -> y - else -> error("Accessing outside of point bounds") - } - - override operator fun iterator(): Iterator = listOf(x, y).iterator() - override fun unwrap(): Vector2D = this - override fun Vector2D.wrap(): Vector2D = this -} - -val Vector2D.r: Double get() = Euclidean2DSpace { sqrt(norm()) } - -@Suppress("FunctionName") -fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y) - -private data class Vector2DImpl( - override val x: Double, - override val y: Double -) : Vector2D - -/** - * 2D Euclidean space - */ -object Euclidean2DSpace : GeometrySpace { - fun Vector2D.norm(): Double = sqrt(x * x + y * y) - - override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm() - - override fun add(a: Vector2D, b: Vector2D): Vector2D = - Vector2D(a.x + b.x, a.y + b.y) - - override fun multiply(a: Vector2D, k: Number): Vector2D = - Vector2D(a.x * k.toDouble(), a.y * k.toDouble()) - - override val zero: Vector2D = Vector2D(0.0, 0.0) - - override fun Vector2D.dot(other: Vector2D): Double = - x * other.x + y * other.y -} \ No newline at end of file diff --git a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt deleted file mode 100644 index 3748e58c7..000000000 --- a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Euclidean3DSpace.kt +++ /dev/null @@ -1,56 +0,0 @@ -package scientifik.kmath.geometry - -import scientifik.kmath.linear.Point -import scientifik.kmath.operations.SpaceElement -import scientifik.kmath.operations.invoke -import kotlin.math.sqrt - - -interface Vector3D : Point, Vector, SpaceElement { - val x: Double - val y: Double - val z: Double - override val context: Euclidean3DSpace get() = Euclidean3DSpace - override val size: Int get() = 3 - - override operator fun get(index: Int): Double = when (index) { - 1 -> x - 2 -> y - 3 -> z - else -> error("Accessing outside of point bounds") - } - - override operator fun iterator(): Iterator = listOf(x, y, z).iterator() - - override fun unwrap(): Vector3D = this - - override fun Vector3D.wrap(): Vector3D = this -} - -@Suppress("FunctionName") -fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z) - -val Vector3D.r: Double get() = Euclidean3DSpace { sqrt(norm()) } - -private data class Vector3DImpl( - override val x: Double, - override val y: Double, - override val z: Double -) : Vector3D - -object Euclidean3DSpace : GeometrySpace { - override val zero: Vector3D = Vector3D(0.0, 0.0, 0.0) - - fun Vector3D.norm(): Double = sqrt(x * x + y * y + z * z) - - override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm() - - override fun add(a: Vector3D, b: Vector3D): Vector3D = - Vector3D(a.x + b.x, a.y + b.y, a.z + b.z) - - override fun multiply(a: Vector3D, k: Number): Vector3D = - Vector3D(a.x * k.toDouble(), a.y * k.toDouble(), a.z * k.toDouble()) - - override fun Vector3D.dot(other: Vector3D): Double = - x * other.x + y * other.y + z * other.z -} diff --git a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/GeometrySpace.kt b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/GeometrySpace.kt deleted file mode 100644 index b65a8dd3a..000000000 --- a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/GeometrySpace.kt +++ /dev/null @@ -1,17 +0,0 @@ -package scientifik.kmath.geometry - -import scientifik.kmath.operations.Space - -interface Vector - -interface GeometrySpace: Space { - /** - * L2 distance - */ - fun V.distanceTo(other: V): Double - - /** - * Scalar product - */ - infix fun V.dot(other: V): Double -} \ No newline at end of file diff --git a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Line.kt b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Line.kt deleted file mode 100644 index d802a103f..000000000 --- a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/Line.kt +++ /dev/null @@ -1,6 +0,0 @@ -package scientifik.kmath.geometry - -data class Line(val base: V, val direction: V) - -typealias Line2D = Line -typealias Line3D = Line diff --git a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/ReferenceFrame.kt b/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/ReferenceFrame.kt deleted file mode 100644 index 420e38ce2..000000000 --- a/kmath-geometry/src/commonMain/kotlin/scientifik/kmath/geometry/ReferenceFrame.kt +++ /dev/null @@ -1,4 +0,0 @@ -package scientifik.kmath.geometry - -interface ReferenceFrame { -} \ No newline at end of file diff --git a/kmath-histograms/build.gradle.kts b/kmath-histograms/build.gradle.kts index 993bfed8e..7de21ad89 100644 --- a/kmath-histograms/build.gradle.kts +++ b/kmath-histograms/build.gradle.kts @@ -1,10 +1,8 @@ -plugins { - id("scientifik.mpp") -} +plugins { id("ru.mipt.npm.mpp") } kotlin.sourceSets.commonMain { dependencies { api(project(":kmath-core")) api(project(":kmath-for-real")) } -} \ No newline at end of file +} diff --git a/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Counters.kt b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Counters.kt new file mode 100644 index 000000000..7a263a9fc --- /dev/null +++ b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Counters.kt @@ -0,0 +1,20 @@ +package kscience.kmath.histogram + +/* + * Common representation for atomic counters + * TODO replace with atomics + */ + +public expect class LongCounter() { + public fun decrement() + public fun increment() + public fun reset() + public fun sum(): Long + public fun add(l: Long) +} + +public expect class DoubleCounter() { + public fun reset() + public fun sum(): Double + public fun add(d: Double) +} diff --git a/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Histogram.kt new file mode 100644 index 000000000..370a01215 --- /dev/null +++ b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/Histogram.kt @@ -0,0 +1,54 @@ +package kscience.kmath.histogram + +import kscience.kmath.domains.Domain +import kscience.kmath.linear.Point +import kscience.kmath.structures.ArrayBuffer +import kscience.kmath.structures.RealBuffer + +/** + * The bin in the histogram. The histogram is by definition always done in the real space + */ +public interface Bin : Domain { + /** + * The value of this bin. + */ + public val value: Number + + public val center: Point +} + +public interface Histogram> : Iterable { + /** + * Find existing bin, corresponding to given coordinates + */ + public operator fun get(point: Point): B? + + /** + * Dimension of the histogram + */ + public val dimension: Int +} + +public interface MutableHistogram> : Histogram { + + /** + * Increment appropriate bin + */ + public fun putWithWeight(point: Point, weight: Double) + + public fun put(point: Point): Unit = putWithWeight(point, 1.0) +} + +public fun MutableHistogram.put(vararg point: T): Unit = put(ArrayBuffer(point)) + +public fun MutableHistogram.put(vararg point: Number): Unit = + put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) + +public fun MutableHistogram.put(vararg point: Double): Unit = put(RealBuffer(point)) +public fun MutableHistogram.fill(sequence: Iterable>): Unit = sequence.forEach { put(it) } + +/** + * Pass a sequence builder into histogram + */ +public fun MutableHistogram.fill(block: suspend SequenceScope>.() -> Unit): Unit = + fill(sequence(block).asIterable()) diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/RealHistogram.kt similarity index 56% rename from kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt rename to kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/RealHistogram.kt index f05ae1694..f95264ee1 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/kscience/kmath/histogram/RealHistogram.kt @@ -1,15 +1,17 @@ -package scientifik.kmath.histogram +package kscience.kmath.histogram -import scientifik.kmath.linear.Point -import scientifik.kmath.operations.SpaceOperations -import scientifik.kmath.operations.invoke -import scientifik.kmath.real.asVector -import scientifik.kmath.structures.* +import kscience.kmath.linear.Point +import kscience.kmath.operations.SpaceOperations +import kscience.kmath.operations.invoke +import kscience.kmath.structures.* import kotlin.math.floor - -data class BinDef>(val space: SpaceOperations>, val center: Point, val sizes: Point) { - fun contains(vector: Point): Boolean { +public data class BinDef>( + public val space: SpaceOperations>, + public val center: Point, + public val sizes: Point +) { + public fun contains(vector: Point): Boolean { require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" } val upper = space { center + sizes / 2.0 } val lower = space { center - sizes / 2.0 } @@ -18,21 +20,20 @@ data class BinDef>(val space: SpaceOperations>, val c } -class MultivariateBin>(val def: BinDef, override val value: Number) : Bin { - override operator fun contains(point: Point): Boolean = def.contains(point) - - override val dimension: Int +public class MultivariateBin>(public val def: BinDef, public override val value: Number) : Bin { + public override val dimension: Int get() = def.center.size - override val center: Point + public override val center: Point get() = def.center + public override operator fun contains(point: Point): Boolean = def.contains(point) } /** * Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions. */ -class RealHistogram( +public class RealHistogram( private val lower: Buffer, private val upper: Buffer, private val binNums: IntArray = IntArray(lower.size) { 20 } @@ -40,7 +41,7 @@ class RealHistogram( private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 }) private val values: NDStructure = NDStructure.auto(strides) { LongCounter() } private val weights: NDStructure = NDStructure.auto(strides) { DoubleCounter() } - override val dimension: Int get() = lower.size + public override val dimension: Int get() = lower.size private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } init { @@ -64,7 +65,7 @@ class RealHistogram( private fun getValue(index: IntArray): Long = values[index].sum() - fun getValue(point: Buffer): Long = getValue(getIndex(point)) + public fun getValue(point: Buffer): Long = getValue(getIndex(point)) private fun getDef(index: IntArray): BinDef { val center = index.mapIndexed { axis, i -> @@ -78,9 +79,9 @@ class RealHistogram( return BinDef(RealBufferFieldOperations, center, binSize) } - fun getDef(point: Buffer): BinDef = getDef(getIndex(point)) + public fun getDef(point: Buffer): BinDef = getDef(getIndex(point)) - override operator fun get(point: Buffer): MultivariateBin? { + public override operator fun get(point: Buffer): MultivariateBin? { val index = getIndex(point) return MultivariateBin(getDef(index), getValue(index)) } @@ -90,27 +91,27 @@ class RealHistogram( // values[index].increment() // } - override fun putWithWeight(point: Buffer, weight: Double) { + public override fun putWithWeight(point: Buffer, weight: Double) { val index = getIndex(point) values[index].increment() weights[index].add(weight) } - override operator fun iterator(): Iterator> = weights.elements().map { (index, value) -> - MultivariateBin(getDef(index), value.sum()) - }.iterator() + public override operator fun iterator(): Iterator> = + weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) } + .iterator() /** * Convert this histogram into NDStructure containing bin values but not bin descriptions */ - fun values(): NDStructure = NDStructure.auto(values.shape) { values[it].sum() } + public fun values(): NDStructure = NDStructure.auto(values.shape) { values[it].sum() } /** * Sum of weights */ - fun weights(): NDStructure = NDStructure.auto(weights.shape) { weights[it].sum() } + public fun weights(): NDStructure = NDStructure.auto(weights.shape) { weights[it].sum() } - companion object { + public companion object { /** * Use it like * ``` @@ -120,9 +121,9 @@ class RealHistogram( *) *``` */ - fun fromRanges(vararg ranges: ClosedFloatingPointRange): RealHistogram = RealHistogram( - ranges.map { it.start }.asVector(), - ranges.map { it.endInclusive }.asVector() + public fun fromRanges(vararg ranges: ClosedFloatingPointRange): RealHistogram = RealHistogram( + ranges.map(ClosedFloatingPointRange::start).asBuffer(), + ranges.map(ClosedFloatingPointRange::endInclusive).asBuffer() ) /** @@ -134,10 +135,21 @@ class RealHistogram( *) *``` */ - fun fromRanges(vararg ranges: Pair, Int>): RealHistogram = RealHistogram( - ListBuffer(ranges.map { it.first.start }), - ListBuffer(ranges.map { it.first.endInclusive }), - ranges.map { it.second }.toIntArray() - ) + public fun fromRanges(vararg ranges: Pair, Int>): RealHistogram = + RealHistogram( + ListBuffer( + ranges + .map(Pair, Int>::first) + .map(ClosedFloatingPointRange::start) + ), + + ListBuffer( + ranges + .map(Pair, Int>::first) + .map(ClosedFloatingPointRange::endInclusive) + ), + + ranges.map(Pair, Int>::second).toIntArray() + ) } } diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt deleted file mode 100644 index 9c7de3303..000000000 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt +++ /dev/null @@ -1,20 +0,0 @@ -package scientifik.kmath.histogram - -/* - * Common representation for atomic counters - * TODO replace with atomics - */ - -expect class LongCounter() { - fun decrement() - fun increment() - fun reset() - fun sum(): Long - fun add(l: Long) -} - -expect class DoubleCounter() { - fun reset() - fun sum(): Double - fun add(d: Double) -} \ No newline at end of file diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt deleted file mode 100644 index 9ff2aacf5..000000000 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ /dev/null @@ -1,59 +0,0 @@ -package scientifik.kmath.histogram - -import scientifik.kmath.domains.Domain -import scientifik.kmath.linear.Point -import scientifik.kmath.structures.ArrayBuffer -import scientifik.kmath.structures.RealBuffer -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract - -/** - * The bin in the histogram. The histogram is by definition always done in the real space - */ -interface Bin : Domain { - /** - * The value of this bin - */ - val value: Number - val center: Point -} - -interface Histogram> : Iterable { - - /** - * Find existing bin, corresponding to given coordinates - */ - operator fun get(point: Point): B? - - /** - * Dimension of the histogram - */ - val dimension: Int - -} - -interface MutableHistogram> : Histogram { - - /** - * Increment appropriate bin - */ - fun putWithWeight(point: Point, weight: Double) - - fun put(point: Point): Unit = putWithWeight(point, 1.0) -} - -fun MutableHistogram.put(vararg point: T): Unit = put(ArrayBuffer(point)) - -fun MutableHistogram.put(vararg point: Number): Unit = - put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) - -fun MutableHistogram.put(vararg point: Double): Unit = put(RealBuffer(point)) - -fun MutableHistogram.fill(sequence: Iterable>): Unit = sequence.forEach { put(it) } - -/** - * Pass a sequence builder into histogram - */ -fun MutableHistogram.fill(block: suspend SequenceScope>.() -> Unit): Unit = - fill(sequence(block).asIterable()) diff --git a/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt b/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt index 5edecb5a5..af22afc6b 100644 --- a/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt +++ b/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt @@ -1,16 +1,15 @@ package scietifik.kmath.histogram -import scientifik.kmath.histogram.RealHistogram -import scientifik.kmath.histogram.fill -import scientifik.kmath.histogram.put -import scientifik.kmath.real.RealVector +import kscience.kmath.histogram.RealHistogram +import kscience.kmath.histogram.fill +import kscience.kmath.histogram.put +import kscience.kmath.real.RealVector +import kscience.kmath.real.invoke +import kscience.kmath.structures.Buffer import kotlin.random.Random -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertFalse -import kotlin.test.assertTrue +import kotlin.test.* -class MultivariateHistogramTest { +internal class MultivariateHistogramTest { @Test fun testSinglePutHistogram() { val histogram = RealHistogram.fromRanges( @@ -18,7 +17,7 @@ class MultivariateHistogramTest { (-1.0..1.0) ) histogram.put(0.55, 0.55) - val bin = histogram.find { it.value.toInt() > 0 }!! + val bin = histogram.find { it.value.toInt() > 0 } ?: fail() assertTrue { bin.contains(RealVector(0.55, 0.55)) } assertTrue { bin.contains(RealVector(0.6, 0.5)) } assertFalse { bin.contains(RealVector(-0.55, 0.55)) } diff --git a/kmath-histograms/src/jsMain/kotlin/kscience/kmath/histogram/Counters.kt b/kmath-histograms/src/jsMain/kotlin/kscience/kmath/histogram/Counters.kt new file mode 100644 index 000000000..d0fa1f4c2 --- /dev/null +++ b/kmath-histograms/src/jsMain/kotlin/kscience/kmath/histogram/Counters.kt @@ -0,0 +1,37 @@ +package kscience.kmath.histogram + +public actual class LongCounter { + private var sum: Long = 0L + + public actual fun decrement() { + sum-- + } + + public actual fun increment() { + sum++ + } + + public actual fun reset() { + sum = 0 + } + + public actual fun sum(): Long = sum + + public actual fun add(l: Long) { + sum += l + } +} + +public actual class DoubleCounter { + private var sum: Double = 0.0 + + public actual fun reset() { + sum = 0.0 + } + + public actual fun sum(): Double = sum + + public actual fun add(d: Double) { + sum += d + } +} diff --git a/kmath-histograms/src/jsMain/kotlin/scientifik/kmath/histogram/Counters.kt b/kmath-histograms/src/jsMain/kotlin/scientifik/kmath/histogram/Counters.kt deleted file mode 100644 index 3765220b9..000000000 --- a/kmath-histograms/src/jsMain/kotlin/scientifik/kmath/histogram/Counters.kt +++ /dev/null @@ -1,33 +0,0 @@ -package scientifik.kmath.histogram - -actual class LongCounter { - private var sum: Long = 0 - actual fun decrement() { - sum-- - } - - actual fun increment() { - sum++ - } - - actual fun reset() { - sum = 0 - } - - actual fun sum(): Long = sum - actual fun add(l: Long) { - sum += l - } -} - -actual class DoubleCounter { - private var sum: Double = 0.0 - actual fun reset() { - sum = 0.0 - } - - actual fun sum(): Double = sum - actual fun add(d: Double) { - sum += d - } -} \ No newline at end of file diff --git a/kmath-histograms/src/jvmMain/kotlin/kscience/kmath/histogram/Counters.kt b/kmath-histograms/src/jvmMain/kotlin/kscience/kmath/histogram/Counters.kt new file mode 100644 index 000000000..efbd185ef --- /dev/null +++ b/kmath-histograms/src/jvmMain/kotlin/kscience/kmath/histogram/Counters.kt @@ -0,0 +1,7 @@ +package kscience.kmath.histogram + +import java.util.concurrent.atomic.DoubleAdder +import java.util.concurrent.atomic.LongAdder + +public actual typealias LongCounter = LongAdder +public actual typealias DoubleCounter = DoubleAdder diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/kscience/kmath/histogram/UnivariateHistogram.kt similarity index 54% rename from kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt rename to kmath-histograms/src/jvmMain/kotlin/kscience/kmath/histogram/UnivariateHistogram.kt index e30a45f5a..2f3855892 100644 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/kscience/kmath/histogram/UnivariateHistogram.kt @@ -1,32 +1,33 @@ -package scientifik.kmath.histogram +package kscience.kmath.histogram -import scientifik.kmath.real.RealVector -import scientifik.kmath.real.asVector -import scientifik.kmath.structures.Buffer +import kscience.kmath.real.RealVector +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.asBuffer import java.util.* import kotlin.math.floor //TODO move to common -class UnivariateBin(val position: Double, val size: Double, val counter: LongCounter = LongCounter()) : Bin { +public class UnivariateBin( + public val position: Double, + public val size: Double, + public val counter: LongCounter = LongCounter() +) : Bin { //TODO add weighting - override val value: Number get() = counter.sum() + public override val value: Number get() = counter.sum() - override val center: RealVector get() = doubleArrayOf(position).asVector() + public override val center: RealVector get() = doubleArrayOf(position).asBuffer() + public override val dimension: Int get() = 1 - operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) - - override fun contains(point: Buffer): Boolean = contains(point[0]) - - internal operator fun inc() = this.also { counter.increment() } - - override val dimension: Int get() = 1 + public operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) + public override fun contains(point: Buffer): Boolean = contains(point[0]) + internal operator fun inc(): UnivariateBin = this.also { counter.increment() } } /** * Univariate histogram with log(n) bin search speed */ -class UnivariateHistogram private constructor(private val factory: (Double) -> UnivariateBin) : +public class UnivariateHistogram private constructor(private val factory: (Double) -> UnivariateBin) : MutableHistogram { private val bins: TreeMap = TreeMap() @@ -43,19 +44,19 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U } private fun createBin(value: Double): UnivariateBin = factory(value).also { - synchronized(this) { bins.put(it.position, it) } + synchronized(this) { bins[it.position] = it } } - override operator fun get(point: Buffer): UnivariateBin? = get(point[0]) + public override operator fun get(point: Buffer): UnivariateBin? = get(point[0]) - override val dimension: Int get() = 1 + public override val dimension: Int get() = 1 - override operator fun iterator(): Iterator = bins.values.iterator() + public override operator fun iterator(): Iterator = bins.values.iterator() /** * Thread safe put operation */ - fun put(value: Double) { + public fun put(value: Double) { (get(value) ?: createBin(value)).inc() } @@ -64,13 +65,13 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U put(point[0]) } - companion object { - fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value -> + public companion object { + public fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value -> val center = start + binSize * floor((value - start) / binSize + 0.5) UnivariateBin(center, binSize) } - fun custom(borders: DoubleArray): UnivariateHistogram { + public fun custom(borders: DoubleArray): UnivariateHistogram { val sorted = borders.sortedArray() return UnivariateHistogram { value -> @@ -79,12 +80,14 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U Double.NEGATIVE_INFINITY, Double.MAX_VALUE ) + value > sorted.last() -> UnivariateBin( Double.POSITIVE_INFINITY, Double.MAX_VALUE ) + else -> { - val index = (0 until sorted.size).first { value > sorted[it] } + val index = sorted.indices.first { value > sorted[it] } val left = sorted[index] val right = sorted[index + 1] UnivariateBin((left + right) / 2, (right - left)) @@ -95,4 +98,4 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U } } -fun UnivariateHistogram.fill(sequence: Iterable) = sequence.forEach { put(it) } +public fun UnivariateHistogram.fill(sequence: Iterable): Unit = sequence.forEach(::put) diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/Counters.kt b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/Counters.kt deleted file mode 100644 index bb3667f7d..000000000 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/Counters.kt +++ /dev/null @@ -1,7 +0,0 @@ -package scientifik.kmath.histogram - -import java.util.concurrent.atomic.DoubleAdder -import java.util.concurrent.atomic.LongAdder - -actual typealias LongCounter = LongAdder -actual typealias DoubleCounter = DoubleAdder \ No newline at end of file diff --git a/kmath-koma/build.gradle.kts b/kmath-koma/build.gradle.kts deleted file mode 100644 index 26955bca7..000000000 --- a/kmath-koma/build.gradle.kts +++ /dev/null @@ -1,31 +0,0 @@ -plugins { - id("scientifik.mpp") -} - -repositories { - maven("http://dl.bintray.com/kyonifer/maven") -} - -kotlin.sourceSets { - commonMain { - dependencies { - api(project(":kmath-core")) - api("com.kyonifer:koma-core-api-common:0.12") - } - } - jvmMain { - dependencies { - api("com.kyonifer:koma-core-api-jvm:0.12") - } - } - jvmTest { - dependencies { - implementation("com.kyonifer:koma-core-ejml:0.12") - } - } - jsMain { - dependencies { - api("com.kyonifer:koma-core-api-js:0.12") - } - } -} diff --git a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt deleted file mode 100644 index bd8fa782a..000000000 --- a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt +++ /dev/null @@ -1,110 +0,0 @@ -package scientifik.kmath.linear - -import koma.extensions.fill -import koma.matrix.MatrixFactory -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke -import scientifik.kmath.structures.Matrix -import scientifik.kmath.structures.NDStructure - -class KomaMatrixContext( - private val factory: MatrixFactory>, - private val space: Space -) : MatrixContext { - - override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): KomaMatrix = - KomaMatrix(factory.zeros(rows, columns).fill(initializer)) - - fun Matrix.toKoma(): KomaMatrix = if (this is KomaMatrix) { - this - } else { - produce(rowNum, colNum) { i, j -> get(i, j) } - } - - fun Point.toKoma(): KomaVector = if (this is KomaVector) { - this - } else { - KomaVector(factory.zeros(size, 1).fill { i, _ -> get(i) }) - } - - - override fun Matrix.dot(other: Matrix): KomaMatrix = - KomaMatrix(toKoma().origin * other.toKoma().origin) - - override fun Matrix.dot(vector: Point): KomaVector = - KomaVector(toKoma().origin * vector.toKoma().origin) - - override operator fun Matrix.unaryMinus(): KomaMatrix = - KomaMatrix(toKoma().origin.unaryMinus()) - - override fun add(a: Matrix, b: Matrix): KomaMatrix = - KomaMatrix(a.toKoma().origin + b.toKoma().origin) - - override operator fun Matrix.minus(b: Matrix): KomaMatrix = - KomaMatrix(toKoma().origin - b.toKoma().origin) - - override fun multiply(a: Matrix, k: Number): Matrix = - produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } } - - override operator fun Matrix.times(value: T): KomaMatrix = - KomaMatrix(toKoma().origin * value) - - companion object -} - -fun KomaMatrixContext.solve(a: Matrix, b: Matrix) = - KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin)) - -fun KomaMatrixContext.solve(a: Matrix, b: Point) = - KomaVector(a.toKoma().origin.solve(b.toKoma().origin)) - -fun KomaMatrixContext.inverse(a: Matrix) = - KomaMatrix(a.toKoma().origin.inv()) - -class KomaMatrix(val origin: koma.matrix.Matrix, features: Set? = null) : FeaturedMatrix { - override val rowNum: Int get() = origin.numRows() - override val colNum: Int get() = origin.numCols() - - override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols()) - - override val features: Set = features ?: hashSetOf( - object : DeterminantFeature { - override val determinant: T get() = origin.det() - }, - - object : LUPDecompositionFeature { - private val lup by lazy { origin.LU() } - override val l: FeaturedMatrix get() = KomaMatrix(lup.second) - override val u: FeaturedMatrix get() = KomaMatrix(lup.third) - override val p: FeaturedMatrix get() = KomaMatrix(lup.first) - } - ) - - override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix = - KomaMatrix(this.origin, this.features + features) - - override operator fun get(i: Int, j: Int): T = origin.getGeneric(i, j) - - override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) - } - - override fun hashCode(): Int { - var result = origin.hashCode() - result = 31 * result + features.hashCode() - return result - } - - -} - -class KomaVector internal constructor(val origin: koma.matrix.Matrix) : Point { - override val size: Int get() = origin.numRows() - - init { - require(origin.numCols() == 1) { error("Only single column matrices are allowed") } - } - - override operator fun get(index: Int): T = origin.getGeneric(index) - override operator fun iterator(): Iterator = origin.toIterable().iterator() -} diff --git a/kmath-kotlingrad/build.gradle.kts b/kmath-kotlingrad/build.gradle.kts new file mode 100644 index 000000000..3925a744c --- /dev/null +++ b/kmath-kotlingrad/build.gradle.kts @@ -0,0 +1,9 @@ +plugins { + id("ru.mipt.npm.jvm") +} + +dependencies { + implementation("com.github.breandan:kaliningraph:0.1.4") + implementation("com.github.breandan:kotlingrad:0.4.0") + api(project(":kmath-ast")) +} diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt new file mode 100644 index 000000000..abde9e54d --- /dev/null +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt @@ -0,0 +1,53 @@ +package kscience.kmath.kotlingrad + +import edu.umontreal.kotlingrad.api.SFun +import kscience.kmath.ast.MST +import kscience.kmath.ast.MstAlgebra +import kscience.kmath.ast.MstExpression +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Symbol +import kscience.kmath.operations.NumericAlgebra + +/** + * Represents wrapper of [MstExpression] implementing [DifferentiableExpression]. + * + * The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting + * [SFun] back to [MST]. + * + * @param T the type of number. + * @param A the [NumericAlgebra] of [T]. + * @property expr the underlying [MstExpression]. + */ +public inline class DifferentiableMstExpression(public val expr: MstExpression) : + DifferentiableExpression> where A : NumericAlgebra, T : Number { + public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst)) + + /** + * The [MstExpression.algebra] of [expr]. + */ + public val algebra: A + get() = expr.algebra + + /** + * The [MstExpression.mst] of [expr]. + */ + public val mst: MST + get() = expr.mst + + public override fun invoke(arguments: Map): T = expr(arguments) + + public override fun derivativeOrNull(symbols: List): MstExpression = MstExpression( + algebra, + symbols.map(Symbol::identity) + .map(MstAlgebra::symbol) + .map { it.toSVar>() } + .fold(mst.toSFun(), SFun>::d) + .toMst(), + ) +} + +/** + * Wraps this [MstExpression] into [DifferentiableMstExpression]. + */ +public fun > MstExpression.differentiable(): DifferentiableMstExpression = + DifferentiableMstExpression(this) diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/KMathNumber.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/KMathNumber.kt new file mode 100644 index 000000000..2a4db4258 --- /dev/null +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/KMathNumber.kt @@ -0,0 +1,18 @@ +package kscience.kmath.kotlingrad + +import edu.umontreal.kotlingrad.api.RealNumber +import edu.umontreal.kotlingrad.api.SConst +import kscience.kmath.operations.NumericAlgebra + +/** + * Implements [RealNumber] by delegating its functionality to [NumericAlgebra]. + * + * @param T the type of number. + * @param A the [NumericAlgebra] of [T]. + * @property algebra the algebra. + * @param value the value of this number. + */ +public class KMathNumber(public val algebra: A, value: T) : + RealNumber, T>(value) where T : Number, A : NumericAlgebra { + public override fun wrap(number: Number): SConst> = SConst(algebra.number(number)) +} diff --git a/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt new file mode 100644 index 000000000..8dc1d3958 --- /dev/null +++ b/kmath-kotlingrad/src/main/kotlin/kscience/kmath/kotlingrad/ScalarsAdapters.kt @@ -0,0 +1,124 @@ +package kscience.kmath.kotlingrad + +import edu.umontreal.kotlingrad.api.* +import kscience.kmath.ast.MST +import kscience.kmath.ast.MstAlgebra +import kscience.kmath.ast.MstExtendedField +import kscience.kmath.ast.MstExtendedField.unaryMinus +import kscience.kmath.operations.* + +/** + * Maps [SVar] to [MST.Symbolic] directly. + * + * @receiver the variable. + * @return a node. + */ +public fun > SVar.toMst(): MST.Symbolic = MstAlgebra.symbol(name) + +/** + * Maps [SVar] to [MST.Numeric] directly. + * + * @receiver the constant. + * @return a node. + */ +public fun > SConst.toMst(): MST.Numeric = MstAlgebra.number(doubleValue) + +/** + * Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then. + * [Power] operation is limited to constant right-hand side arguments. + * + * Detailed mapping is: + * + * - [SVar] -> [MstExtendedField.symbol]; + * - [SConst] -> [MstExtendedField.number]; + * - [Sum] -> [MstExtendedField.add]; + * - [Prod] -> [MstExtendedField.multiply]; + * - [Power] -> [MstExtendedField.power] (limited to constant exponents only); + * - [Negative] -> [MstExtendedField.unaryMinus]; + * - [Log] -> [MstExtendedField.ln] (left) / [MstExtendedField.ln] (right); + * - [Sine] -> [MstExtendedField.sin]; + * - [Cosine] -> [MstExtendedField.cos]; + * - [Tangent] -> [MstExtendedField.tan]; + * - [DProd] is vector operation, and it is requested to be evaluated; + * - [SComposition] is also requested to be evaluated eagerly; + * - [VSumAll] is requested to be evaluated; + * - [Derivative] is requested to be evaluated. + * + * @receiver the scalar function. + * @return a node. + */ +public fun > SFun.toMst(): MST = MstExtendedField { + when (this@toMst) { + is SVar -> toMst() + is SConst -> toMst() + is Sum -> left.toMst() + right.toMst() + is Prod -> left.toMst() * right.toMst() + is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue) + is Negative -> -input.toMst() + is Log -> ln(left.toMst()) / ln(right.toMst()) + is Sine -> sin(input.toMst()) + is Cosine -> cos(input.toMst()) + is Tangent -> tan(input.toMst()) + is DProd -> this@toMst().toMst() + is SComposition -> this@toMst().toMst() + is VSumAll -> this@toMst().toMst() + is Derivative -> this@toMst().toMst() + } +} + +/** + * Maps [MST.Numeric] to [SConst] directly. + * + * @receiver the node. + * @return a new constant. + */ +public fun > MST.Numeric.toSConst(): SConst = SConst(value) + +/** + * Maps [MST.Symbolic] to [SVar] directly. + * + * @receiver the node. + * @param proto the prototype instance. + * @return a new variable. + */ +internal fun > MST.Symbolic.toSVar(): SVar = SVar(value) + +/** + * Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException]. + * + * Detailed mapping is: + * + * - [MST.Numeric] -> [SConst]; + * - [MST.Symbolic] -> [SVar]; + * - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log]; + * - [MST.Binary] -> [Sum], [Prod], [Power]. + * + * @receiver the node. + * @param proto the prototype instance. + * @return a scalar function. + */ +public fun > MST.toSFun(): SFun = when (this) { + is MST.Numeric -> toSConst() + is MST.Symbolic -> toSVar() + + is MST.Unary -> when (operation) { + SpaceOperations.PLUS_OPERATION -> +value.toSFun() + SpaceOperations.MINUS_OPERATION -> -value.toSFun() + TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun()) + TrigonometricOperations.COS_OPERATION -> cos(value.toSFun()) + TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun()) + PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun()) + ExponentialOperations.EXP_OPERATION -> exp(value.toSFun()) + ExponentialOperations.LN_OPERATION -> value.toSFun().ln() + else -> error("Unary operation $operation not defined in $this") + } + + is MST.Binary -> when (operation) { + SpaceOperations.PLUS_OPERATION -> left.toSFun() + right.toSFun() + SpaceOperations.MINUS_OPERATION -> left.toSFun() - right.toSFun() + RingOperations.TIMES_OPERATION -> left.toSFun() * right.toSFun() + FieldOperations.DIV_OPERATION -> left.toSFun() / right.toSFun() + PowerOperations.POW_OPERATION -> left.toSFun() pow (right as MST.Numeric).toSConst() + else -> error("Binary operation $operation not defined in $this") + } +} diff --git a/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt b/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt new file mode 100644 index 000000000..aa4ddd703 --- /dev/null +++ b/kmath-kotlingrad/src/test/kotlin/kscience/kmath/kotlingrad/AdaptingTests.kt @@ -0,0 +1,64 @@ +package kscience.kmath.kotlingrad + +import edu.umontreal.kotlingrad.api.* +import kscience.kmath.asm.compile +import kscience.kmath.ast.MstAlgebra +import kscience.kmath.ast.MstExpression +import kscience.kmath.ast.parseMath +import kscience.kmath.expressions.invoke +import kscience.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.fail + +internal class AdaptingTests { + @Test + fun symbol() { + val c1 = MstAlgebra.symbol("x") + assertTrue(c1.toSVar>().name == "x") + val c2 = "kitten".parseMath().toSFun>() + if (c2 is SVar) assertTrue(c2.name == "kitten") else fail() + } + + @Test + fun number() { + val c1 = MstAlgebra.number(12354324) + assertTrue(c1.toSConst().doubleValue == 12354324.0) + val c2 = "0.234".parseMath().toSFun>() + if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail() + val c3 = "1e-3".parseMath().toSFun>() + if (c3 is SConst) assertEquals(0.001, c3.value) else fail() + } + + @Test + fun simpleFunctionShape() { + val linear = "2*x+16".parseMath().toSFun>() + if (linear !is Sum) fail() + if (linear.left !is Prod) fail() + if (linear.right !is SConst) fail() + } + + @Test + fun simpleFunctionDerivative() { + val x = MstAlgebra.symbol("x").toSVar>() + val quadratic = "x^2-4*x-44".parseMath().toSFun>() + val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile() + val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile() + assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) + } + + @Test + fun moreComplexDerivative() { + val x = MstAlgebra.symbol("x").toSVar>() + val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun>() + val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile() + + val expectedDerivative = MstExpression( + RealField, + "-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath() + ).compile() + + assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1)) + } +} diff --git a/kmath-memory/build.gradle.kts b/kmath-memory/build.gradle.kts index 75b4f174e..9f92cca92 100644 --- a/kmath-memory/build.gradle.kts +++ b/kmath-memory/build.gradle.kts @@ -1,3 +1,4 @@ plugins { - id("scientifik.mpp") -} \ No newline at end of file + id("ru.mipt.npm.mpp") + id("ru.mipt.npm.native") +} diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt b/kmath-memory/src/commonMain/kotlin/kscience/kmath/memory/Memory.kt similarity index 63% rename from kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt rename to kmath-memory/src/commonMain/kotlin/kscience/kmath/memory/Memory.kt index 177c6b46b..344a1f1d3 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/Memory.kt +++ b/kmath-memory/src/commonMain/kotlin/kscience/kmath/memory/Memory.kt @@ -1,4 +1,4 @@ -package scientifik.memory +package kscience.kmath.memory import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -6,84 +6,84 @@ import kotlin.contracts.contract /** * Represents a display of certain memory structure. */ -interface Memory { +public interface Memory { /** * The length of this memory in bytes. */ - val size: Int + public val size: Int /** * Get a projection of this memory (it reflects the changes in the parent memory block). */ - fun view(offset: Int, length: Int): Memory + public fun view(offset: Int, length: Int): Memory /** * Creates an independent copy of this memory. */ - fun copy(): Memory + public fun copy(): Memory /** * Gets or creates a reader of this memory. */ - fun reader(): MemoryReader + public fun reader(): MemoryReader /** * Gets or creates a writer of this memory. */ - fun writer(): MemoryWriter + public fun writer(): MemoryWriter - companion object + public companion object } /** * The interface to read primitive types in this memory. */ -interface MemoryReader { +public interface MemoryReader { /** * The underlying memory. */ - val memory: Memory + public val memory: Memory /** * Reads [Double] at certain [offset]. */ - fun readDouble(offset: Int): Double + public fun readDouble(offset: Int): Double /** * Reads [Float] at certain [offset]. */ - fun readFloat(offset: Int): Float + public fun readFloat(offset: Int): Float /** * Reads [Byte] at certain [offset]. */ - fun readByte(offset: Int): Byte + public fun readByte(offset: Int): Byte /** * Reads [Short] at certain [offset]. */ - fun readShort(offset: Int): Short + public fun readShort(offset: Int): Short /** * Reads [Int] at certain [offset]. */ - fun readInt(offset: Int): Int + public fun readInt(offset: Int): Int /** * Reads [Long] at certain [offset]. */ - fun readLong(offset: Int): Long + public fun readLong(offset: Int): Long /** * Disposes this reader if needed. */ - fun release() + public fun release() } /** * Uses the memory for read then releases the reader. */ -inline fun Memory.read(block: MemoryReader.() -> R): R { +public inline fun Memory.read(block: MemoryReader.() -> R): R { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } val reader = reader() val result = reader.block() @@ -94,52 +94,52 @@ inline fun Memory.read(block: MemoryReader.() -> R): R { /** * The interface to write primitive types into this memory. */ -interface MemoryWriter { +public interface MemoryWriter { /** * The underlying memory. */ - val memory: Memory + public val memory: Memory /** * Writes [Double] at certain [offset]. */ - fun writeDouble(offset: Int, value: Double) + public fun writeDouble(offset: Int, value: Double) /** * Writes [Float] at certain [offset]. */ - fun writeFloat(offset: Int, value: Float) + public fun writeFloat(offset: Int, value: Float) /** * Writes [Byte] at certain [offset]. */ - fun writeByte(offset: Int, value: Byte) + public fun writeByte(offset: Int, value: Byte) /** * Writes [Short] at certain [offset]. */ - fun writeShort(offset: Int, value: Short) + public fun writeShort(offset: Int, value: Short) /** * Writes [Int] at certain [offset]. */ - fun writeInt(offset: Int, value: Int) + public fun writeInt(offset: Int, value: Int) /** * Writes [Long] at certain [offset]. */ - fun writeLong(offset: Int, value: Long) + public fun writeLong(offset: Int, value: Long) /** * Disposes this writer if needed. */ - fun release() + public fun release() } /** * Uses the memory for write then releases the writer. */ -inline fun Memory.write(block: MemoryWriter.() -> Unit) { +public inline fun Memory.write(block: MemoryWriter.() -> Unit) { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } writer().apply(block).release() } @@ -147,10 +147,10 @@ inline fun Memory.write(block: MemoryWriter.() -> Unit) { /** * Allocates the most effective platform-specific memory. */ -expect fun Memory.Companion.allocate(length: Int): Memory +public expect fun Memory.Companion.allocate(length: Int): Memory /** * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied * and could be mutated independently from the resulting [Memory]. */ -expect fun Memory.Companion.wrap(array: ByteArray): Memory +public expect fun Memory.Companion.wrap(array: ByteArray): Memory diff --git a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt b/kmath-memory/src/commonMain/kotlin/kscience/kmath/memory/MemorySpec.kt similarity index 58% rename from kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt rename to kmath-memory/src/commonMain/kotlin/kscience/kmath/memory/MemorySpec.kt index 1381afbec..572dab0fa 100644 --- a/kmath-memory/src/commonMain/kotlin/scientifik/memory/MemorySpec.kt +++ b/kmath-memory/src/commonMain/kotlin/kscience/kmath/memory/MemorySpec.kt @@ -1,49 +1,49 @@ -package scientifik.memory +package kscience.kmath.memory /** * A specification to read or write custom objects with fixed size in bytes. * * @param T the type of object this spec manages. */ -interface MemorySpec { +public interface MemorySpec { /** * Size of [T] in bytes after serialization. */ - val objectSize: Int + public val objectSize: Int /** * Reads the object starting from [offset]. */ - fun MemoryReader.read(offset: Int): T + public fun MemoryReader.read(offset: Int): T // TODO consider thread safety /** * Writes the object [value] starting from [offset]. */ - fun MemoryWriter.write(offset: Int, value: T) + public fun MemoryWriter.write(offset: Int, value: T) } /** * Reads the object with [spec] starting from [offset]. */ -fun MemoryReader.read(spec: MemorySpec, offset: Int): T = with(spec) { read(offset) } +public fun MemoryReader.read(spec: MemorySpec, offset: Int): T = with(spec) { read(offset) } /** * Writes the object [value] with [spec] starting from [offset]. */ -fun MemoryWriter.write(spec: MemorySpec, offset: Int, value: T): Unit = with(spec) { write(offset, value) } +public fun MemoryWriter.write(spec: MemorySpec, offset: Int, value: T): Unit = with(spec) { write(offset, value) } /** * Reads array of [size] objects mapped by [spec] at certain [offset]. */ -inline fun MemoryReader.readArray(spec: MemorySpec, offset: Int, size: Int): Array = +public inline fun MemoryReader.readArray(spec: MemorySpec, offset: Int, size: Int): Array = Array(size) { i -> with(spec) { read(offset + i * objectSize) } } /** * Writes [array] of objects mapped by [spec] at certain [offset]. */ -fun MemoryWriter.writeArray(spec: MemorySpec, offset: Int, array: Array): Unit = +public fun MemoryWriter.writeArray(spec: MemorySpec, offset: Int, array: Array): Unit = with(spec) { array.indices.forEach { i -> write(offset + i * objectSize, array[i]) } } // TODO It is possible to add elastic MemorySpec with unknown object size diff --git a/kmath-memory/src/jsMain/kotlin/scientifik/memory/DataViewMemory.kt b/kmath-memory/src/jsMain/kotlin/kscience/kmath/memory/DataViewMemory.kt similarity index 95% rename from kmath-memory/src/jsMain/kotlin/scientifik/memory/DataViewMemory.kt rename to kmath-memory/src/jsMain/kotlin/kscience/kmath/memory/DataViewMemory.kt index 974750502..2146cd4e1 100644 --- a/kmath-memory/src/jsMain/kotlin/scientifik/memory/DataViewMemory.kt +++ b/kmath-memory/src/jsMain/kotlin/kscience/kmath/memory/DataViewMemory.kt @@ -1,4 +1,4 @@ -package scientifik.memory +package kscience.kmath.memory import org.khronos.webgl.ArrayBuffer import org.khronos.webgl.DataView @@ -83,7 +83,7 @@ private class DataViewMemory(val view: DataView) : Memory { /** * Allocates memory based on a [DataView]. */ -actual fun Memory.Companion.allocate(length: Int): Memory { +public actual fun Memory.Companion.allocate(length: Int): Memory { val buffer = ArrayBuffer(length) return DataViewMemory(DataView(buffer, 0, length)) } @@ -92,7 +92,7 @@ actual fun Memory.Companion.allocate(length: Int): Memory { * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied * and could be mutated independently from the resulting [Memory]. */ -actual fun Memory.Companion.wrap(array: ByteArray): Memory { +public actual fun Memory.Companion.wrap(array: ByteArray): Memory { @Suppress("CAST_NEVER_SUCCEEDS") val int8Array = array as Int8Array return DataViewMemory(DataView(int8Array.buffer, int8Array.byteOffset, int8Array.length)) } diff --git a/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt b/kmath-memory/src/jvmMain/kotlin/kscience/kmath/memory/ByteBufferMemory.kt similarity index 89% rename from kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt rename to kmath-memory/src/jvmMain/kotlin/kscience/kmath/memory/ByteBufferMemory.kt index f4967bf5c..7a75b423e 100644 --- a/kmath-memory/src/jvmMain/kotlin/scientifik/memory/ByteBufferMemory.kt +++ b/kmath-memory/src/jvmMain/kotlin/kscience/kmath/memory/ByteBufferMemory.kt @@ -1,4 +1,4 @@ -package scientifik.memory +package kscience.kmath.memory import java.io.IOException import java.nio.ByteBuffer @@ -6,7 +6,6 @@ import java.nio.channels.FileChannel import java.nio.file.Files import java.nio.file.Path import java.nio.file.StandardOpenOption -import kotlin.contracts.ExperimentalContracts import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -94,14 +93,14 @@ internal class ByteBufferMemory( /** * Allocates memory based on a [ByteBuffer]. */ -actual fun Memory.Companion.allocate(length: Int): Memory = +public actual fun Memory.Companion.allocate(length: Int): Memory = ByteBufferMemory(checkNotNull(ByteBuffer.allocate(length))) /** * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied * and could be mutated independently from the resulting [Memory]. */ -actual fun Memory.Companion.wrap(array: ByteArray): Memory = ByteBufferMemory(checkNotNull(ByteBuffer.wrap(array))) +public actual fun Memory.Companion.wrap(array: ByteArray): Memory = ByteBufferMemory(checkNotNull(ByteBuffer.wrap(array))) /** * Wraps this [ByteBuffer] to [Memory] object. @@ -111,14 +110,14 @@ actual fun Memory.Companion.wrap(array: ByteArray): Memory = ByteBufferMemory(ch * @param size the size of memory to map. * @return the [Memory] object. */ -fun ByteBuffer.asMemory(startOffset: Int = 0, size: Int = limit()): Memory = +public fun ByteBuffer.asMemory(startOffset: Int = 0, size: Int = limit()): Memory = ByteBufferMemory(this, startOffset, size) /** * Uses direct memory-mapped buffer from file to read something and close it afterwards. */ @Throws(IOException::class) -inline fun Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R { +public inline fun Path.readAsMemory(position: Long = 0, size: Long = Files.size(this), block: Memory.() -> R): R { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } return FileChannel diff --git a/kmath-memory/src/nativeMain/kotlin/kscience/kmath/memory/NativeMemory.kt b/kmath-memory/src/nativeMain/kotlin/kscience/kmath/memory/NativeMemory.kt new file mode 100644 index 000000000..0e007a8ab --- /dev/null +++ b/kmath-memory/src/nativeMain/kotlin/kscience/kmath/memory/NativeMemory.kt @@ -0,0 +1,93 @@ +package kscience.kmath.memory + +@PublishedApi +internal class NativeMemory( + val array: ByteArray, + val startOffset: Int = 0, + override val size: Int = array.size +) : Memory { + @Suppress("NOTHING_TO_INLINE") + private inline fun position(o: Int): Int = startOffset + o + + override fun view(offset: Int, length: Int): Memory { + require(offset >= 0) { "offset shouldn't be negative: $offset" } + require(length >= 0) { "length shouldn't be negative: $length" } + require(offset + length <= size) { "Can't view memory outside the parent region." } + return NativeMemory(array, position(offset), length) + } + + override fun copy(): Memory { + val copy = array.copyOfRange(startOffset, startOffset + size) + return NativeMemory(copy) + } + + private val reader: MemoryReader = object : MemoryReader { + override val memory: Memory get() = this@NativeMemory + + override fun readDouble(offset: Int) = array.getDoubleAt(position(offset)) + + override fun readFloat(offset: Int) = array.getFloatAt(position(offset)) + + override fun readByte(offset: Int) = array[position(offset)] + + override fun readShort(offset: Int) = array.getShortAt(position(offset)) + + override fun readInt(offset: Int) = array.getIntAt(position(offset)) + + override fun readLong(offset: Int) = array.getLongAt(position(offset)) + + override fun release() { + // does nothing on JVM + } + } + + override fun reader(): MemoryReader = reader + + private val writer: MemoryWriter = object : MemoryWriter { + override val memory: Memory get() = this@NativeMemory + + override fun writeDouble(offset: Int, value: Double) { + array.setDoubleAt(position(offset), value) + } + + override fun writeFloat(offset: Int, value: Float) { + array.setFloatAt(position(offset), value) + } + + override fun writeByte(offset: Int, value: Byte) { + array.set(position(offset), value) + } + + override fun writeShort(offset: Int, value: Short) { + array.setShortAt(position(offset), value) + } + + override fun writeInt(offset: Int, value: Int) { + array.setIntAt(position(offset), value) + } + + override fun writeLong(offset: Int, value: Long) { + array.setLongAt(position(offset), value) + } + + override fun release() { + // does nothing on JVM + } + } + + override fun writer(): MemoryWriter = writer +} + +/** + * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied + * and could be mutated independently from the resulting [Memory]. + */ +public actual fun Memory.Companion.wrap(array: ByteArray): Memory = NativeMemory(array) + +/** + * Allocates the most effective platform-specific memory. + */ +public actual fun Memory.Companion.allocate(length: Int): Memory { + val array = ByteArray(length) + return NativeMemory(array) +} \ No newline at end of file diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md new file mode 100644 index 000000000..ff4ff4542 --- /dev/null +++ b/kmath-nd4j/README.md @@ -0,0 +1,82 @@ +# ND4J NDStructure implementation (`kmath-nd4j`) + +This subproject implements the following features: + + - [nd4jarraystructure](src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : NDStructure wrapper for INDArray + - [nd4jarrayrings](src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Rings over Nd4jArrayStructure of Int and Long + - [nd4jarrayfields](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : Fields over Nd4jArrayStructure of Float and Double + + +> #### Artifact: +> +> This module artifact: `kscience.kmath:kmath-nd4j:0.2.0-dev-4`. +> +> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-nd4j/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-nd4j/_latestVersion) +> +> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-nd4j/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-nd4j/_latestVersion) +> +> **Gradle:** +> +> ```gradle +> repositories { +> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } +> maven { url 'https://dl.bintray.com/mipt-npm/kscience' } +> maven { url 'https://dl.bintray.com/mipt-npm/dev' } +> maven { url 'https://dl.bintray.com/hotkeytlt/maven' } +> +> } +> +> dependencies { +> implementation 'kscience.kmath:kmath-nd4j:0.2.0-dev-4' +> } +> ``` +> **Gradle Kotlin DSL:** +> +> ```kotlin +> repositories { +> maven("https://dl.bintray.com/kotlin/kotlin-eap") +> maven("https://dl.bintray.com/mipt-npm/kscience") +> maven("https://dl.bintray.com/mipt-npm/dev") +> maven("https://dl.bintray.com/hotkeytlt/maven") +> } +> +> dependencies { +> implementation("kscience.kmath:kmath-nd4j:0.2.0-dev-4") +> } +> ``` + +## Examples + +NDStructure wrapper for INDArray: + +```kotlin +import org.nd4j.linalg.factory.* +import scientifik.kmath.nd4j.* +import scientifik.kmath.structures.* + +val array = Nd4j.ones(2, 2).asRealStructure() +println(array[0, 0]) // 1.0 +array[intArrayOf(0, 0)] = 24.0 +println(array[0, 0]) // 24.0 +``` + +Fast element-wise and in-place arithmetics for INDArray: + +```kotlin +import org.nd4j.linalg.factory.* +import scientifik.kmath.nd4j.* +import scientifik.kmath.operations.* + +val field = RealNd4jArrayField(intArrayOf(2, 2)) +val array = Nd4j.rand(2, 2).asRealStructure() + +val res = field { + (25.0 / array + 20) * 4 +} + +println(res.ndArray) +// [[ 250.6449, 428.5840], +// [ 269.7913, 202.2077]] +``` + +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). diff --git a/kmath-nd4j/build.gradle.kts b/kmath-nd4j/build.gradle.kts new file mode 100644 index 000000000..391727c45 --- /dev/null +++ b/kmath-nd4j/build.gradle.kts @@ -0,0 +1,37 @@ +import ru.mipt.npm.gradle.Maturity + +plugins { + id("ru.mipt.npm.jvm") +} + +dependencies { + api(project(":kmath-core")) + api("org.nd4j:nd4j-api:1.0.0-beta7") + testImplementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7") + testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") + testImplementation("org.slf4j:slf4j-simple:1.7.30") +} + +readme { + description = "ND4J NDStructure implementation and according NDAlgebra classes" + maturity = Maturity.EXPERIMENTAL + propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) + + feature( + id = "nd4jarraystructure", + description = "NDStructure wrapper for INDArray", + ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt" + ) + + feature( + id = "nd4jarrayrings", + description = "Rings over Nd4jArrayStructure of Int and Long", + ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt" + ) + + feature( + id = "nd4jarrayfields", + description = "Fields over Nd4jArrayStructure of Float and Double", + ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt" + ) +} diff --git a/kmath-nd4j/docs/README-TEMPLATE.md b/kmath-nd4j/docs/README-TEMPLATE.md new file mode 100644 index 000000000..76ce8c9a7 --- /dev/null +++ b/kmath-nd4j/docs/README-TEMPLATE.md @@ -0,0 +1,43 @@ +# ND4J NDStructure implementation (`kmath-nd4j`) + +This subproject implements the following features: + +${features} + +${artifact} + +## Examples + +NDStructure wrapper for INDArray: + +```kotlin +import org.nd4j.linalg.factory.* +import scientifik.kmath.nd4j.* +import scientifik.kmath.structures.* + +val array = Nd4j.ones(2, 2).asRealStructure() +println(array[0, 0]) // 1.0 +array[intArrayOf(0, 0)] = 24.0 +println(array[0, 0]) // 24.0 +``` + +Fast element-wise and in-place arithmetics for INDArray: + +```kotlin +import org.nd4j.linalg.factory.* +import scientifik.kmath.nd4j.* +import scientifik.kmath.operations.* + +val field = RealNd4jArrayField(intArrayOf(2, 2)) +val array = Nd4j.rand(2, 2).asRealStructure() + +val res = field { + (25.0 / array + 20) * 4 +} + +println(res.ndArray) +// [[ 250.6449, 428.5840], +// [ 269.7913, 202.2077]] +``` + +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt new file mode 100644 index 000000000..db2a44861 --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt @@ -0,0 +1,353 @@ +package kscience.kmath.nd4j + +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.* +import kscience.kmath.structures.NDAlgebra +import kscience.kmath.structures.NDField +import kscience.kmath.structures.NDRing +import kscience.kmath.structures.NDSpace +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.factory.Nd4j + +/** + * Represents [NDAlgebra] over [Nd4jArrayAlgebra]. + * + * @param T the type of ND-structure element. + * @param C the type of the element context. + */ +public interface Nd4jArrayAlgebra : NDAlgebra> { + /** + * Wraps [INDArray] to [N]. + */ + public fun INDArray.wrap(): Nd4jArrayStructure + + public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure { + val struct = Nd4j.create(*shape)!!.wrap() + struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } + return struct + } + + public override fun map(arg: Nd4jArrayStructure, transform: C.(T) -> T): Nd4jArrayStructure { + check(arg) + val newStruct = arg.ndArray.dup().wrap() + newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } + return newStruct + } + + public override fun mapIndexed( + arg: Nd4jArrayStructure, + transform: C.(index: IntArray, T) -> T, + ): Nd4jArrayStructure { + check(arg) + val new = Nd4j.create(*shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) } + return new + } + + public override fun combine( + a: Nd4jArrayStructure, + b: Nd4jArrayStructure, + transform: C.(T, T) -> T, + ): Nd4jArrayStructure { + check(a, b) + val new = Nd4j.create(*shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } + return new + } +} + +/** + * Represents [NDSpace] over [Nd4jArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param S the type of space of structure elements. + */ +public interface Nd4jArraySpace> : NDSpace>, Nd4jArrayAlgebra { + + public override val zero: Nd4jArrayStructure + get() = Nd4j.zeros(*shape).wrap() + + public override fun add(a: Nd4jArrayStructure, b: Nd4jArrayStructure): Nd4jArrayStructure { + check(a, b) + return a.ndArray.add(b.ndArray).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { + check(this, b) + return ndArray.sub(b.ndArray).wrap() + } + + public override operator fun Nd4jArrayStructure.unaryMinus(): Nd4jArrayStructure { + check(this) + return ndArray.neg().wrap() + } + + public override fun multiply(a: Nd4jArrayStructure, k: Number): Nd4jArrayStructure { + check(a) + return a.ndArray.mul(k).wrap() + } + + public override operator fun Nd4jArrayStructure.div(k: Number): Nd4jArrayStructure { + check(this) + return ndArray.div(k).wrap() + } + + public override operator fun Nd4jArrayStructure.times(k: Number): Nd4jArrayStructure { + check(this) + return ndArray.mul(k).wrap() + } +} + +/** + * Represents [NDRing] over [Nd4jArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param R the type of ring of structure elements. + */ +@OptIn(UnstableKMathAPI::class) +public interface Nd4jArrayRing> : NDRing>, Nd4jArraySpace { + + public override val one: Nd4jArrayStructure + get() = Nd4j.ones(*shape).wrap() + + public override fun multiply(a: Nd4jArrayStructure, b: Nd4jArrayStructure): Nd4jArrayStructure { + check(a, b) + return a.ndArray.mul(b.ndArray).wrap() + } +// +// public override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure { +// check(this) +// return ndArray.sub(b).wrap() +// } +// +// public override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure { +// check(this) +// return ndArray.add(b).wrap() +// } +// +// public override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { +// check(b) +// return b.ndArray.rsub(this).wrap() +// } + + public companion object { + private val intNd4jArrayRingCache: ThreadLocal> = + ThreadLocal.withInitial { hashMapOf() } + + private val longNd4jArrayRingCache: ThreadLocal> = + ThreadLocal.withInitial { hashMapOf() } + + /** + * Creates an [NDRing] for [Int] values or pull it from cache if it was created previously. + */ + public fun int(vararg shape: Int): Nd4jArrayRing = + intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) } + + /** + * Creates an [NDRing] for [Long] values or pull it from cache if it was created previously. + */ + public fun long(vararg shape: Int): Nd4jArrayRing = + longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) } + + /** + * Creates a most suitable implementation of [NDRing] using reified class. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(vararg shape: Int): Nd4jArrayRing> = when { + T::class == Int::class -> int(*shape) as Nd4jArrayRing> + T::class == Long::class -> long(*shape) as Nd4jArrayRing> + else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.") + } + } +} + +/** + * Represents [NDField] over [Nd4jArrayStructure]. + * + * @param T the type of the element contained in ND structure. + * @param N the type of ND structure. + * @param F the type field of structure elements. + */ +public interface Nd4jArrayField> : NDField>, Nd4jArrayRing { + + public override fun divide(a: Nd4jArrayStructure, b: Nd4jArrayStructure): Nd4jArrayStructure { + check(a, b) + return a.ndArray.div(b.ndArray).wrap() + } + + public override operator fun Number.div(b: Nd4jArrayStructure): Nd4jArrayStructure { + check(b) + return b.ndArray.rdiv(this).wrap() + } + + + public companion object { + private val floatNd4jArrayFieldCache: ThreadLocal> = + ThreadLocal.withInitial { hashMapOf() } + + private val realNd4jArrayFieldCache: ThreadLocal> = + ThreadLocal.withInitial { hashMapOf() } + + /** + * Creates an [NDField] for [Float] values or pull it from cache if it was created previously. + */ + public fun float(vararg shape: Int): Nd4jArrayRing = + floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) } + + /** + * Creates an [NDField] for [Double] values or pull it from cache if it was created previously. + */ + public fun real(vararg shape: Int): Nd4jArrayRing = + realNd4jArrayFieldCache.get().getOrPut(shape) { RealNd4jArrayField(shape) } + + /** + * Creates a most suitable implementation of [NDRing] using reified class. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(vararg shape: Int): Nd4jArrayField> = when { + T::class == Float::class -> float(*shape) as Nd4jArrayField> + T::class == Double::class -> real(*shape) as Nd4jArrayField> + else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.") + } + } +} + +/** + * Represents [NDField] over [Nd4jArrayRealStructure]. + */ +public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField { + public override val elementContext: RealField + get() = RealField + + public override fun INDArray.wrap(): Nd4jArrayStructure = check(asRealStructure()) + + public override operator fun Nd4jArrayStructure.div(arg: Double): Nd4jArrayStructure { + check(this) + return ndArray.div(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.plus(arg: Double): Nd4jArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(arg: Double): Nd4jArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.times(arg: Double): Nd4jArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Double.div(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rdiv(this).wrap() + } + + public override operator fun Double.minus(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDField] over [Nd4jArrayStructure] of [Float]. + */ +public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField { + public override val elementContext: FloatField + get() = FloatField + + public override fun INDArray.wrap(): Nd4jArrayStructure = check(asFloatStructure()) + + public override operator fun Nd4jArrayStructure.div(arg: Float): Nd4jArrayStructure { + check(this) + return ndArray.div(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.plus(arg: Float): Nd4jArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(arg: Float): Nd4jArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.times(arg: Float): Nd4jArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Float.div(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rdiv(this).wrap() + } + + public override operator fun Float.minus(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDRing] over [Nd4jArrayIntStructure]. + */ +public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing { + public override val elementContext: IntRing + get() = IntRing + + public override fun INDArray.wrap(): Nd4jArrayStructure = check(asIntStructure()) + + public override operator fun Nd4jArrayStructure.plus(arg: Int): Nd4jArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(arg: Int): Nd4jArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.times(arg: Int): Nd4jArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Int.minus(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} + +/** + * Represents [NDRing] over [Nd4jArrayStructure] of [Long]. + */ +public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing { + public override val elementContext: LongRing + get() = LongRing + + public override fun INDArray.wrap(): Nd4jArrayStructure = check(asLongStructure()) + + public override operator fun Nd4jArrayStructure.plus(arg: Long): Nd4jArrayStructure { + check(this) + return ndArray.add(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.minus(arg: Long): Nd4jArrayStructure { + check(this) + return ndArray.sub(arg).wrap() + } + + public override operator fun Nd4jArrayStructure.times(arg: Long): Nd4jArrayStructure { + check(this) + return ndArray.mul(arg).wrap() + } + + public override operator fun Long.minus(arg: Nd4jArrayStructure): Nd4jArrayStructure { + check(arg) + return arg.ndArray.rsub(this).wrap() + } +} diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayIterator.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayIterator.kt new file mode 100644 index 000000000..1463a92fe --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayIterator.kt @@ -0,0 +1,62 @@ +package kscience.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.api.shape.Shape + +private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iterator { + private var i: Int = 0 + + override fun hasNext(): Boolean = i < iterateOver.length() + + override fun next(): IntArray { + val la = if (iterateOver.ordering() == 'c') + Shape.ind2subC(iterateOver, i++.toLong())!! + else + Shape.ind2sub(iterateOver, i++.toLong())!! + + return la.toIntArray() + } +} + +internal fun INDArray.indicesIterator(): Iterator = Nd4jArrayIndicesIterator(this) + +private sealed class Nd4jArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { + private var i: Int = 0 + + final override fun hasNext(): Boolean = i < iterateOver.length() + + abstract fun getSingle(indices: LongArray): T + + final override fun next(): Pair { + val la = if (iterateOver.ordering() == 'c') + Shape.ind2subC(iterateOver, i++.toLong())!! + else + Shape.ind2sub(iterateOver, i++.toLong())!! + + return la.toIntArray() to getSingle(la) + } +} + +private class Nd4jArrayRealIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { + override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices) +} + +internal fun INDArray.realIterator(): Iterator> = Nd4jArrayRealIterator(this) + +private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { + override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices) +} + +internal fun INDArray.longIterator(): Iterator> = Nd4jArrayLongIterator(this) + +private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { + override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray()) +} + +internal fun INDArray.intIterator(): Iterator> = Nd4jArrayIntIterator(this) + +private class Nd4jArrayFloatIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { + override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices) +} + +internal fun INDArray.floatIterator(): Iterator> = Nd4jArrayFloatIterator(this) diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayStructure.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayStructure.kt new file mode 100644 index 000000000..d47a293c3 --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayStructure.kt @@ -0,0 +1,68 @@ +package kscience.kmath.nd4j + +import kscience.kmath.structures.MutableNDStructure +import kscience.kmath.structures.NDStructure +import org.nd4j.linalg.api.ndarray.INDArray + +/** + * Represents a [NDStructure] wrapping an [INDArray] object. + * + * @param T the type of items. + */ +public sealed class Nd4jArrayStructure : MutableNDStructure { + /** + * The wrapped [INDArray]. + */ + public abstract val ndArray: INDArray + + public override val shape: IntArray + get() = ndArray.shape().toIntArray() + + internal abstract fun elementsIterator(): Iterator> + internal fun indicesIterator(): Iterator = ndArray.indicesIterator() + public override fun elements(): Sequence> = Sequence(::elementsIterator) +} + +private data class Nd4jArrayIntStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { + override fun elementsIterator(): Iterator> = ndArray.intIterator() + override fun get(index: IntArray): Int = ndArray.getInt(*index) + override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) } +} + +/** + * Wraps this [INDArray] to [Nd4jArrayStructure]. + */ +public fun INDArray.asIntStructure(): Nd4jArrayStructure = Nd4jArrayIntStructure(this) + +private data class Nd4jArrayLongStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { + override fun elementsIterator(): Iterator> = ndArray.longIterator() + override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray()) + override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) } +} + +/** + * Wraps this [INDArray] to [Nd4jArrayStructure]. + */ +public fun INDArray.asLongStructure(): Nd4jArrayStructure = Nd4jArrayLongStructure(this) + +private data class Nd4jArrayRealStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { + override fun elementsIterator(): Iterator> = ndArray.realIterator() + override fun get(index: IntArray): Double = ndArray.getDouble(*index) + override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) } +} + +/** + * Wraps this [INDArray] to [Nd4jArrayStructure]. + */ +public fun INDArray.asRealStructure(): Nd4jArrayStructure = Nd4jArrayRealStructure(this) + +private data class Nd4jArrayFloatStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { + override fun elementsIterator(): Iterator> = ndArray.floatIterator() + override fun get(index: IntArray): Float = ndArray.getFloat(*index) + override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) } +} + +/** + * Wraps this [INDArray] to [Nd4jArrayStructure]. + */ +public fun INDArray.asFloatStructure(): Nd4jArrayStructure = Nd4jArrayFloatStructure(this) diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt new file mode 100644 index 000000000..798f81c35 --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt @@ -0,0 +1,4 @@ +package kscience.kmath.nd4j + +internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() } +internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() } diff --git a/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt new file mode 100644 index 000000000..650d5670c --- /dev/null +++ b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt @@ -0,0 +1,42 @@ +package kscience.kmath.nd4j + +import org.nd4j.linalg.factory.Nd4j +import kscience.kmath.operations.invoke +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.fail + +internal class Nd4jArrayAlgebraTest { + @Test + fun testProduce() { + val res = (RealNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } + val expected = (Nd4j.create(2, 2) ?: fail()).asRealStructure() + expected[intArrayOf(0, 0)] = 0.0 + expected[intArrayOf(0, 1)] = 1.0 + expected[intArrayOf(1, 0)] = 1.0 + expected[intArrayOf(1, 1)] = 2.0 + assertEquals(expected, res) + } + + @Test + fun testMap() { + val res = (IntNd4jArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } } + val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() + expected[intArrayOf(0, 0)] = 3 + expected[intArrayOf(0, 1)] = 3 + expected[intArrayOf(1, 0)] = 3 + expected[intArrayOf(1, 1)] = 3 + assertEquals(expected, res) + } + + @Test + fun testAdd() { + val res = (IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 } + val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() + expected[intArrayOf(0, 0)] = 26 + expected[intArrayOf(0, 1)] = 26 + expected[intArrayOf(1, 0)] = 26 + expected[intArrayOf(1, 1)] = 26 + assertEquals(expected, res) + } +} diff --git a/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt new file mode 100644 index 000000000..7e46211c1 --- /dev/null +++ b/kmath-nd4j/src/test/kotlin/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt @@ -0,0 +1,72 @@ +package kscience.kmath.nd4j + +import kscience.kmath.structures.get +import org.nd4j.linalg.factory.Nd4j +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals +import kotlin.test.fail + +internal class Nd4jArrayStructureTest { + @Test + fun testElements() { + val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! + val struct = nd.asRealStructure() + val res = struct.elements().map(Pair::second).toList() + assertEquals(listOf(1.0, 2.0, 3.0), res) + } + + @Test + fun testShape() { + val nd = Nd4j.rand(10, 2, 3, 6) ?: fail() + val struct = nd.asRealStructure() + assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList()) + } + + @Test + fun testEquals() { + val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail() + val struct1 = nd1.asRealStructure() + assertEquals(struct1, struct1) + assertNotEquals(struct1 as Any?, null) + val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail() + val struct2 = nd2.asRealStructure() + assertEquals(struct1, struct2) + assertEquals(struct2, struct1) + val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail() + val struct3 = nd3.asRealStructure() + assertEquals(struct2, struct3) + assertEquals(struct1, struct3) + } + + @Test + fun testHashCode() { + val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail() + val struct1 = nd1.asRealStructure() + val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail() + val struct2 = nd2.asRealStructure() + assertEquals(struct1.hashCode(), struct2.hashCode()) + } + + @Test + fun testDimension() { + val nd = Nd4j.rand(8, 16, 3, 7, 1)!! + val struct = nd.asFloatStructure() + assertEquals(5, struct.dimension) + } + + @Test + fun testGet() { + val nd = Nd4j.rand(10, 2, 3, 6)?:fail() + val struct = nd.asIntStructure() + assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0]) + } + + @Test + fun testSet() { + val nd = Nd4j.rand(17, 12, 4, 8)!! + val struct = nd.asLongStructure() + struct[intArrayOf(1, 2, 3, 4)] = 777 + assertEquals(777, struct[1, 2, 3, 4]) + } +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt deleted file mode 100644 index ea526c058..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt +++ /dev/null @@ -1,47 +0,0 @@ -package scientifik.kmath.prob - -import scientifik.kmath.chains.Chain -import scientifik.kmath.chains.SimpleChain - -/** - * A multivariate distribution which takes a map of parameters - */ -interface NamedDistribution : Distribution> - -/** - * A multivariate distribution that has independent distributions for separate axis - */ -class FactorizedDistribution(val distributions: Collection>) : NamedDistribution { - - override fun probability(arg: Map): Double { - return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) } - } - - override fun sample(generator: RandomGenerator): Chain> { - val chains = distributions.map { it.sample(generator) } - return SimpleChain> { - chains.fold(emptyMap()) { acc, chain -> acc + chain.next() } - } - } -} - -class NamedDistributionWrapper(val name: String, val distribution: Distribution) : NamedDistribution { - override fun probability(arg: Map): Double = distribution.probability( - arg[name] ?: error("Argument with name $name not found in input parameters") - ) - - override fun sample(generator: RandomGenerator): Chain> { - val chain = distribution.sample(generator) - return SimpleChain { - mapOf(name to chain.next()) - } - } -} - -class DistributionBuilder{ - private val distributions = ArrayList>() - - infix fun String.to(distribution: Distribution){ - distributions.add(NamedDistributionWrapper(this,distribution)) - } -} \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt deleted file mode 100644 index 49163c701..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt +++ /dev/null @@ -1,14 +0,0 @@ -package scientifik.kmath.prob - -import scientifik.kmath.chains.Chain - -/** - * A possibly stateful chain producing random values. - */ -class RandomChain(val generator: RandomGenerator, private val gen: suspend RandomGenerator.() -> R) : Chain { - override suspend fun next(): R = generator.gen() - - override fun fork(): Chain = RandomChain(generator.fork(), gen) -} - -fun RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain = RandomChain(this, gen) diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt deleted file mode 100644 index 2a225fe47..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt +++ /dev/null @@ -1,55 +0,0 @@ -package scientifik.kmath.prob - -import kotlin.random.Random - -/** - * A basic generator - */ -interface RandomGenerator { - fun nextBoolean(): Boolean - - fun nextDouble(): Double - fun nextInt(): Int - fun nextInt(until: Int): Int - fun nextLong(): Long - fun nextLong(until: Long): Long - - fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size) - fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) } - - /** - * Create a new generator which is independent from current generator (operations on new generator do not affect this one - * and vise versa). The statistical properties of new generator should be the same as for this one. - * For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run. - * - * The thread safety of this operation is not guaranteed since it could affect the state of the generator. - */ - fun fork(): RandomGenerator - - companion object { - val default by lazy { DefaultGenerator() } - - fun default(seed: Long) = DefaultGenerator(Random(seed)) - } -} - -inline class DefaultGenerator(val random: Random = Random) : RandomGenerator { - override fun nextBoolean(): Boolean = random.nextBoolean() - - override fun nextDouble(): Double = random.nextDouble() - - override fun nextInt(): Int = random.nextInt() - override fun nextInt(until: Int): Int = random.nextInt(until) - - override fun nextLong(): Long = random.nextLong() - - override fun nextLong(until: Long): Long = random.nextLong(until) - - override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { - random.nextBytes(array, fromIndex, toIndex) - } - - override fun nextBytes(size: Int): ByteArray = random.nextBytes(size) - - override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong()) -} \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt deleted file mode 100644 index 02f98439e..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt +++ /dev/null @@ -1,32 +0,0 @@ -package scientifik.kmath.prob - -import scientifik.kmath.chains.Chain -import scientifik.kmath.chains.ConstantChain -import scientifik.kmath.chains.map -import scientifik.kmath.chains.zip -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke - -class BasicSampler(val chainBuilder: (RandomGenerator) -> Chain) : Sampler { - override fun sample(generator: RandomGenerator): Chain = chainBuilder(generator) -} - -class ConstantSampler(val value: T) : Sampler { - override fun sample(generator: RandomGenerator): Chain = ConstantChain(value) -} - -/** - * A space for samplers. Allows to perform simple operations on distributions - */ -class SamplerSpace(val space: Space) : Space> { - - override val zero: Sampler = ConstantSampler(space.zero) - - override fun add(a: Sampler, b: Sampler): Sampler = BasicSampler { generator -> - a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space { aValue + bValue } } - } - - override fun multiply(a: Sampler, k: Number): Sampler = BasicSampler { generator -> - a.sample(generator).map { space { it * k.toDouble() } } - } -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/UniformDistribution.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/UniformDistribution.kt deleted file mode 100644 index 9d96bff59..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/UniformDistribution.kt +++ /dev/null @@ -1,34 +0,0 @@ -package scientifik.kmath.prob - -import scientifik.kmath.chains.Chain -import scientifik.kmath.chains.SimpleChain - -class UniformDistribution(val range: ClosedFloatingPointRange) : UnivariateDistribution { - - private val length = range.endInclusive - range.start - - override fun probability(arg: Double): Double { - return if (arg in range) { - return 1.0 / length - } else { - 0.0 - } - } - - override fun sample(generator: RandomGenerator): Chain { - return SimpleChain { - range.start + generator.nextDouble() * length - } - } - - override fun cumulative(arg: Double): Double { - return when { - arg < range.start -> 0.0 - arg >= range.endInclusive -> 1.0 - else -> (arg - range.start) / length - } - } -} - -fun Distribution.Companion.uniform(range: ClosedFloatingPointRange): UniformDistribution = - UniformDistribution(range) \ No newline at end of file diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt deleted file mode 100644 index f5a73a08b..000000000 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt +++ /dev/null @@ -1,67 +0,0 @@ -package scientifik.kmath.prob - -import org.apache.commons.rng.UniformRandomProvider -import org.apache.commons.rng.simple.RandomSource - -class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator { - internal val random: UniformRandomProvider = seed?.let { - RandomSource.create(source, seed) - } ?: RandomSource.create(source) - - override fun nextBoolean(): Boolean = random.nextBoolean() - - override fun nextDouble(): Double = random.nextDouble() - - override fun nextInt(): Int = random.nextInt() - override fun nextInt(until: Int): Int = random.nextInt(until) - - override fun nextLong(): Long = random.nextLong() - override fun nextLong(until: Long): Long = random.nextLong(until) - - override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { - require(toIndex > fromIndex) - random.nextBytes(array, fromIndex, toIndex - fromIndex) - } - - override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong()) -} - -inline class RandomGeneratorProvider(val generator: RandomGenerator) : UniformRandomProvider { - override fun nextBoolean(): Boolean = generator.nextBoolean() - - override fun nextFloat(): Float = generator.nextDouble().toFloat() - - override fun nextBytes(bytes: ByteArray) { - generator.fillBytes(bytes) - } - - override fun nextBytes(bytes: ByteArray, start: Int, len: Int) { - generator.fillBytes(bytes, start, start + len) - } - - override fun nextInt(): Int = generator.nextInt() - - override fun nextInt(n: Int): Int = generator.nextInt(n) - - override fun nextDouble(): Double = generator.nextDouble() - - override fun nextLong(): Long = generator.nextLong() - - override fun nextLong(n: Long): Long = generator.nextLong(n) -} - -/** - * Represent this [RandomGenerator] as commons-rng [UniformRandomProvider] preserving and mirroring its current state. - * Getting new value from one of those changes the state of another. - */ -fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this is RandomSourceGenerator) { - random -} else { - RandomGeneratorProvider(this) -} - -fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator = - RandomSourceGenerator(source, seed) - -fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator = - fromSource(RandomSource.MT, seed) diff --git a/kmath-prob/build.gradle.kts b/kmath-stat/build.gradle.kts similarity index 80% rename from kmath-prob/build.gradle.kts rename to kmath-stat/build.gradle.kts index a69d61b73..186aff944 100644 --- a/kmath-prob/build.gradle.kts +++ b/kmath-stat/build.gradle.kts @@ -1,5 +1,5 @@ plugins { - id("scientifik.mpp") + id("ru.mipt.npm.mpp") } kotlin.sourceSets { @@ -8,10 +8,11 @@ kotlin.sourceSets { api(project(":kmath-coroutines")) } } - jvmMain{ - dependencies{ + + jvmMain { + dependencies { api("org.apache.commons:commons-rng-sampling:1.3") api("org.apache.commons:commons-rng-simple:1.3") } } -} \ No newline at end of file +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Distribution.kt similarity index 60% rename from kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Distribution.kt index 3b874adaa..c4ceb29eb 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Distribution.kt @@ -1,23 +1,23 @@ -package scientifik.kmath.prob +package kscience.kmath.stat -import scientifik.kmath.chains.Chain -import scientifik.kmath.chains.collect -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.BufferFactory +import kscience.kmath.chains.Chain +import kscience.kmath.chains.collect +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.BufferFactory -interface Sampler { - fun sample(generator: RandomGenerator): Chain +public interface Sampler { + public fun sample(generator: RandomGenerator): Chain } /** * A distribution of typed objects */ -interface Distribution : Sampler { +public interface Distribution : Sampler { /** * A probability value for given argument [arg]. * For continuous distributions returns PDF */ - fun probability(arg: T): Double + public fun probability(arg: T): Double /** * Create a chain of samples from this distribution. @@ -28,20 +28,20 @@ interface Distribution : Sampler { /** * An empty companion. Distribution factories should be written as its extensions */ - companion object + public companion object } -interface UnivariateDistribution> : Distribution { +public interface UnivariateDistribution> : Distribution { /** * Cumulative distribution for ordered parameter (CDF) */ - fun cumulative(arg: T): Double + public fun cumulative(arg: T): Double } /** * Compute probability integral in an interval */ -fun > UnivariateDistribution.integral(from: T, to: T): Double { +public fun > UnivariateDistribution.integral(from: T, to: T): Double { require(to > from) return cumulative(to) - cumulative(from) } @@ -49,7 +49,7 @@ fun > UnivariateDistribution.integral(from: T, to: T): Doub /** * Sample a bunch of values */ -fun Sampler.sampleBuffer( +public fun Sampler.sampleBuffer( generator: RandomGenerator, size: Int, bufferFactory: BufferFactory = Buffer.Companion::boxing @@ -57,11 +57,12 @@ fun Sampler.sampleBuffer( require(size > 1) //creating temporary storage once val tmp = ArrayList(size) + return sample(generator).collect { chain -> //clear list from previous run tmp.clear() //Fill list - repeat(size){ + repeat(size) { tmp.add(chain.next()) } //return new buffer with elements from tmp @@ -72,5 +73,5 @@ fun Sampler.sampleBuffer( /** * Generate a bunch of samples from real distributions */ -fun Sampler.sampleBuffer(generator: RandomGenerator, size: Int) = - sampleBuffer(generator, size, Buffer.Companion::real) \ No newline at end of file +public fun Sampler.sampleBuffer(generator: RandomGenerator, size: Int): Chain> = + sampleBuffer(generator, size, Buffer.Companion::real) diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/FactorizedDistribution.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/FactorizedDistribution.kt new file mode 100644 index 000000000..1ed9deba9 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/FactorizedDistribution.kt @@ -0,0 +1,43 @@ +package kscience.kmath.stat + +import kscience.kmath.chains.Chain +import kscience.kmath.chains.SimpleChain + +/** + * A multivariate distribution which takes a map of parameters + */ +public interface NamedDistribution : Distribution> + +/** + * A multivariate distribution that has independent distributions for separate axis + */ +public class FactorizedDistribution(public val distributions: Collection>) : + NamedDistribution { + override fun probability(arg: Map): Double = + distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) } + + override fun sample(generator: RandomGenerator): Chain> { + val chains = distributions.map { it.sample(generator) } + return SimpleChain { chains.fold(emptyMap()) { acc, chain -> acc + chain.next() } } + } +} + +public class NamedDistributionWrapper(public val name: String, public val distribution: Distribution) : + NamedDistribution { + override fun probability(arg: Map): Double = distribution.probability( + arg[name] ?: error("Argument with name $name not found in input parameters") + ) + + override fun sample(generator: RandomGenerator): Chain> { + val chain = distribution.sample(generator) + return SimpleChain { mapOf(name to chain.next()) } + } +} + +public class DistributionBuilder { + private val distributions = ArrayList>() + + public infix fun String.to(distribution: Distribution) { + distributions.add(NamedDistributionWrapper(this, distribution)) + } +} diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt new file mode 100644 index 000000000..9d4655df2 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Fitting.kt @@ -0,0 +1,63 @@ +package kscience.kmath.stat + +import kscience.kmath.expressions.* +import kscience.kmath.operations.ExtendedField +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.indices +import kotlin.math.pow + +public object Fitting { + + /** + * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation + */ + public fun chiSquared( + autoDiff: AutoDiffProcessor>, + x: Buffer, + y: Buffer, + yErr: Buffer, + model: A.(I) -> I, + ): DifferentiableExpression> where A : ExtendedField, A : ExpressionAlgebra { + require(x.size == y.size) { "X and y buffers should be of the same size" } + require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } + + return autoDiff.process { + var sum = zero + + x.indices.forEach { + val xValue = const(x[it]) + val yValue = const(y[it]) + val yErrValue = const(yErr[it]) + val modelValue = model(xValue) + sum += ((yValue - modelValue) / yErrValue).pow(2) + } + + sum + } + } + + /** + * Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives + */ + public fun chiSquared( + x: Buffer, + y: Buffer, + yErr: Buffer, + model: Expression, + xSymbol: Symbol = StringSymbol("x"), + ): Expression { + require(x.size == y.size) { "X and y buffers should be of the same size" } + require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } + + return Expression { arguments -> + x.indices.sumByDouble { + val xValue = x[it] + val yValue = y[it] + val yErrValue = yErr[it] + val modifiedArgs = arguments + (xSymbol to xValue) + val modelValue = model(modifiedArgs) + ((yValue - modelValue) / yErrValue).pow(2) + } + } + } +} diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/MCScope.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/MCScope.kt new file mode 100644 index 000000000..5dc567db8 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/MCScope.kt @@ -0,0 +1,58 @@ +package kscience.kmath.stat + +import kotlinx.coroutines.* +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.coroutines.coroutineContext + +/** + * A scope for a Monte-Carlo computations or multi-coroutine random number generation. + * The scope preserves the order of random generator calls as long as all concurrency calls is done via [launch] and [async] + * functions. + */ +public class MCScope( + public val coroutineContext: CoroutineContext, + public val random: RandomGenerator, +) + +/** + * Launches a supervised Monte-Carlo scope + */ +public suspend inline fun mcScope(generator: RandomGenerator, block: MCScope.() -> T): T = + MCScope(coroutineContext, generator).block() + +/** + * Launch mc scope with a given seed + */ +public suspend inline fun mcScope(seed: Long, block: MCScope.() -> T): T = + mcScope(RandomGenerator.default(seed), block) + +/** + * Specialized launch for [MCScope]. Behaves the same way as regular [CoroutineScope.launch], but also stores the generator fork. + * The method itself is not thread safe. + */ +public inline fun MCScope.launch( + context: CoroutineContext = EmptyCoroutineContext, + start: CoroutineStart = CoroutineStart.DEFAULT, + crossinline block: suspend MCScope.() -> Unit, +): Job { + val newRandom = random.fork() + return CoroutineScope(coroutineContext).launch(context, start) { + MCScope(coroutineContext, newRandom).block() + } +} + +/** + * Specialized async for [MCScope]. Behaves the same way as regular [CoroutineScope.async], but also stores the generator fork. + * The method itself is not thread safe. + */ +public inline fun MCScope.async( + context: CoroutineContext = EmptyCoroutineContext, + start: CoroutineStart = CoroutineStart.DEFAULT, + crossinline block: suspend MCScope.() -> T, +): Deferred { + val newRandom = random.fork() + return CoroutineScope(coroutineContext).async(context, start) { + MCScope(coroutineContext, newRandom).block() + } +} \ No newline at end of file diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt new file mode 100644 index 000000000..0f3cd9dd9 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/OptimizationProblem.kt @@ -0,0 +1,88 @@ +package kscience.kmath.stat + +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Expression +import kscience.kmath.expressions.Symbol + +public interface OptimizationFeature + +public class OptimizationResult( + public val point: Map, + public val value: T, + public val features: Set = emptySet(), +) { + override fun toString(): String { + return "OptimizationResult(point=$point, value=$value)" + } +} + +public operator fun OptimizationResult.plus( + feature: OptimizationFeature, +): OptimizationResult = OptimizationResult(point, value, features + feature) + +/** + * A configuration builder for optimization problem + */ +public interface OptimizationProblem { + /** + * Define the initial guess for the optimization problem + */ + public fun initialGuess(map: Map) + + /** + * Set an objective function expression + */ + public fun expression(expression: Expression) + + /** + * Set a differentiable expression as objective function as function and gradient provider + */ + public fun diffExpression(expression: DifferentiableExpression>) + + /** + * Update the problem from previous optimization run + */ + public fun update(result: OptimizationResult) + + /** + * Make an optimization run + */ + public fun optimize(): OptimizationResult +} + +public fun interface OptimizationProblemFactory> { + public fun build(symbols: List): P +} + +public operator fun > OptimizationProblemFactory.invoke( + symbols: List, + block: P.() -> Unit, +): P = build(symbols).apply(block) + +/** + * Optimize expression without derivatives using specific [OptimizationProblemFactory] + */ +public fun > Expression.optimizeWith( + factory: OptimizationProblemFactory, + vararg symbols: Symbol, + configuration: F.() -> Unit, +): OptimizationResult { + require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = factory(symbols.toList(),configuration) + problem.expression(this) + return problem.optimize() +} + +/** + * Optimize differentiable expression using specific [OptimizationProblemFactory] + */ +public fun > DifferentiableExpression>.optimizeWith( + factory: OptimizationProblemFactory, + vararg symbols: Symbol, + configuration: F.() -> Unit, +): OptimizationResult { + require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } + val problem = factory(symbols.toList(), configuration) + problem.diffExpression(this) + return problem.optimize() +} diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomChain.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomChain.kt new file mode 100644 index 000000000..0f10851b9 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomChain.kt @@ -0,0 +1,17 @@ +package kscience.kmath.stat + +import kscience.kmath.chains.Chain + +/** + * A possibly stateful chain producing random values. + */ +public class RandomChain( + public val generator: RandomGenerator, + private val gen: suspend RandomGenerator.() -> R +) : Chain { + override suspend fun next(): R = generator.gen() + + override fun fork(): Chain = RandomChain(generator.fork(), gen) +} + +public fun RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain = RandomChain(this, gen) diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomGenerator.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomGenerator.kt new file mode 100644 index 000000000..4486ae016 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/RandomGenerator.kt @@ -0,0 +1,100 @@ +package kscience.kmath.stat + +import kotlin.random.Random + +/** + * An interface that is implemented by random number generator algorithms. + */ +public interface RandomGenerator { + /** + * Gets the next random [Boolean] value. + */ + public fun nextBoolean(): Boolean + + /** + * Gets the next random [Double] value uniformly distributed between 0 (inclusive) and 1 (exclusive). + */ + public fun nextDouble(): Double + + /** + * Gets the next random `Int` from the random number generator. + * + * Generates an `Int` random value uniformly distributed between [Int.MIN_VALUE] and [Int.MAX_VALUE] (inclusive). + */ + public fun nextInt(): Int + + /** + * Gets the next random non-negative `Int` from the random number generator less than the specified [until] bound. + * + * Generates an `Int` random value uniformly distributed between `0` (inclusive) and the specified [until] bound + * (exclusive). + */ + public fun nextInt(until: Int): Int + + /** + * Gets the next random `Long` from the random number generator. + * + * Generates a `Long` random value uniformly distributed between [Long.MIN_VALUE] and [Long.MAX_VALUE] (inclusive). + */ + public fun nextLong(): Long + + /** + * Gets the next random non-negative `Long` from the random number generator less than the specified [until] bound. + * + * Generates a `Long` random value uniformly distributed between `0` (inclusive) and the specified [until] bound (exclusive). + */ + public fun nextLong(until: Long): Long + + /** + * Fills a subrange of the specified byte [array] starting from [fromIndex] inclusive and ending [toIndex] exclusive + * with random bytes. + * + * @return [array] with the subrange filled with random bytes. + */ + public fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size) + + /** + * Creates a byte array of the specified [size], filled with random bytes. + */ + public fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) } + + /** + * Create a new generator which is independent from current generator (operations on new generator do not affect this one + * and vise versa). The statistical properties of new generator should be the same as for this one. + * For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run. + * + * The thread safety of this operation is not guaranteed since it could affect the state of the generator. + */ + public fun fork(): RandomGenerator + + public companion object { + /** + * The [DefaultGenerator] instance. + */ + public val default: DefaultGenerator by lazy(::DefaultGenerator) + + /** + * Returns [DefaultGenerator] of given [seed]. + */ + public fun default(seed: Long): DefaultGenerator = DefaultGenerator(Random(seed)) + } +} + +/** + * Implements [RandomGenerator] by delegating all operations to [Random]. + */ +public inline class DefaultGenerator(public val random: Random = Random) : RandomGenerator { + public override fun nextBoolean(): Boolean = random.nextBoolean() + public override fun nextDouble(): Double = random.nextDouble() + public override fun nextInt(): Int = random.nextInt() + public override fun nextInt(until: Int): Int = random.nextInt(until) + public override fun nextLong(): Long = random.nextLong() + public override fun nextLong(until: Long): Long = random.nextLong(until) + + public override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { + random.nextBytes(array, fromIndex, toIndex) + } + + public override fun nextBytes(size: Int): ByteArray = random.nextBytes(size) + public override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong()) +} diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/SamplerAlgebra.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/SamplerAlgebra.kt new file mode 100644 index 000000000..f416028a5 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/SamplerAlgebra.kt @@ -0,0 +1,31 @@ +package kscience.kmath.stat + +import kscience.kmath.chains.Chain +import kscience.kmath.chains.ConstantChain +import kscience.kmath.chains.map +import kscience.kmath.chains.zip +import kscience.kmath.operations.Space +import kscience.kmath.operations.invoke + +public class BasicSampler(public val chainBuilder: (RandomGenerator) -> Chain) : Sampler { + public override fun sample(generator: RandomGenerator): Chain = chainBuilder(generator) +} + +public class ConstantSampler(public val value: T) : Sampler { + public override fun sample(generator: RandomGenerator): Chain = ConstantChain(value) +} + +/** + * A space for samplers. Allows to perform simple operations on distributions + */ +public class SamplerSpace(public val space: Space) : Space> { + public override val zero: Sampler = ConstantSampler(space.zero) + + public override fun add(a: Sampler, b: Sampler): Sampler = BasicSampler { generator -> + a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space { aValue + bValue } } + } + + public override fun multiply(a: Sampler, k: Number): Sampler = BasicSampler { generator -> + a.sample(generator).map { space { it * k.toDouble() } } + } +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Statistic.kt similarity index 52% rename from kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt rename to kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Statistic.kt index c82d262bf..a4624fc21 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Statistic.kt +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/Statistic.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.prob +package kscience.kmath.stat import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers @@ -6,18 +6,18 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map -import kotlinx.coroutines.flow.scanReduce -import scientifik.kmath.coroutines.mapParallel -import scientifik.kmath.operations.* -import scientifik.kmath.structures.Buffer -import scientifik.kmath.structures.asIterable -import scientifik.kmath.structures.asSequence +import kotlinx.coroutines.flow.runningReduce +import kscience.kmath.coroutines.mapParallel +import kscience.kmath.operations.* +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.asIterable +import kscience.kmath.structures.asSequence /** * A function, that transforms a buffer of random quantities to some resulting value */ -interface Statistic { - suspend operator fun invoke(data: Buffer): R +public interface Statistic { + public suspend operator fun invoke(data: Buffer): R } /** @@ -26,17 +26,17 @@ interface Statistic { * @param I - intermediate block type * @param R - result type */ -interface ComposableStatistic : Statistic { +public interface ComposableStatistic : Statistic { //compute statistic on a single block - suspend fun computeIntermediate(data: Buffer): I + public suspend fun computeIntermediate(data: Buffer): I //Compose two blocks - suspend fun composeIntermediate(first: I, second: I): I + public suspend fun composeIntermediate(first: I, second: I): I //Transform block to result - suspend fun toResult(intermediate: I): R + public suspend fun toResult(intermediate: I): R - override suspend fun invoke(data: Buffer): R = toResult(computeIntermediate(data)) + public override suspend fun invoke(data: Buffer): R = toResult(computeIntermediate(data)) } @FlowPreview @@ -46,7 +46,7 @@ private fun ComposableStatistic.flowIntermediate( dispatcher: CoroutineDispatcher = Dispatchers.Default ): Flow = flow .mapParallel(dispatcher) { computeIntermediate(it) } - .scanReduce(::composeIntermediate) + .runningReduce(::composeIntermediate) /** @@ -57,7 +57,7 @@ private fun ComposableStatistic.flowIntermediate( */ @FlowPreview @ExperimentalCoroutinesApi -fun ComposableStatistic.flow( +public fun ComposableStatistic.flow( flow: Flow>, dispatcher: CoroutineDispatcher = Dispatchers.Default ): Flow = flowIntermediate(flow, dispatcher).map(::toResult) @@ -65,32 +65,32 @@ fun ComposableStatistic.flow( /** * Arithmetic mean */ -class Mean(val space: Space) : ComposableStatistic, T> { - override suspend fun computeIntermediate(data: Buffer): Pair = +public class Mean(public val space: Space) : ComposableStatistic, T> { + public override suspend fun computeIntermediate(data: Buffer): Pair = space { sum(data.asIterable()) } to data.size - override suspend fun composeIntermediate(first: Pair, second: Pair): Pair = + public override suspend fun composeIntermediate(first: Pair, second: Pair): Pair = space { first.first + second.first } to (first.second + second.second) - override suspend fun toResult(intermediate: Pair): T = + public override suspend fun toResult(intermediate: Pair): T = space { intermediate.first / intermediate.second } - companion object { + public companion object { //TODO replace with optimized version which respects overflow - val real: Mean = Mean(RealField) - val int: Mean = Mean(IntRing) - val long: Mean = Mean(LongRing) + public val real: Mean = Mean(RealField) + public val int: Mean = Mean(IntRing) + public val long: Mean = Mean(LongRing) } } /** * Non-composable median */ -class Median(private val comparator: Comparator) : Statistic { - override suspend fun invoke(data: Buffer): T = +public class Median(private val comparator: Comparator) : Statistic { + public override suspend fun invoke(data: Buffer): T = data.asSequence().sortedWith(comparator).toList()[data.size / 2] //TODO check if this is correct - companion object { - val real: Median = Median(Comparator { a: Double, b: Double -> a.compareTo(b) }) + public companion object { + public val real: Median = Median { a: Double, b: Double -> a.compareTo(b) } } } diff --git a/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/UniformDistribution.kt b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/UniformDistribution.kt new file mode 100644 index 000000000..1ba5c96f1 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/kscience/kmath/stat/UniformDistribution.kt @@ -0,0 +1,22 @@ +package kscience.kmath.stat + +import kscience.kmath.chains.Chain +import kscience.kmath.chains.SimpleChain + +public class UniformDistribution(public val range: ClosedFloatingPointRange) : UnivariateDistribution { + private val length: Double = range.endInclusive - range.start + + override fun probability(arg: Double): Double = if (arg in range) 1.0 / length else 0.0 + + override fun sample(generator: RandomGenerator): Chain = + SimpleChain { range.start + generator.nextDouble() * length } + + override fun cumulative(arg: Double): Double = when { + arg < range.start -> 0.0 + arg >= range.endInclusive -> 1.0 + else -> (arg - range.start) / length + } +} + +public fun Distribution.Companion.uniform(range: ClosedFloatingPointRange): UniformDistribution = + UniformDistribution(range) diff --git a/kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/RandomSourceGenerator.kt b/kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/RandomSourceGenerator.kt new file mode 100644 index 000000000..5cba28a95 --- /dev/null +++ b/kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/RandomSourceGenerator.kt @@ -0,0 +1,58 @@ +package kscience.kmath.stat + +import org.apache.commons.rng.UniformRandomProvider +import org.apache.commons.rng.simple.RandomSource + +public class RandomSourceGenerator(public val source: RandomSource, seed: Long?) : RandomGenerator { + internal val random: UniformRandomProvider = seed?.let { + RandomSource.create(source, seed) + } ?: RandomSource.create(source) + + public override fun nextBoolean(): Boolean = random.nextBoolean() + public override fun nextDouble(): Double = random.nextDouble() + public override fun nextInt(): Int = random.nextInt() + public override fun nextInt(until: Int): Int = random.nextInt(until) + public override fun nextLong(): Long = random.nextLong() + public override fun nextLong(until: Long): Long = random.nextLong(until) + + public override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { + require(toIndex > fromIndex) + random.nextBytes(array, fromIndex, toIndex - fromIndex) + } + + public override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong()) +} + +public inline class RandomGeneratorProvider(public val generator: RandomGenerator) : UniformRandomProvider { + public override fun nextBoolean(): Boolean = generator.nextBoolean() + public override fun nextFloat(): Float = generator.nextDouble().toFloat() + + public override fun nextBytes(bytes: ByteArray) { + generator.fillBytes(bytes) + } + + public override fun nextBytes(bytes: ByteArray, start: Int, len: Int) { + generator.fillBytes(bytes, start, start + len) + } + + public override fun nextInt(): Int = generator.nextInt() + public override fun nextInt(n: Int): Int = generator.nextInt(n) + public override fun nextDouble(): Double = generator.nextDouble() + public override fun nextLong(): Long = generator.nextLong() + public override fun nextLong(n: Long): Long = generator.nextLong(n) +} + +/** + * Represent this [RandomGenerator] as commons-rng [UniformRandomProvider] preserving and mirroring its current state. + * Getting new value from one of those changes the state of another. + */ +public fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this is RandomSourceGenerator) + random +else + RandomGeneratorProvider(this) + +public fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator = + RandomSourceGenerator(source, seed) + +public fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator = + fromSource(RandomSource.MT, seed) diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt b/kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/distributions.kt similarity index 53% rename from kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt rename to kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/distributions.kt index 412454994..6cc18a37c 100644 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt +++ b/kmath-stat/src/jvmMain/kotlin/kscience/kmath/stat/distributions.kt @@ -1,47 +1,42 @@ -package scientifik.kmath.prob +package kscience.kmath.stat +import kscience.kmath.chains.BlockingIntChain +import kscience.kmath.chains.BlockingRealChain +import kscience.kmath.chains.Chain import org.apache.commons.rng.UniformRandomProvider import org.apache.commons.rng.sampling.distribution.* -import scientifik.kmath.chains.BlockingIntChain -import scientifik.kmath.chains.BlockingRealChain -import scientifik.kmath.chains.Chain -import java.util.* import kotlin.math.PI import kotlin.math.exp import kotlin.math.pow import kotlin.math.sqrt -abstract class ContinuousSamplerDistribution : Distribution { - +public abstract class ContinuousSamplerDistribution : Distribution { private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() { private val sampler = buildCMSampler(generator) override fun nextDouble(): Double = sampler.sample() - override fun fork(): Chain = ContinuousSamplerChain(generator.fork()) } protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler - override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator) + public override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator) } -abstract class DiscreteSamplerDistribution : Distribution { - +public abstract class DiscreteSamplerDistribution : Distribution { private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() { private val sampler = buildSampler(generator) override fun nextInt(): Int = sampler.sample() - override fun fork(): Chain = ContinuousSamplerChain(generator.fork()) } protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler - override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator) + public override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator) } -enum class NormalSamplerMethod { +public enum class NormalSamplerMethod { BoxMuller, Marsaglia, Ziggurat @@ -54,20 +49,21 @@ private fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomPr NormalSamplerMethod.Ziggurat -> ZigguratNormalizedGaussianSampler(provider) } -fun Distribution.Companion.normal( +public fun Distribution.Companion.normal( method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat -): Distribution = object : ContinuousSamplerDistribution() { +): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() { override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { - val provider: UniformRandomProvider = generator.asUniformRandomProvider() + val provider = generator.asUniformRandomProvider() return normalSampler(method, provider) } - override fun probability(arg: Double): Double { - return exp(-arg.pow(2) / 2) / sqrt(PI * 2) - } + override fun probability(arg: Double): Double = exp(-arg.pow(2) / 2) / sqrt(PI * 2) } -fun Distribution.Companion.normal( +/** + * A univariate normal distribution with given [mean] and [sigma]. [method] defines commons-rng generation method + */ +public fun Distribution.Companion.normal( mean: Double, sigma: Double, method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat @@ -76,34 +72,27 @@ fun Distribution.Companion.normal( private val norm = sigma * sqrt(PI * 2) override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { - val provider: UniformRandomProvider = generator.asUniformRandomProvider() + val provider = generator.asUniformRandomProvider() val normalizedSampler = normalSampler(method, provider) return GaussianSampler(normalizedSampler, mean, sigma) } - override fun probability(arg: Double): Double { - return exp(-(arg - mean).pow(2) / 2 / sigma2) / norm - } + override fun probability(arg: Double): Double = exp(-(arg - mean).pow(2) / 2 / sigma2) / norm } -fun Distribution.Companion.poisson( - lambda: Double -): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() { +public fun Distribution.Companion.poisson(lambda: Double): DiscreteSamplerDistribution = + object : DiscreteSamplerDistribution() { + private val computedProb: MutableMap = hashMapOf(0 to exp(-lambda)) - override fun buildSampler(generator: RandomGenerator): DiscreteSampler { - return PoissonSampler.of(generator.asUniformRandomProvider(), lambda) - } + override fun buildSampler(generator: RandomGenerator): DiscreteSampler = + PoissonSampler.of(generator.asUniformRandomProvider(), lambda) - private val computedProb: HashMap = hashMapOf(0 to exp(-lambda)) + override fun probability(arg: Int): Double { + require(arg >= 0) { "The argument must be >= 0" } - override fun probability(arg: Int): Double { - require(arg >= 0) { "The argument must be >= 0" } - return if (arg > 40) { - exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda) - } else { - computedProb.getOrPut(arg) { - probability(arg - 1) * lambda / arg - } + return if (arg > 40) + exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda) + else + computedProb.getOrPut(arg) { probability(arg - 1) * lambda / arg } } } -} diff --git a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/CommonsDistributionsTest.kt similarity index 91% rename from kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt rename to kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/CommonsDistributionsTest.kt index 7638c695e..fe58fac08 100644 --- a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/CommonsDistributionsTest.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.prob +package kscience.kmath.stat import kotlinx.coroutines.flow.take import kotlinx.coroutines.flow.toList @@ -6,7 +6,7 @@ import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test -class CommonsDistributionsTest { +internal class CommonsDistributionsTest { @Test fun testNormalDistributionSuspend() { val distribution = Distribution.normal(7.0, 2.0) @@ -24,5 +24,4 @@ class CommonsDistributionsTest { val sample = distribution.sample(generator).nextBlock(1000) Assertions.assertEquals(7.0, sample.average(), 0.1) } - -} \ No newline at end of file +} diff --git a/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt new file mode 100644 index 000000000..4e29e6105 --- /dev/null +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt @@ -0,0 +1,86 @@ +package kscience.kmath.stat + +import kotlinx.coroutines.* +import java.util.* +import kotlin.collections.HashSet +import kotlin.test.Test +import kotlin.test.assertEquals + +data class RandomResult(val branch: String, val order: Int, val value: Int) + +typealias ATest = suspend CoroutineScope.() -> Set + +class MCScopeTest { + val simpleTest: ATest = { + mcScope(1111) { + val res = Collections.synchronizedSet(HashSet()) + + launch { + //println(random) + repeat(10) { + delay(10) + res.add(RandomResult("first", it, random.nextInt())) + } + launch { + //empty fork + } + } + + launch { + //println(random) + repeat(10) { + delay(10) + res.add(RandomResult("second", it, random.nextInt())) + } + } + + + res + } + } + + val testWithJoin: ATest = { + mcScope(1111) { + val res = Collections.synchronizedSet(HashSet()) + + val job = launch { + repeat(10) { + delay(10) + res.add(RandomResult("first", it, random.nextInt())) + } + } + launch { + repeat(10) { + delay(10) + if (it == 4) job.join() + res.add(RandomResult("second", it, random.nextInt())) + } + } + + res + } + } + + + @OptIn(ObsoleteCoroutinesApi::class) + fun compareResult(test: ATest) { + val res1 = runBlocking(Dispatchers.Default) { test() } + val res2 = runBlocking(newSingleThreadContext("test")) { test() } + assertEquals( + res1.find { it.branch == "first" && it.order == 7 }?.value, + res2.find { it.branch == "first" && it.order == 7 }?.value + ) + assertEquals(res1, res2) + } + + @Test + fun testParallel() { + compareResult(simpleTest) + } + + + @Test + fun testConditionalJoin() { + compareResult(testWithJoin) + } +} \ No newline at end of file diff --git a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/SamplerTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/SamplerTest.kt similarity index 84% rename from kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/SamplerTest.kt rename to kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/SamplerTest.kt index 1152f3057..afed4c5d0 100644 --- a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/SamplerTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/SamplerTest.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.prob +package kscience.kmath.stat import kotlinx.coroutines.runBlocking import kotlin.test.Test @@ -6,7 +6,7 @@ import kotlin.test.Test class SamplerTest { @Test - fun bufferSamplerTest(){ + fun bufferSamplerTest() { val sampler: Sampler = BasicSampler { it.chain { nextDouble() } } val data = sampler.sampleBuffer(RandomGenerator.default, 100) diff --git a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/StatisticTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/StatisticTest.kt similarity index 81% rename from kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/StatisticTest.kt rename to kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/StatisticTest.kt index 2613f71d5..5cee4d172 100644 --- a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/StatisticTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/StatisticTest.kt @@ -1,18 +1,20 @@ -package scientifik.kmath.prob +package kscience.kmath.stat import kotlinx.coroutines.flow.drop import kotlinx.coroutines.flow.first import kotlinx.coroutines.runBlocking -import scientifik.kmath.streaming.chunked +import kscience.kmath.streaming.chunked import kotlin.test.Test -class StatisticTest { +internal class StatisticTest { //create a random number generator. val generator = RandomGenerator.default(1) + //Create a stateless chain from generator. val data = generator.chain { nextDouble() } - //Convert a chaint to Flow and break it into chunks. + + //Convert a chain to Flow and break it into chunks. val chunked = data.chunked(1000) @Test @@ -22,7 +24,8 @@ class StatisticTest { .flow(chunked) //create a flow with results .drop(99) // Skip first 99 values and use one with total data .first() //get 1e5 data samples average + println(average) } } -} \ No newline at end of file +} diff --git a/kmath-viktor/build.gradle.kts b/kmath-viktor/build.gradle.kts index 52ee7c497..3e5c5912c 100644 --- a/kmath-viktor/build.gradle.kts +++ b/kmath-viktor/build.gradle.kts @@ -1,5 +1,5 @@ plugins { - id("scientifik.jvm") + id("ru.mipt.npm.jvm") } description = "Binding for https://github.com/JetBrains-Research/viktor" @@ -7,4 +7,4 @@ description = "Binding for https://github.com/JetBrains-Research/viktor" dependencies { api(project(":kmath-core")) api("org.jetbrains.bio:viktor:1.0.1") -} \ No newline at end of file +} diff --git a/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorBuffer.kt b/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorBuffer.kt new file mode 100644 index 000000000..5c9611758 --- /dev/null +++ b/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorBuffer.kt @@ -0,0 +1,19 @@ +package kscience.kmath.viktor + +import kscience.kmath.structures.MutableBuffer +import org.jetbrains.bio.viktor.F64FlatArray + +@Suppress("NOTHING_TO_INLINE", "OVERRIDE_BY_INLINE") +public inline class ViktorBuffer(public val flatArray: F64FlatArray) : MutableBuffer { + public override val size: Int + get() = flatArray.size + + public override inline fun get(index: Int): Double = flatArray[index] + + public override inline fun set(index: Int, value: Double) { + flatArray[index] = value + } + + public override fun copy(): MutableBuffer = ViktorBuffer(flatArray.copy().flatten()) + public override operator fun iterator(): Iterator = flatArray.data.iterator() +} diff --git a/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorNDStructure.kt b/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorNDStructure.kt new file mode 100644 index 000000000..2471362cb --- /dev/null +++ b/kmath-viktor/src/main/kotlin/kscience/kmath/viktor/ViktorNDStructure.kt @@ -0,0 +1,88 @@ +package kscience.kmath.viktor + +import kscience.kmath.operations.RealField +import kscience.kmath.structures.DefaultStrides +import kscience.kmath.structures.MutableNDStructure +import kscience.kmath.structures.NDField +import kscience.kmath.structures.Strides +import org.jetbrains.bio.viktor.F64Array + +@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public inline class ViktorNDStructure(public val f64Buffer: F64Array) : MutableNDStructure { + public override val shape: IntArray get() = f64Buffer.shape + + public override inline fun get(index: IntArray): Double = f64Buffer.get(*index) + + public override inline fun set(index: IntArray, value: Double) { + f64Buffer.set(*index, value = value) + } + + public override fun elements(): Sequence> = + DefaultStrides(shape).indices().map { it to get(it) } +} + +public fun F64Array.asStructure(): ViktorNDStructure = ViktorNDStructure(this) + +@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public class ViktorNDField(public override val shape: IntArray) : NDField { + public override val zero: ViktorNDStructure + get() = F64Array.full(init = 0.0, shape = shape).asStructure() + + public override val one: ViktorNDStructure + get() = F64Array.full(init = 1.0, shape = shape).asStructure() + + public val strides: Strides = DefaultStrides(shape) + + public override val elementContext: RealField get() = RealField + + public override fun produce(initializer: RealField.(IntArray) -> Double): ViktorNDStructure = + F64Array(*shape).apply { + this@ViktorNDField.strides.indices().forEach { index -> + set(value = RealField.initializer(index), indices = index) + } + }.asStructure() + + public override fun map(arg: ViktorNDStructure, transform: RealField.(Double) -> Double): ViktorNDStructure = + F64Array(*shape).apply { + this@ViktorNDField.strides.indices().forEach { index -> + set(value = RealField.transform(arg[index]), indices = index) + } + }.asStructure() + + public override fun mapIndexed( + arg: ViktorNDStructure, + transform: RealField.(index: IntArray, Double) -> Double + ): ViktorNDStructure = F64Array(*shape).apply { + this@ViktorNDField.strides.indices().forEach { index -> + set(value = RealField.transform(index, arg[index]), indices = index) + } + }.asStructure() + + public override fun combine( + a: ViktorNDStructure, + b: ViktorNDStructure, + transform: RealField.(Double, Double) -> Double + ): ViktorNDStructure = F64Array(*shape).apply { + this@ViktorNDField.strides.indices().forEach { index -> + set(value = RealField.transform(a[index], b[index]), indices = index) + } + }.asStructure() + + public override fun add(a: ViktorNDStructure, b: ViktorNDStructure): ViktorNDStructure = + (a.f64Buffer + b.f64Buffer).asStructure() + + public override fun multiply(a: ViktorNDStructure, k: Number): ViktorNDStructure = + (a.f64Buffer * k.toDouble()).asStructure() + + public override inline fun ViktorNDStructure.plus(b: ViktorNDStructure): ViktorNDStructure = + (f64Buffer + b.f64Buffer).asStructure() + + public override inline fun ViktorNDStructure.minus(b: ViktorNDStructure): ViktorNDStructure = + (f64Buffer - b.f64Buffer).asStructure() + + public override inline fun ViktorNDStructure.times(k: Number): ViktorNDStructure = + (f64Buffer * k.toDouble()).asStructure() + + public override inline fun ViktorNDStructure.plus(arg: Double): ViktorNDStructure = + (f64Buffer.plus(arg)).asStructure() +} \ No newline at end of file diff --git a/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt b/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt deleted file mode 100644 index 551b877a7..000000000 --- a/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorBuffer.kt +++ /dev/null @@ -1,20 +0,0 @@ -package scientifik.kmath.viktor - -import org.jetbrains.bio.viktor.F64FlatArray -import scientifik.kmath.structures.MutableBuffer - -@Suppress("NOTHING_TO_INLINE", "OVERRIDE_BY_INLINE") -inline class ViktorBuffer(val flatArray: F64FlatArray) : MutableBuffer { - override val size: Int get() = flatArray.size - - override inline fun get(index: Int): Double = flatArray[index] - override inline fun set(index: Int, value: Double) { - flatArray[index] = value - } - - override fun copy(): MutableBuffer { - return ViktorBuffer(flatArray.copy().flatten()) - } - - override operator fun iterator(): Iterator = flatArray.data.iterator() -} diff --git a/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorNDStructure.kt b/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorNDStructure.kt deleted file mode 100644 index 84e927721..000000000 --- a/kmath-viktor/src/main/kotlin/scientifik/kmath/viktor/ViktorNDStructure.kt +++ /dev/null @@ -1,86 +0,0 @@ -package scientifik.kmath.viktor - -import org.jetbrains.bio.viktor.F64Array -import scientifik.kmath.operations.RealField -import scientifik.kmath.structures.DefaultStrides -import scientifik.kmath.structures.MutableNDStructure -import scientifik.kmath.structures.NDField - -@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -inline class ViktorNDStructure(val f64Buffer: F64Array) : MutableNDStructure { - - override val shape: IntArray get() = f64Buffer.shape - - override inline fun get(index: IntArray): Double = f64Buffer.get(*index) - - override inline fun set(index: IntArray, value: Double) { - f64Buffer.set(*index, value = value) - } - - override fun elements(): Sequence> { - return DefaultStrides(shape).indices().map { it to get(it) } - } -} - -fun F64Array.asStructure(): ViktorNDStructure = ViktorNDStructure(this) - -@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -class ViktorNDField(override val shape: IntArray) : NDField { - override val zero: ViktorNDStructure - get() = F64Array.full(init = 0.0, shape = *shape).asStructure() - override val one: ViktorNDStructure - get() = F64Array.full(init = 1.0, shape = *shape).asStructure() - - val strides = DefaultStrides(shape) - - override val elementContext: RealField get() = RealField - - override fun produce(initializer: RealField.(IntArray) -> Double): ViktorNDStructure = F64Array(*shape).apply { - this@ViktorNDField.strides.indices().forEach { index -> - set(value = RealField.initializer(index), indices = *index) - } - }.asStructure() - - override fun map(arg: ViktorNDStructure, transform: RealField.(Double) -> Double): ViktorNDStructure = - F64Array(*shape).apply { - this@ViktorNDField.strides.indices().forEach { index -> - set(value = RealField.transform(arg[index]), indices = *index) - } - }.asStructure() - - override fun mapIndexed( - arg: ViktorNDStructure, - transform: RealField.(index: IntArray, Double) -> Double - ): ViktorNDStructure = F64Array(*shape).apply { - this@ViktorNDField.strides.indices().forEach { index -> - set(value = RealField.transform(index, arg[index]), indices = *index) - } - }.asStructure() - - override fun combine( - a: ViktorNDStructure, - b: ViktorNDStructure, - transform: RealField.(Double, Double) -> Double - ): ViktorNDStructure = F64Array(*shape).apply { - this@ViktorNDField.strides.indices().forEach { index -> - set(value = RealField.transform(a[index], b[index]), indices = *index) - } - }.asStructure() - - override fun add(a: ViktorNDStructure, b: ViktorNDStructure): ViktorNDStructure { - return (a.f64Buffer + b.f64Buffer).asStructure() - } - - override fun multiply(a: ViktorNDStructure, k: Number): ViktorNDStructure = - (a.f64Buffer * k.toDouble()).asStructure() - - override inline fun ViktorNDStructure.plus(b: ViktorNDStructure): ViktorNDStructure = - (f64Buffer + b.f64Buffer).asStructure() - - override inline fun ViktorNDStructure.minus(b: ViktorNDStructure): ViktorNDStructure = - (f64Buffer - b.f64Buffer).asStructure() - - override inline fun ViktorNDStructure.times(k: Number): ViktorNDStructure = (f64Buffer * k.toDouble()).asStructure() - - override inline fun ViktorNDStructure.plus(arg: Double): ViktorNDStructure = (f64Buffer.plus(arg)).asStructure() -} \ No newline at end of file diff --git a/settings.gradle.kts b/settings.gradle.kts index 6601fd053..a1ea40148 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,43 +1,44 @@ pluginManagement { - - val toolsVersion = "0.5.2" - - plugins { - id("kotlinx.benchmark") version "0.2.0-dev-8" - id("scientifik.mpp") version toolsVersion - id("scientifik.jvm") version toolsVersion - id("scientifik.atomic") version toolsVersion - id("scientifik.publish") version toolsVersion - kotlin("plugin.allopen") version "1.3.72" - } - repositories { - mavenLocal() - jcenter() gradlePluginPortal() + jcenter() maven("https://dl.bintray.com/kotlin/kotlin-eap") - maven("https://dl.bintray.com/mipt-npm/scientifik") + maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/kotlin/kotlinx") } + + val toolsVersion = "0.7.3-1.4.30-RC" + val kotlinVersion = "1.4.30-RC" + + plugins { + id("kotlinx.benchmark") version "0.2.0-dev-20" + id("ru.mipt.npm.project") version toolsVersion + id("ru.mipt.npm.mpp") version toolsVersion + id("ru.mipt.npm.jvm") version toolsVersion + id("ru.mipt.npm.publish") version toolsVersion + kotlin("jvm") version kotlinVersion + kotlin("plugin.allopen") version kotlinVersion + } } rootProject.name = "kmath" + include( ":kmath-memory", ":kmath-core", ":kmath-functions", -// ":kmath-io", ":kmath-coroutines", ":kmath-histograms", ":kmath-commons", ":kmath-viktor", - ":kmath-koma", - ":kmath-prob", - ":kmath-io", + ":kmath-stat", + ":kmath-nd4j", ":kmath-dimensions", ":kmath-for-real", ":kmath-geometry", ":kmath-ast", + ":kmath-ejml", + ":kmath-kotlingrad", ":examples" )