forked from kscience/kmath
Merge branch 'dev' into gsl-experiment
# Conflicts: # build.gradle.kts # examples/build.gradle.kts # examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt # settings.gradle.kts
This commit is contained in:
commit
e5d5ac17da
@ -7,12 +7,15 @@
|
|||||||
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
||||||
- Automatic README generation for features (#139)
|
- Automatic README generation for features (#139)
|
||||||
- Native support for `memory`, `core` and `dimensions`
|
- Native support for `memory`, `core` and `dimensions`
|
||||||
- `kmath-ejml` to supply EJML SimpleMatrix wrapper.
|
- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136).
|
||||||
- A separate `Symbol` entity, which is used for global unbound symbol.
|
- A separate `Symbol` entity, which is used for global unbound symbol.
|
||||||
- A `Symbol` indexing scope.
|
- A `Symbol` indexing scope.
|
||||||
- Basic optimization API for Commons-math.
|
- Basic optimization API for Commons-math.
|
||||||
- Chi squared optimization for array-like data in CM
|
- Chi squared optimization for array-like data in CM
|
||||||
- `Fitting` utility object in prob/stat
|
- `Fitting` utility object in prob/stat
|
||||||
|
- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`.
|
||||||
|
- Coroutine-deterministic Monte-Carlo scope with a random number generator.
|
||||||
|
- Some minor utilities to `kmath-for-real`.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Package changed from `scientifik` to `kscience.kmath`.
|
- Package changed from `scientifik` to `kscience.kmath`.
|
||||||
@ -23,12 +26,14 @@
|
|||||||
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
|
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
|
||||||
- Full autodiff refactoring based on `Symbol`
|
- Full autodiff refactoring based on `Symbol`
|
||||||
- `kmath-prob` renamed to `kmath-stat`
|
- `kmath-prob` renamed to `kmath-stat`
|
||||||
|
- Grid generators moved to `kmath-for-real`
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
- `kmath-koma` module because it doesn't support Kotlin 1.4.
|
- `kmath-koma` module because it doesn't support Kotlin 1.4.
|
||||||
- Support of `legacy` JS backend (we will support only IR)
|
- Support of `legacy` JS backend (we will support only IR)
|
||||||
|
- `toGrid` method.
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
- `symbol` method in `MstExtendedField` (https://github.com/mipt-npm/kmath/pull/140)
|
- `symbol` method in `MstExtendedField` (https://github.com/mipt-npm/kmath/pull/140)
|
||||||
|
108
README.md
108
README.md
@ -8,41 +8,50 @@ Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience
|
|||||||
Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion)
|
Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion)
|
||||||
|
|
||||||
# KMath
|
# KMath
|
||||||
Could be pronounced as `key-math`.
|
|
||||||
The Kotlin MATHematics library was initially intended as a Kotlin-based analog to Python's `numpy` library. Later we found that kotlin is much more flexible language and allows superior architecture designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could be achieved with [kmath-for-real](/kmath-for-real) extension module.
|
Could be pronounced as `key-math`. The Kotlin MATHematics library was initially intended as a Kotlin-based analog to
|
||||||
|
Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture
|
||||||
|
designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could
|
||||||
|
be achieved with [kmath-for-real](/kmath-for-real) extension module.
|
||||||
|
|
||||||
## Publications and talks
|
## Publications and talks
|
||||||
|
|
||||||
* [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2)
|
* [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2)
|
||||||
* [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814)
|
* [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814)
|
||||||
* [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103)
|
* [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103)
|
||||||
|
|
||||||
# Goal
|
# Goal
|
||||||
* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM and JS for now and Native in future).
|
|
||||||
|
* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native).
|
||||||
* Provide basic multiplatform implementations for those abstractions (without significant performance optimization).
|
* Provide basic multiplatform implementations for those abstractions (without significant performance optimization).
|
||||||
* Provide bindings and wrappers with those abstractions for popular optimized platform libraries.
|
* Provide bindings and wrappers with those abstractions for popular optimized platform libraries.
|
||||||
|
|
||||||
## Non-goals
|
## Non-goals
|
||||||
* Be like Numpy. It was the idea at the beginning, but we decided that we can do better in terms of API.
|
|
||||||
* Provide best performance out of the box. We have specialized libraries for that. Need only API wrappers for them.
|
* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API.
|
||||||
|
* Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them.
|
||||||
* Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually.
|
* Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually.
|
||||||
* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better experience for those, who want to work with specific types.
|
* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like
|
||||||
|
for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better
|
||||||
|
experience for those, who want to work with specific types.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
Actual feature list is [here](/docs/features.md)
|
Current feature list is [here](/docs/features.md)
|
||||||
|
|
||||||
* **Algebra**
|
* **Algebra**
|
||||||
* Algebraic structures like rings, spaces and field (**TODO** add example to wiki)
|
* Algebraic structures like rings, spaces and fields (**TODO** add example to wiki)
|
||||||
* Basic linear algebra operations (sums, products, etc.), backed by the `Space` API.
|
* Basic linear algebra operations (sums, products, etc.), backed by the `Space` API.
|
||||||
* Complex numbers backed by the `Field` API (meaning that they will be usable in any structure like vectors and N-dimensional arrays).
|
* Complex numbers backed by the `Field` API (meaning they will be usable in any structure like vectors and
|
||||||
|
N-dimensional arrays).
|
||||||
* Advanced linear algebra operations like matrix inversion and LU decomposition.
|
* Advanced linear algebra operations like matrix inversion and LU decomposition.
|
||||||
|
|
||||||
* **Array-like structures** Full support of many-dimensional array-like structures
|
* **Array-like structures** Full support of many-dimensional array-like structures
|
||||||
including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).
|
including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).
|
||||||
|
|
||||||
* **Expressions** By writing a single mathematical expression
|
* **Expressions** By writing a single mathematical expression once, users will be able to apply different types of
|
||||||
once, users will be able to apply different types of objects to the expression by providing a context. Expressions
|
objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high
|
||||||
can be used for a wide variety of purposes from high performance calculations to code generation.
|
performance calculations to code generation.
|
||||||
|
|
||||||
* **Histograms** Fast multi-dimensional histograms.
|
* **Histograms** Fast multi-dimensional histograms.
|
||||||
|
|
||||||
@ -50,9 +59,10 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
|
|
||||||
* **Type-safe dimensions** Type-safe dimensions for matrix operations.
|
* **Type-safe dimensions** Type-safe dimensions for matrix operations.
|
||||||
|
|
||||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
* **Commons-math wrapper** It is planned to gradually wrap most parts of
|
||||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some
|
||||||
to submit a feature request if you want something to be done first.
|
parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to
|
||||||
|
submit a feature request if you want something to be implemented first.
|
||||||
|
|
||||||
## Planned features
|
## Planned features
|
||||||
|
|
||||||
@ -151,6 +161,18 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
<hr/>
|
||||||
|
|
||||||
|
* ### [kmath-nd4j](kmath-nd4j)
|
||||||
|
> ND4J NDStructure implementation and according NDAlgebra classes
|
||||||
|
>
|
||||||
|
> **Maturity**: EXPERIMENTAL
|
||||||
|
>
|
||||||
|
> **Features:**
|
||||||
|
> - [nd4jarraystrucure](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : NDStructure wrapper for INDArray
|
||||||
|
> - [nd4jarrayrings](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Rings over Nd4jArrayStructure of Int and Long
|
||||||
|
> - [nd4jarrayfields](kmath-nd4j/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : Fields over Nd4jArrayStructure of Float and Double
|
||||||
|
|
||||||
|
<hr/>
|
||||||
|
|
||||||
* ### [kmath-stat](kmath-stat)
|
* ### [kmath-stat](kmath-stat)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
@ -166,39 +188,69 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
|
|
||||||
## Multi-platform support
|
## Multi-platform support
|
||||||
|
|
||||||
KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the [common module](/kmath-core/src/commonMain). Implementation is also done in the common module wherever possible. In some cases, features are delegated to platform-specific implementations even if they could be done in the common module for performance reasons. Currently, the JVM is the main focus of development, however Kotlin/Native and Kotlin/JS contributions are also welcome.
|
KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the
|
||||||
|
[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features
|
||||||
|
are delegated to platform-specific implementations even if they could be provided in the common module for performance
|
||||||
|
reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and
|
||||||
|
feedback are also welcome.
|
||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
Calculation performance is one of major goals of KMath in the future, but in some cases it is not possible to achieve both performance and flexibility. We expect to focus on creating convenient universal API first and then work on increasing performance for specific cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be better than SciPy.
|
Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve
|
||||||
|
both performance and flexibility.
|
||||||
|
|
||||||
### Dependency
|
We expect to focus on creating convenient universal API first and then work on increasing performance for specific
|
||||||
|
cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized
|
||||||
|
native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be
|
||||||
|
better than SciPy.
|
||||||
|
|
||||||
Release artifacts are accessible from bintray with following configuration (see documentation for [kotlin-multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) form more details):
|
### Repositories
|
||||||
|
|
||||||
|
Release artifacts are accessible from bintray with following configuration (see documentation of
|
||||||
|
[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details):
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
repositories{
|
repositories {
|
||||||
|
jcenter()
|
||||||
|
maven("https://clojars.org/repo")
|
||||||
|
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
|
maven("https://jitpack.io")
|
||||||
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies{
|
dependencies {
|
||||||
api("kscience.kmath:kmath-core:0.2.0-dev-2")
|
api("kscience.kmath:kmath-core:0.2.0-dev-3")
|
||||||
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
|
// api("kscience.kmath:kmath-core-jvm:0.2.0-dev-3") for jvm-specific version
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Gradle `6.0+` is required for multiplatform artifacts.
|
Gradle `6.0+` is required for multiplatform artifacts.
|
||||||
|
|
||||||
### Development
|
#### Development
|
||||||
|
|
||||||
|
Development builds are uploaded to the separate repository:
|
||||||
|
|
||||||
Development builds are accessible from the reposirtory
|
|
||||||
```kotlin
|
```kotlin
|
||||||
repositories{
|
repositories {
|
||||||
|
jcenter()
|
||||||
|
maven("https://clojars.org/repo")
|
||||||
|
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
maven("https://jitpack.io")
|
||||||
|
mavenCentral()
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
with the same artifact names.
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
|
The project requires a lot of additional work. The most important thing we need is a feedback about what features are
|
||||||
|
required the most. Feel free to create feature requests. We are also welcome to code contributions,
|
||||||
|
especially in issues marked with
|
||||||
|
[waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label.
|
||||||
|
@ -1,17 +1,26 @@
|
|||||||
|
import ru.mipt.npm.gradle.KSciencePublishPlugin
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("ru.mipt.npm.project")
|
id("ru.mipt.npm.project")
|
||||||
}
|
}
|
||||||
|
|
||||||
internal val kmathVersion: String by extra("0.2.0-dev-3")
|
internal val kmathVersion: String by extra("0.2.0-dev-4")
|
||||||
internal val bintrayRepo: String by extra("kscience")
|
internal val bintrayRepo: String by extra("kscience")
|
||||||
internal val githubProject: String by extra("kmath")
|
internal val githubProject: String by extra("kmath")
|
||||||
|
|
||||||
allprojects {
|
allprojects {
|
||||||
repositories {
|
repositories {
|
||||||
jcenter()
|
jcenter()
|
||||||
|
maven("https://clojars.org/repo")
|
||||||
|
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
|
maven("https://jitpack.io")
|
||||||
|
maven("http://logicrunch.research.it.uu.se/maven/")
|
||||||
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
group = "kscience.kmath"
|
group = "kscience.kmath"
|
||||||
@ -19,7 +28,7 @@ allprojects {
|
|||||||
}
|
}
|
||||||
|
|
||||||
subprojects {
|
subprojects {
|
||||||
if (name.startsWith("kmath")) apply<ru.mipt.npm.gradle.KSciencePublishPlugin>()
|
if (name.startsWith("kmath")) apply<KSciencePublishPlugin>()
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
|
78
docs/templates/README-TEMPLATE.md
vendored
78
docs/templates/README-TEMPLATE.md
vendored
@ -8,41 +8,50 @@ Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience
|
|||||||
Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion)
|
Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion)
|
||||||
|
|
||||||
# KMath
|
# KMath
|
||||||
Could be pronounced as `key-math`.
|
|
||||||
The Kotlin MATHematics library was initially intended as a Kotlin-based analog to Python's `numpy` library. Later we found that kotlin is much more flexible language and allows superior architecture designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could be achieved with [kmath-for-real](/kmath-for-real) extension module.
|
Could be pronounced as `key-math`. The Kotlin MATHematics library was initially intended as a Kotlin-based analog to
|
||||||
|
Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture
|
||||||
|
designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could
|
||||||
|
be achieved with [kmath-for-real](/kmath-for-real) extension module.
|
||||||
|
|
||||||
## Publications and talks
|
## Publications and talks
|
||||||
|
|
||||||
* [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2)
|
* [A conceptual article about context-oriented design](https://proandroiddev.com/an-introduction-context-oriented-programming-in-kotlin-2e79d316b0a2)
|
||||||
* [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814)
|
* [Another article about context-oriented design](https://proandroiddev.com/diving-deeper-into-context-oriented-programming-in-kotlin-3ecb4ec38814)
|
||||||
* [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103)
|
* [ACAT 2019 conference paper](https://aip.scitation.org/doi/abs/10.1063/1.5130103)
|
||||||
|
|
||||||
# Goal
|
# Goal
|
||||||
* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM and JS for now and Native in future).
|
|
||||||
|
* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native).
|
||||||
* Provide basic multiplatform implementations for those abstractions (without significant performance optimization).
|
* Provide basic multiplatform implementations for those abstractions (without significant performance optimization).
|
||||||
* Provide bindings and wrappers with those abstractions for popular optimized platform libraries.
|
* Provide bindings and wrappers with those abstractions for popular optimized platform libraries.
|
||||||
|
|
||||||
## Non-goals
|
## Non-goals
|
||||||
* Be like Numpy. It was the idea at the beginning, but we decided that we can do better in terms of API.
|
|
||||||
* Provide best performance out of the box. We have specialized libraries for that. Need only API wrappers for them.
|
* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API.
|
||||||
|
* Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them.
|
||||||
* Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually.
|
* Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually.
|
||||||
* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better experience for those, who want to work with specific types.
|
* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like
|
||||||
|
for `Double` in the core. For that we will have specialization modules like `for-real`, which will give better
|
||||||
|
experience for those, who want to work with specific types.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
Actual feature list is [here](/docs/features.md)
|
Current feature list is [here](/docs/features.md)
|
||||||
|
|
||||||
* **Algebra**
|
* **Algebra**
|
||||||
* Algebraic structures like rings, spaces and field (**TODO** add example to wiki)
|
* Algebraic structures like rings, spaces and fields (**TODO** add example to wiki)
|
||||||
* Basic linear algebra operations (sums, products, etc.), backed by the `Space` API.
|
* Basic linear algebra operations (sums, products, etc.), backed by the `Space` API.
|
||||||
* Complex numbers backed by the `Field` API (meaning that they will be usable in any structure like vectors and N-dimensional arrays).
|
* Complex numbers backed by the `Field` API (meaning they will be usable in any structure like vectors and
|
||||||
|
N-dimensional arrays).
|
||||||
* Advanced linear algebra operations like matrix inversion and LU decomposition.
|
* Advanced linear algebra operations like matrix inversion and LU decomposition.
|
||||||
|
|
||||||
* **Array-like structures** Full support of many-dimensional array-like structures
|
* **Array-like structures** Full support of many-dimensional array-like structures
|
||||||
including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).
|
including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).
|
||||||
|
|
||||||
* **Expressions** By writing a single mathematical expression
|
* **Expressions** By writing a single mathematical expression once, users will be able to apply different types of
|
||||||
once, users will be able to apply different types of objects to the expression by providing a context. Expressions
|
objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high
|
||||||
can be used for a wide variety of purposes from high performance calculations to code generation.
|
performance calculations to code generation.
|
||||||
|
|
||||||
* **Histograms** Fast multi-dimensional histograms.
|
* **Histograms** Fast multi-dimensional histograms.
|
||||||
|
|
||||||
@ -50,9 +59,10 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
|
|
||||||
* **Type-safe dimensions** Type-safe dimensions for matrix operations.
|
* **Type-safe dimensions** Type-safe dimensions for matrix operations.
|
||||||
|
|
||||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
* **Commons-math wrapper** It is planned to gradually wrap most parts of
|
||||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
[Apache commons-math](http://commons.apache.org/proper/commons-math/) library in Kotlin code and maybe rewrite some
|
||||||
to submit a feature request if you want something to be done first.
|
parts to better suit the Kotlin programming paradigm, however there is no established roadmap for that. Feel free to
|
||||||
|
submit a feature request if you want something to be implemented first.
|
||||||
|
|
||||||
## Planned features
|
## Planned features
|
||||||
|
|
||||||
@ -72,39 +82,53 @@ $modules
|
|||||||
|
|
||||||
## Multi-platform support
|
## Multi-platform support
|
||||||
|
|
||||||
KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the [common module](/kmath-core/src/commonMain). Implementation is also done in the common module wherever possible. In some cases, features are delegated to platform-specific implementations even if they could be done in the common module for performance reasons. Currently, the JVM is the main focus of development, however Kotlin/Native and Kotlin/JS contributions are also welcome.
|
KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the
|
||||||
|
[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features
|
||||||
|
are delegated to platform-specific implementations even if they could be provided in the common module for performance
|
||||||
|
reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and
|
||||||
|
feedback are also welcome.
|
||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
Calculation performance is one of major goals of KMath in the future, but in some cases it is not possible to achieve both performance and flexibility. We expect to focus on creating convenient universal API first and then work on increasing performance for specific cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be better than SciPy.
|
Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve
|
||||||
|
both performance and flexibility.
|
||||||
|
|
||||||
### Dependency
|
We expect to focus on creating convenient universal API first and then work on increasing performance for specific
|
||||||
|
cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized
|
||||||
|
native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be
|
||||||
|
better than SciPy.
|
||||||
|
|
||||||
Release artifacts are accessible from bintray with following configuration (see documentation for [kotlin-multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) form more details):
|
### Repositories
|
||||||
|
|
||||||
|
Release artifacts are accessible from bintray with following configuration (see documentation of
|
||||||
|
[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details):
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
repositories{
|
repositories {
|
||||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies{
|
dependencies {
|
||||||
api("kscience.kmath:kmath-core:$version")
|
api("kscience.kmath:kmath-core:$version")
|
||||||
//api("kscience.kmath:kmath-core-jvm:$version") for jvm-specific version
|
// api("kscience.kmath:kmath-core-jvm:$version") for jvm-specific version
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Gradle `6.0+` is required for multiplatform artifacts.
|
Gradle `6.0+` is required for multiplatform artifacts.
|
||||||
|
|
||||||
### Development
|
#### Development
|
||||||
|
|
||||||
|
Development builds are uploaded to the separate repository:
|
||||||
|
|
||||||
Development builds are accessible from the reposirtory
|
|
||||||
```kotlin
|
```kotlin
|
||||||
repositories{
|
repositories {
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
with the same artifact names.
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
|
The project requires a lot of additional work. The most important thing we need is a feedback about what features are
|
||||||
|
required the most. Feel free to create feature requests. We are also welcome to code contributions,
|
||||||
|
especially in issues marked with
|
||||||
|
[waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label.
|
||||||
|
@ -7,18 +7,25 @@ plugins {
|
|||||||
}
|
}
|
||||||
|
|
||||||
allOpen.annotation("org.openjdk.jmh.annotations.State")
|
allOpen.annotation("org.openjdk.jmh.annotations.State")
|
||||||
|
sourceSets.register("benchmarks")
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
jcenter()
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
maven("https://clojars.org/repo")
|
||||||
|
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
|
maven("https://jitpack.io")
|
||||||
|
maven("http://logicrunch.research.it.uu.se/maven/")
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
sourceSets.register("benchmarks")
|
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation(project(":kmath-ast"))
|
implementation(project(":kmath-ast"))
|
||||||
|
implementation(project(":kmath-kotlingrad"))
|
||||||
implementation(project(":kmath-core"))
|
implementation(project(":kmath-core"))
|
||||||
implementation(project(":kmath-coroutines"))
|
implementation(project(":kmath-coroutines"))
|
||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
@ -26,6 +33,21 @@ dependencies {
|
|||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
implementation(project(":kmath-ejml"))
|
implementation(project(":kmath-ejml"))
|
||||||
|
implementation(project(":kmath-nd4j"))
|
||||||
|
implementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7")
|
||||||
|
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
||||||
|
|
||||||
|
// uncomment if your system supports AVX2
|
||||||
|
// val os = System.getProperty("os.name")
|
||||||
|
//
|
||||||
|
// if (System.getProperty("os.arch") in arrayOf("x86_64", "amd64")) when {
|
||||||
|
// os.startsWith("Windows") -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:windows-x86_64-avx2")
|
||||||
|
// os == "Linux" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:linux-x86_64-avx2")
|
||||||
|
// os == "Mac OS X" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:macosx-x86_64-avx2")
|
||||||
|
// } else
|
||||||
|
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||||
|
|
||||||
|
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
||||||
implementation("org.slf4j:slf4j-simple:1.7.30")
|
implementation("org.slf4j:slf4j-simple:1.7.30")
|
||||||
"benchmarksImplementation"("org.jetbrains.kotlinx:kotlinx.benchmark.runtime-jvm:0.2.0-dev-8")
|
"benchmarksImplementation"("org.jetbrains.kotlinx:kotlinx.benchmark.runtime-jvm:0.2.0-dev-8")
|
||||||
@ -53,4 +75,6 @@ kotlin.sourceSets.all {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.withType<KotlinCompile> { kotlinOptions.jvmTarget = "11" }
|
tasks.withType<KotlinCompile> {
|
||||||
|
kotlinOptions.jvmTarget = "11"
|
||||||
|
}
|
||||||
|
@ -9,11 +9,11 @@ import kscience.kmath.operations.RealField
|
|||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
class ExpressionsInterpretersBenchmark {
|
internal class ExpressionsInterpretersBenchmark {
|
||||||
private val algebra: Field<Double> = RealField
|
private val algebra: Field<Double> = RealField
|
||||||
fun functionalExpression() {
|
fun functionalExpression() {
|
||||||
val expr = algebra.expressionInField {
|
val expr = algebra.expressionInField {
|
||||||
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
|
symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
@ -47,6 +47,16 @@ class ExpressionsInterpretersBenchmark {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
|
||||||
|
* core FunctionalExpressions API.
|
||||||
|
*
|
||||||
|
* The expected rating is:
|
||||||
|
*
|
||||||
|
* 1. ASM.
|
||||||
|
* 2. MST.
|
||||||
|
* 3. FE.
|
||||||
|
*/
|
||||||
fun main() {
|
fun main() {
|
||||||
val benchmark = ExpressionsInterpretersBenchmark()
|
val benchmark = ExpressionsInterpretersBenchmark()
|
||||||
|
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
package kscience.kmath.ast
|
||||||
|
|
||||||
|
import kscience.kmath.asm.compile
|
||||||
|
import kscience.kmath.expressions.derivative
|
||||||
|
import kscience.kmath.expressions.invoke
|
||||||
|
import kscience.kmath.expressions.symbol
|
||||||
|
import kscience.kmath.kotlingrad.differentiable
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with
|
||||||
|
* valid derivative.
|
||||||
|
*/
|
||||||
|
fun main() {
|
||||||
|
val x by symbol
|
||||||
|
|
||||||
|
val actualDerivative = MstExpression(RealField, "x^2-4*x-44".parseMath())
|
||||||
|
.differentiable()
|
||||||
|
.derivative(x)
|
||||||
|
.compile()
|
||||||
|
|
||||||
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
|
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||||
|
}
|
@ -1,8 +1,10 @@
|
|||||||
package kscience.kmath.structures
|
package kscience.kmath.structures
|
||||||
|
|
||||||
import kotlinx.coroutines.GlobalScope
|
import kotlinx.coroutines.GlobalScope
|
||||||
|
import kscience.kmath.nd4j.Nd4jArrayField
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
@ -14,6 +16,8 @@ internal inline fun measureAndPrint(title: String, block: () -> Unit) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
|
// initializing Nd4j
|
||||||
|
Nd4j.zeros(0)
|
||||||
val dim = 1000
|
val dim = 1000
|
||||||
val n = 1000
|
val n = 1000
|
||||||
|
|
||||||
@ -23,6 +27,8 @@ fun main() {
|
|||||||
val specializedField = NDField.real(dim, dim)
|
val specializedField = NDField.real(dim, dim)
|
||||||
//A generic boxing field. It should be used for objects, not primitives.
|
//A generic boxing field. It should be used for objects, not primitives.
|
||||||
val genericField = NDField.boxing(RealField, dim, dim)
|
val genericField = NDField.boxing(RealField, dim, dim)
|
||||||
|
// Nd4j specialized field.
|
||||||
|
val nd4jField = Nd4jArrayField.real(dim, dim)
|
||||||
|
|
||||||
measureAndPrint("Automatic field addition") {
|
measureAndPrint("Automatic field addition") {
|
||||||
autoField {
|
autoField {
|
||||||
@ -43,6 +49,13 @@ fun main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
measureAndPrint("Nd4j specialized addition") {
|
||||||
|
nd4jField {
|
||||||
|
var res = one
|
||||||
|
repeat(n) { res += 1.0 as Number }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
measureAndPrint("Lazy addition") {
|
measureAndPrint("Lazy addition") {
|
||||||
val res = specializedField.one.mapAsync(GlobalScope) {
|
val res = specializedField.one.mapAsync(GlobalScope) {
|
||||||
var c = 0.0
|
var c = 0.0
|
||||||
|
@ -6,14 +6,14 @@ import kscience.kmath.operations.*
|
|||||||
* [Algebra] over [MST] nodes.
|
* [Algebra] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstAlgebra : NumericAlgebra<MST> {
|
public object MstAlgebra : NumericAlgebra<MST> {
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST =
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
|
||||||
MST.Unary(operation, arg)
|
MST.Unary(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MST.Binary(operation, left, right)
|
MST.Binary(operation, left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,97 +21,100 @@ public object MstAlgebra : NumericAlgebra<MST> {
|
|||||||
* [Space] over [MST] nodes.
|
* [Space] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST = number(0.0)
|
override val zero: MST.Numeric by lazy { number(0.0) }
|
||||||
|
|
||||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MstAlgebra.binaryOperation(operation, left, right)
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Ring] over [MST] nodes.
|
* [Ring] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST
|
override val zero: MST.Numeric
|
||||||
get() = MstSpace.zero
|
get() = MstSpace.zero
|
||||||
override val one: MST = number(1.0)
|
|
||||||
|
|
||||||
override fun number(value: Number): MST = MstSpace.number(value)
|
override val one: MST.Numeric by lazy { number(1.0) }
|
||||||
override fun symbol(value: String): MST = MstSpace.symbol(value)
|
|
||||||
override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
|
|
||||||
|
|
||||||
override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
|
override fun number(value: Number): MST.Numeric = MstSpace.number(value)
|
||||||
|
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
||||||
|
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
||||||
|
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
||||||
|
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
|
||||||
MstSpace.binaryOperation(operation, left, right)
|
MstSpace.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Field] over [MST] nodes.
|
* [Field] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstField : Field<MST> {
|
public object MstField : Field<MST> {
|
||||||
public override val zero: MST
|
public override val zero: MST.Numeric
|
||||||
get() = MstRing.zero
|
get() = MstRing.zero
|
||||||
|
|
||||||
public override val one: MST
|
public override val one: MST.Numeric
|
||||||
get() = MstRing.one
|
get() = MstRing.one
|
||||||
|
|
||||||
public override fun symbol(value: String): MST = MstRing.symbol(value)
|
public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value)
|
||||||
public override fun number(value: Number): MST = MstRing.number(value)
|
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
||||||
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
|
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||||
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
|
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
|
||||||
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
|
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||||
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MstRing.binaryOperation(operation, left, right)
|
MstRing.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [ExtendedField] over [MST] nodes.
|
* [ExtendedField] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstExtendedField : ExtendedField<MST> {
|
public object MstExtendedField : ExtendedField<MST> {
|
||||||
override val zero: MST
|
override val zero: MST.Numeric
|
||||||
get() = MstField.zero
|
get() = MstField.zero
|
||||||
|
|
||||||
override val one: MST
|
override val one: MST.Numeric
|
||||||
get() = MstField.one
|
get() = MstField.one
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MstField.symbol(value)
|
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
|
||||||
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
override fun number(value: Number): MST.Numeric = MstField.number(value)
|
||||||
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||||
override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||||
override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
||||||
override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||||
override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||||
override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||||
override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
||||||
override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
||||||
override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
||||||
override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
||||||
override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
||||||
override fun add(a: MST, b: MST): MST = MstField.add(a, b)
|
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
||||||
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
|
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||||
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
|
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
|
||||||
override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
|
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||||
override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||||
override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
|
||||||
override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun power(arg: MST, pow: Number): MST.Binary =
|
||||||
|
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||||
|
|
||||||
|
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||||
|
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MstField.binaryOperation(operation, left, right)
|
MstField.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ import kotlin.contracts.contract
|
|||||||
* @property mst the [MST] node.
|
* @property mst the [MST] node.
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
||||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||||
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||||
@ -21,8 +21,9 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
algebra.binaryOperation(operation, left, right)
|
algebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun number(value: Number): T = if (algebra is NumericAlgebra)
|
@Suppress("UNCHECKED_CAST")
|
||||||
algebra.number(value)
|
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||||
|
(algebra as NumericAlgebra<T>).number(value)
|
||||||
else
|
else
|
||||||
error("Numeric nodes are not supported by $this")
|
error("Numeric nodes are not supported by $this")
|
||||||
}
|
}
|
||||||
@ -38,14 +39,14 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||||
mstAlgebra: E,
|
mstAlgebra: E,
|
||||||
block: E.() -> MST,
|
block: E.() -> MST,
|
||||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [Space].
|
* Builds [MstExpression] over [Space].
|
||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Space<T>> A.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return MstExpression(this, MstSpace.block())
|
return MstExpression(this, MstSpace.block())
|
||||||
}
|
}
|
||||||
@ -55,7 +56,7 @@ public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MS
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return MstExpression(this, MstRing.block())
|
return MstExpression(this, MstRing.block())
|
||||||
}
|
}
|
||||||
@ -65,7 +66,7 @@ public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST):
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return MstExpression(this, MstField.block())
|
return MstExpression(this, MstField.block())
|
||||||
}
|
}
|
||||||
@ -75,7 +76,7 @@ public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MS
|
|||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return MstExpression(this, MstExtendedField.block())
|
return MstExpression(this, MstExtendedField.block())
|
||||||
}
|
}
|
||||||
@ -85,7 +86,7 @@ public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtend
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInSpace(block)
|
return algebra.mstInSpace(block)
|
||||||
}
|
}
|
||||||
@ -95,7 +96,7 @@ public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInRing(block)
|
return algebra.mstInRing(block)
|
||||||
}
|
}
|
||||||
@ -105,7 +106,7 @@ public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInField(block)
|
return algebra.mstInField(block)
|
||||||
}
|
}
|
||||||
@ -117,7 +118,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
|
|||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||||
block: MstExtendedField.() -> MST,
|
block: MstExtendedField.() -> MST,
|
||||||
): MstExpression<T> {
|
): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInExtendedField(block)
|
return algebra.mstInExtendedField(block)
|
||||||
}
|
}
|
||||||
|
@ -69,4 +69,5 @@ public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik.
|
* @author Alexander Nozik.
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class.java, algebra)
|
public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
|
||||||
|
mst.compileWith(T::class.java, algebra)
|
||||||
|
@ -12,16 +12,22 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
|||||||
*/
|
*/
|
||||||
public class DerivativeStructureField(
|
public class DerivativeStructureField(
|
||||||
public val order: Int,
|
public val order: Int,
|
||||||
private val bindings: Map<Symbol, Double>
|
bindings: Map<Symbol, Double>,
|
||||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
|
public val numberOfVariables: Int = bindings.size
|
||||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
|
|
||||||
|
public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
||||||
|
public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class that implements both [DerivativeStructure] and a [Symbol]
|
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||||
*/
|
*/
|
||||||
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
public inner class DerivativeStructureSymbol(
|
||||||
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
size: Int,
|
||||||
|
index: Int,
|
||||||
|
symbol: Symbol,
|
||||||
|
value: Double,
|
||||||
|
) : DerivativeStructure(size, order, index, value), Symbol {
|
||||||
override val identity: String = symbol.identity
|
override val identity: String = symbol.identity
|
||||||
override fun toString(): String = identity
|
override fun toString(): String = identity
|
||||||
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
@ -31,27 +37,26 @@ public class DerivativeStructureField(
|
|||||||
/**
|
/**
|
||||||
* Identity-based symbol bindings map
|
* Identity-based symbol bindings map
|
||||||
*/
|
*/
|
||||||
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
|
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.mapIndexed { index, (key, value) ->
|
||||||
key.identity to DerivativeStructureSymbol(key, value)
|
key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value)
|
||||||
}
|
}.toMap()
|
||||||
|
|
||||||
override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value)
|
override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value)
|
||||||
|
|
||||||
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||||
|
|
||||||
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
||||||
|
|
||||||
//public fun Number.const(): DerivativeStructure = const(toDouble())
|
override fun symbol(value: String): DerivativeStructureSymbol = bind(StringSymbol(value))
|
||||||
|
|
||||||
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
public fun DerivativeStructure.derivative(symbols: List<Symbol>): Double {
|
||||||
return derivative(mapOf(parameter to order))
|
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
||||||
|
val ordersCount = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||||
|
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double {
|
public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList())
|
||||||
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
|
|
||||||
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
||||||
|
|
||||||
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
||||||
@ -90,26 +95,27 @@ public class DerivativeStructureField(
|
|||||||
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||||
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||||
|
|
||||||
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
public companion object :
|
||||||
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
|
AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> {
|
||||||
return DerivativeStructureExpression(function)
|
public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double, Expression<Double>> =
|
||||||
}
|
DerivativeStructureExpression(function)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A constructs that creates a derivative structure with required order on-demand
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
*/
|
*/
|
||||||
public class DerivativeStructureExpression(
|
public class DerivativeStructureExpression(
|
||||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||||
) : DifferentiableExpression<Double> {
|
) : DifferentiableExpression<Double, Expression<Double>> {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
DerivativeStructureField(0, arguments).function().value
|
DerivativeStructureField(0, arguments).function().value
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the derivative expression with given orders
|
* Get the derivative expression with given orders
|
||||||
*/
|
*/
|
||||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
|
||||||
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,9 +19,8 @@ import kotlin.reflect.KClass
|
|||||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||||
public operator fun PointValuePair.component2(): Double = value
|
public operator fun PointValuePair.component2(): Double = value
|
||||||
|
|
||||||
public class CMOptimizationProblem(
|
public class CMOptimizationProblem(override val symbols: List<Symbol>, ) :
|
||||||
override val symbols: List<Symbol>,
|
OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
|
||||||
) : OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
|
|
||||||
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||||
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
||||||
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
||||||
@ -49,7 +48,7 @@ public class CMOptimizationProblem(
|
|||||||
addOptimizationData(objectiveFunction)
|
addOptimizationData(objectiveFunction)
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit {
|
public override fun diffExpression(expression: DifferentiableExpression<Double, Expression<Double>>) {
|
||||||
expression(expression)
|
expression(expression)
|
||||||
val gradientFunction = ObjectiveFunctionGradient {
|
val gradientFunction = ObjectiveFunctionGradient {
|
||||||
val args = it.toMap()
|
val args = it.toMap()
|
||||||
|
@ -12,7 +12,6 @@ import kscience.kmath.structures.asBuffer
|
|||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
*/
|
*/
|
||||||
@ -21,7 +20,7 @@ public fun Fitting.chiSquared(
|
|||||||
y: Buffer<Double>,
|
y: Buffer<Double>,
|
||||||
yErr: Buffer<Double>,
|
yErr: Buffer<Double>,
|
||||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
): DifferentiableExpression<Double> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
@ -31,7 +30,7 @@ public fun Fitting.chiSquared(
|
|||||||
y: Iterable<Double>,
|
y: Iterable<Double>,
|
||||||
yErr: Iterable<Double>,
|
yErr: Iterable<Double>,
|
||||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
): DifferentiableExpression<Double> = chiSquared(
|
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(
|
||||||
DerivativeStructureField,
|
DerivativeStructureField,
|
||||||
x.toList().asBuffer(),
|
x.toList().asBuffer(),
|
||||||
y.toList().asBuffer(),
|
y.toList().asBuffer(),
|
||||||
@ -39,7 +38,6 @@ public fun Fitting.chiSquared(
|
|||||||
model
|
model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize expression without derivatives
|
* Optimize expression without derivatives
|
||||||
*/
|
*/
|
||||||
@ -48,16 +46,15 @@ public fun Expression<Double>.optimize(
|
|||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize differentiable expression
|
* Optimize differentiable expression
|
||||||
*/
|
*/
|
||||||
public fun DifferentiableExpression<Double>.optimize(
|
public fun DifferentiableExpression<Double, Expression<Double>>.optimize(
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
|
|
||||||
public fun DifferentiableExpression<Double>.minimize(
|
public fun DifferentiableExpression<Double, Expression<Double>>.minimize(
|
||||||
vararg startPoint: Pair<Symbol, Double>,
|
vararg startPoint: Pair<Symbol, Double>,
|
||||||
configuration: CMOptimizationProblem.() -> Unit = {},
|
configuration: CMOptimizationProblem.() -> Unit = {},
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> {
|
||||||
|
@ -5,14 +5,15 @@ import kotlin.contracts.InvocationKind
|
|||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertFails
|
||||||
|
|
||||||
internal inline fun <R> diff(
|
internal inline fun diff(
|
||||||
order: Int,
|
order: Int,
|
||||||
vararg parameters: Pair<Symbol, Double>,
|
vararg parameters: Pair<Symbol, Double>,
|
||||||
block: DerivativeStructureField.() -> R,
|
block: DerivativeStructureField.() -> Unit,
|
||||||
): R {
|
): Unit {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AutoDiffTest {
|
internal class AutoDiffTest {
|
||||||
@ -21,13 +22,16 @@ internal class AutoDiffTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun derivativeStructureFieldTest() {
|
fun derivativeStructureFieldTest() {
|
||||||
val res: Double = diff(3, x to 1.0, y to 1.0) {
|
diff(2, x to 1.0, y to 1.0) {
|
||||||
val x = bind(x)//by binding()
|
val x = bind(x)//by binding()
|
||||||
val y = symbol("y")
|
val y = symbol("y")
|
||||||
val z = x * (-sin(x * y) + y)
|
val z = x * (-sin(x * y) + y) + 2.0
|
||||||
z.derivative(x)
|
println(z.derivative(x))
|
||||||
|
println(z.derivative(y,x))
|
||||||
|
assertEquals(z.derivative(x, y), z.derivative(y, x))
|
||||||
|
//check that improper order cause failure
|
||||||
|
assertFails { z.derivative(x,x,y) }
|
||||||
}
|
}
|
||||||
println(res)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -40,5 +44,7 @@ internal class AutoDiffTest {
|
|||||||
|
|
||||||
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||||
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||||
|
assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0))
|
||||||
|
assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,6 @@ import kscience.kmath.stat.Distribution
|
|||||||
import kscience.kmath.stat.Fitting
|
import kscience.kmath.stat.Fitting
|
||||||
import kscience.kmath.stat.RandomGenerator
|
import kscience.kmath.stat.RandomGenerator
|
||||||
import kscience.kmath.stat.normal
|
import kscience.kmath.stat.normal
|
||||||
import kscience.kmath.structures.asBuffer
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
@ -48,14 +47,17 @@ internal class OptimizeTest {
|
|||||||
val sigma = 1.0
|
val sigma = 1.0
|
||||||
val generator = Distribution.normal(0.0, sigma)
|
val generator = Distribution.normal(0.0, sigma)
|
||||||
val chain = generator.sample(RandomGenerator.default(112667))
|
val chain = generator.sample(RandomGenerator.default(112667))
|
||||||
val x = (1..100).map { it.toDouble() }
|
val x = (1..100).map(Int::toDouble)
|
||||||
val y = x.map { it ->
|
|
||||||
|
val y = x.map {
|
||||||
it.pow(2) + it + 1 + chain.nextDouble()
|
it.pow(2) + it + 1 + chain.nextDouble()
|
||||||
}
|
}
|
||||||
val yErr = x.map { sigma }
|
|
||||||
val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
|
val yErr = List(x.size) { sigma }
|
||||||
|
|
||||||
|
val chi2 = Fitting.chiSquared(x, y, yErr) { x1 ->
|
||||||
val cWithDefault = bindOrNull(c) ?: one
|
val cWithDefault = bindOrNull(c) ?: one
|
||||||
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
bind(a) * x1.pow(2) + bind(b) * x1 + cWithDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
||||||
|
@ -12,7 +12,7 @@ The core features of KMath:
|
|||||||
|
|
||||||
> #### Artifact:
|
> #### Artifact:
|
||||||
>
|
>
|
||||||
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
|
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-3`.
|
||||||
>
|
>
|
||||||
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
||||||
>
|
>
|
||||||
@ -30,7 +30,7 @@ The core features of KMath:
|
|||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
|
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-3'
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
> **Gradle Kotlin DSL:**
|
> **Gradle Kotlin DSL:**
|
||||||
@ -44,6 +44,6 @@ The core features of KMath:
|
|||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
|
> implementation("kscience.kmath:kmath-core:0.2.0-dev-3")
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import ru.mipt.npm.gradle.Maturity
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("ru.mipt.npm.mpp")
|
id("ru.mipt.npm.mpp")
|
||||||
id("ru.mipt.npm.native")
|
id("ru.mipt.npm.native")
|
||||||
@ -11,33 +13,39 @@ kotlin.sourceSets.commonMain {
|
|||||||
|
|
||||||
readme {
|
readme {
|
||||||
description = "Core classes, algebra definitions, basic linear algebra"
|
description = "Core classes, algebra definitions, basic linear algebra"
|
||||||
maturity = ru.mipt.npm.gradle.Maturity.DEVELOPMENT
|
maturity = Maturity.DEVELOPMENT
|
||||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "algebras",
|
id = "algebras",
|
||||||
description = "Algebraic structures: contexts and elements",
|
description = "Algebraic structures: contexts and elements",
|
||||||
ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt"
|
ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt"
|
||||||
)
|
)
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "nd",
|
id = "nd",
|
||||||
description = "Many-dimensional structures",
|
description = "Many-dimensional structures",
|
||||||
ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt"
|
ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt"
|
||||||
)
|
)
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "buffers",
|
id = "buffers",
|
||||||
description = "One-dimensional structure",
|
description = "One-dimensional structure",
|
||||||
ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt"
|
ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt"
|
||||||
)
|
)
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "expressions",
|
id = "expressions",
|
||||||
description = "Functional Expressions",
|
description = "Functional Expressions",
|
||||||
ref = "src/commonMain/kotlin/kscience/kmath/expressions"
|
ref = "src/commonMain/kotlin/kscience/kmath/expressions"
|
||||||
)
|
)
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "domains",
|
id = "domains",
|
||||||
description = "Domains",
|
description = "Domains",
|
||||||
ref = "src/commonMain/kotlin/kscience/kmath/domains"
|
ref = "src/commonMain/kotlin/kscience/kmath/domains"
|
||||||
)
|
)
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "autodif",
|
id = "autodif",
|
||||||
description = "Automatic differentiation",
|
description = "Automatic differentiation",
|
||||||
|
@ -1,32 +1,41 @@
|
|||||||
package kscience.kmath.expressions
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An expression that provides derivatives
|
* Represents expression which structure can be differentiated.
|
||||||
|
*
|
||||||
|
* @param T the type this expression takes as argument and returns.
|
||||||
|
* @param R the type of expression this expression can be differentiated to.
|
||||||
*/
|
*/
|
||||||
public interface DifferentiableExpression<T> : Expression<T>{
|
public interface DifferentiableExpression<T, out R : Expression<T>> : Expression<T> {
|
||||||
public fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>?
|
/**
|
||||||
|
* Differentiates this expression by ordered collection of [symbols].
|
||||||
|
*
|
||||||
|
* @param symbols the symbols.
|
||||||
|
* @return the derivative or `null`.
|
||||||
|
*/
|
||||||
|
public fun derivativeOrNull(symbols: List<Symbol>): R?
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(orders: Map<Symbol, Int>): Expression<T> =
|
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R =
|
||||||
derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided")
|
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R =
|
||||||
derivative(mapOf(*orders))
|
derivative(symbols.toList())
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(name: String): R =
|
||||||
|
derivative(StringSymbol(name))
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
|
||||||
derivative(StringSymbol(name) to 1)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A [DifferentiableExpression] that defines only first derivatives
|
* A [DifferentiableExpression] that defines only first derivatives
|
||||||
*/
|
*/
|
||||||
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
public abstract class FirstDerivativeExpression<T, R : Expression<T>> : DifferentiableExpression<T,R> {
|
||||||
|
/**
|
||||||
|
* Returns first derivative of this expression by given [symbol].
|
||||||
|
*/
|
||||||
|
public abstract fun derivativeOrNull(symbol: Symbol): R?
|
||||||
|
|
||||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
public final override fun derivativeOrNull(symbols: List<Symbol>): R? {
|
||||||
|
val dSymbol = symbols.firstOrNull() ?: return null
|
||||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>? {
|
|
||||||
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
|
|
||||||
return derivativeOrNull(dSymbol)
|
return derivativeOrNull(dSymbol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -34,6 +43,6 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
|
|||||||
/**
|
/**
|
||||||
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
||||||
*/
|
*/
|
||||||
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
|
public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
|
||||||
public fun process(function: A.() -> I): DifferentiableExpression<T>
|
public fun process(function: A.() -> I): DifferentiableExpression<T, R>
|
||||||
}
|
}
|
@ -3,6 +3,7 @@ package kscience.kmath.expressions
|
|||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
|
import kotlin.reflect.KProperty
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A marker interface for a symbol. A symbol mus have an identity
|
* A marker interface for a symbol. A symbol mus have an identity
|
||||||
@ -12,6 +13,13 @@ public interface Symbol {
|
|||||||
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
|
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
|
||||||
*/
|
*/
|
||||||
public val identity: String
|
public val identity: String
|
||||||
|
|
||||||
|
public companion object : ReadOnlyProperty<Any?, Symbol> {
|
||||||
|
//TODO deprecate and replace by top level function after fix of https://youtrack.jetbrains.com/issue/KT-40121
|
||||||
|
override fun getValue(thisRef: Any?, property: KProperty<*>): Symbol {
|
||||||
|
return StringSymbol(property.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -22,7 +30,9 @@ public inline class StringSymbol(override val identity: String) : Symbol {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An elementary function that could be invoked on a map of arguments
|
* An elementary function that could be invoked on a map of arguments.
|
||||||
|
*
|
||||||
|
* @param T the type this expression takes as argument and returns.
|
||||||
*/
|
*/
|
||||||
public fun interface Expression<T> {
|
public fun interface Expression<T> {
|
||||||
/**
|
/**
|
||||||
@ -35,20 +45,27 @@ public fun interface Expression<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Invoke an expression without parameters
|
* Calls this expression without providing any arguments.
|
||||||
|
*
|
||||||
|
* @return a value.
|
||||||
*/
|
*/
|
||||||
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
||||||
//This method exists to avoid resolution ambiguity of vararg methods
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls this expression from arguments.
|
* Calls this expression from arguments.
|
||||||
*
|
*
|
||||||
* @param pairs the pair of arguments' names to values.
|
* @param pairs the pairs of arguments to values.
|
||||||
* @return the value.
|
* @return a value.
|
||||||
*/
|
*/
|
||||||
@JvmName("callBySymbol")
|
@JvmName("callBySymbol")
|
||||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression from arguments.
|
||||||
|
*
|
||||||
|
* @param pairs the pairs of arguments' names to values.
|
||||||
|
* @return a value.
|
||||||
|
*/
|
||||||
@JvmName("callByString")
|
@JvmName("callByString")
|
||||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||||
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||||
@ -61,7 +78,6 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
|||||||
* @param E type of the actual expression state
|
* @param E type of the actual expression state
|
||||||
*/
|
*/
|
||||||
public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
|
* Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
|
||||||
*/
|
*/
|
||||||
@ -87,9 +103,9 @@ public fun <T, E> ExpressionAlgebra<T, E>.bind(symbol: Symbol): E =
|
|||||||
/**
|
/**
|
||||||
* A delegate to create a symbol with a string identity in this scope
|
* A delegate to create a symbol with a string identity in this scope
|
||||||
*/
|
*/
|
||||||
public val symbol: ReadOnlyProperty<Any?, StringSymbol> = ReadOnlyProperty { thisRef, property ->
|
public val symbol: ReadOnlyProperty<Any?, Symbol> get() = Symbol
|
||||||
StringSymbol(property.name)
|
//TODO does not work directly on native due to https://youtrack.jetbrains.com/issue/KT-40121
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bind a symbol by name inside the [ExpressionAlgebra]
|
* Bind a symbol by name inside the [ExpressionAlgebra]
|
||||||
|
@ -68,7 +68,7 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
|||||||
): DerivationResult<T> {
|
): DerivationResult<T> {
|
||||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
return SimpleAutoDiffField(this, bindings).derivate(body)
|
return SimpleAutoDiffField(this, bindings).differentiate(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
@ -83,12 +83,21 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
public val context: F,
|
public val context: F,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||||
|
public override val zero: AutoDiffValue<T>
|
||||||
|
get() = const(context.zero)
|
||||||
|
|
||||||
|
public override val one: AutoDiffValue<T>
|
||||||
|
get() = const(context.one)
|
||||||
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
private var sp: Int = 0
|
private var sp: Int = 0
|
||||||
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||||
|
|
||||||
|
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||||
|
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||||
* with respect to this variable.
|
* with respect to this variable.
|
||||||
@ -106,11 +115,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
override fun hashCode(): Int = identity.hashCode()
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
public override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
||||||
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
|
||||||
|
|
||||||
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
||||||
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
||||||
@ -119,7 +124,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
private fun runBackwardPass() {
|
private fun runBackwardPass() {
|
||||||
while (sp > 0) {
|
while (sp > 0) {
|
||||||
@ -129,9 +133,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
|
||||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
|
||||||
|
|
||||||
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -165,7 +166,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||||
val result = function()
|
val result = function()
|
||||||
result.d = context.one // computing derivative w.r.t result
|
result.d = context.one // computing derivative w.r.t result
|
||||||
runBackwardPass()
|
runBackwardPass()
|
||||||
@ -174,41 +175,41 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
|
|
||||||
// Overloads for Double constants
|
// Overloads for Double constants
|
||||||
|
|
||||||
override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
||||||
b.d += z.d
|
b.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
|
public override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
|
||||||
|
|
||||||
override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||||
|
|
||||||
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
public override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
||||||
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||||
|
|
||||||
|
|
||||||
// Basic math (+, -, *, /)
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value + b.value }) { z ->
|
derive(const { a.value + b.value }) { z ->
|
||||||
a.d += z.d
|
a.d += z.d
|
||||||
b.d += z.d
|
b.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value * b.value }) { z ->
|
derive(const { a.value * b.value }) { z ->
|
||||||
a.d += z.d * b.value
|
a.d += z.d * b.value
|
||||||
b.d += z.d * a.value
|
b.d += z.d * a.value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value / b.value }) { z ->
|
derive(const { a.value / b.value }) { z ->
|
||||||
a.d += z.d / b.value
|
a.d += z.d / b.value
|
||||||
b.d -= z.d * a.value / (b.value * b.value)
|
b.d -= z.d * a.value / (b.value * b.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
|
public override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
|
||||||
derive(const { k.toDouble() * a.value }) { z ->
|
derive(const { k.toDouble() * a.value }) { z ->
|
||||||
a.d += z.d * k.toDouble()
|
a.d += z.d * k.toDouble()
|
||||||
}
|
}
|
||||||
@ -220,15 +221,15 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
public val field: F,
|
public val field: F,
|
||||||
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
) : FirstDerivativeExpression<T>() {
|
) : FirstDerivativeExpression<T, Expression<T>>() {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
return SimpleAutoDiffField(field, arguments).function().value
|
return SimpleAutoDiffField(field, arguments).function().value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
public override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
|
val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
|
||||||
derivationResult.derivative(symbol)
|
derivationResult.derivative(symbol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -236,13 +237,10 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
|||||||
/**
|
/**
|
||||||
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>, Expression<T>> =
|
||||||
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
AutoDiffProcessor { function ->
|
||||||
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
|
SimpleAutoDiffExpression(field, function)
|
||||||
return SimpleAutoDiffExpression(field, function)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
|
@ -74,9 +74,9 @@ public interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : Math
|
|||||||
/**
|
/**
|
||||||
* The element of [Ring].
|
* The element of [Ring].
|
||||||
*
|
*
|
||||||
* @param T the type of space operation results.
|
* @param T the type of ring operation results.
|
||||||
* @param I self type of the element. Needed for static type checking.
|
* @param I self type of the element. Needed for static type checking.
|
||||||
* @param R the type of space.
|
* @param R the type of ring.
|
||||||
*/
|
*/
|
||||||
public interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
|
public interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
|
||||||
/**
|
/**
|
||||||
@ -91,7 +91,7 @@ public interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceEl
|
|||||||
/**
|
/**
|
||||||
* The element of [Field].
|
* The element of [Field].
|
||||||
*
|
*
|
||||||
* @param T the type of space operation results.
|
* @param T the type of field operation results.
|
||||||
* @param I self type of the element. Needed for static type checking.
|
* @param I self type of the element. Needed for static type checking.
|
||||||
* @param F the type of field.
|
* @param F the type of field.
|
||||||
*/
|
*/
|
||||||
|
@ -73,7 +73,7 @@ public interface NDAlgebra<T, C, N : NDStructure<T>> {
|
|||||||
public fun check(vararg elements: N): Array<out N> = elements
|
public fun check(vararg elements: N): Array<out N> = elements
|
||||||
.map(NDStructure<T>::shape)
|
.map(NDStructure<T>::shape)
|
||||||
.singleOrNull { !shape.contentEquals(it) }
|
.singleOrNull { !shape.contentEquals(it) }
|
||||||
?.let { throw ShapeMismatchException(shape, it) }
|
?.let<IntArray, Array<out N>> { throw ShapeMismatchException(shape, it) }
|
||||||
?: elements
|
?: elements
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5,19 +5,19 @@ package kscience.kmath.structures
|
|||||||
*
|
*
|
||||||
* @property array the underlying array.
|
* @property array the underlying array.
|
||||||
*/
|
*/
|
||||||
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
public inline class RealBuffer(public val array: DoubleArray) : MutableBuffer<Double> {
|
public inline class RealBuffer(public val array: DoubleArray) : MutableBuffer<Double> {
|
||||||
override val size: Int get() = array.size
|
override val size: Int get() = array.size
|
||||||
|
|
||||||
override operator fun get(index: Int): Double = array[index]
|
override inline operator fun get(index: Int): Double = array[index]
|
||||||
|
|
||||||
override operator fun set(index: Int, value: Double) {
|
override inline operator fun set(index: Int, value: Double) {
|
||||||
array[index] = value
|
array[index] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun iterator(): DoubleIterator = array.iterator()
|
override operator fun iterator(): DoubleIterator = array.iterator()
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Double> =
|
override fun copy(): RealBuffer = RealBuffer(array.copyOf())
|
||||||
RealBuffer(array.copyOf())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -34,6 +34,11 @@ public inline fun RealBuffer(size: Int, init: (Int) -> Double): RealBuffer = Rea
|
|||||||
*/
|
*/
|
||||||
public fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
|
public fun RealBuffer(vararg doubles: Double): RealBuffer = RealBuffer(doubles)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simplified [RealBuffer] to array comparison
|
||||||
|
*/
|
||||||
|
public fun RealBuffer.contentEquals(vararg doubles: Double): Boolean = array.contentEquals(doubles)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a [DoubleArray] containing all of the elements of this [MutableBuffer].
|
* Returns a [DoubleArray] containing all of the elements of this [MutableBuffer].
|
||||||
*/
|
*/
|
||||||
|
@ -14,8 +14,9 @@ import kotlin.math.sqrt
|
|||||||
|
|
||||||
public typealias RealPoint = Point<Double>
|
public typealias RealPoint = Point<Double>
|
||||||
|
|
||||||
public fun DoubleArray.asVector(): RealVector = RealVector(asBuffer())
|
public fun RealPoint.asVector(): RealVector = RealVector(this)
|
||||||
public fun List<Double>.asVector(): RealVector = RealVector(asBuffer())
|
public fun DoubleArray.asVector(): RealVector = asBuffer().asVector()
|
||||||
|
public fun List<Double>.asVector(): RealVector = asBuffer().asVector()
|
||||||
|
|
||||||
public object VectorL2Norm : Norm<Point<out Number>, Double> {
|
public object VectorL2Norm : Norm<Point<out Number>, Double> {
|
||||||
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble(Number::toDouble))
|
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble(Number::toDouble))
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package kscience.kmath.misc
|
package kscience.kmath.real
|
||||||
|
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -32,6 +34,9 @@ public fun ClosedFloatingPointRange<Double>.toSequenceWithStep(step: Double): Se
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public infix fun ClosedFloatingPointRange<Double>.step(step: Double): Point<Double> =
|
||||||
|
toSequenceWithStep(step).toList().asBuffer()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert double range to sequence with the fixed number of points
|
* Convert double range to sequence with the fixed number of points
|
||||||
*/
|
*/
|
||||||
@ -39,12 +44,3 @@ public fun ClosedFloatingPointRange<Double>.toSequenceWithPoints(numPoints: Int)
|
|||||||
require(numPoints > 1) { "The number of points should be more than 2" }
|
require(numPoints > 1) { "The number of points should be more than 2" }
|
||||||
return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1))
|
return toSequenceWithStep(abs(endInclusive - start) / (numPoints - 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
|
|
||||||
*/
|
|
||||||
@Deprecated("Replace by 'toSequenceWithPoints'")
|
|
||||||
public fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
|
|
||||||
require(numPoints >= 2) { "Can't create generic grid with less than two points" }
|
|
||||||
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
|
|
||||||
}
|
|
@ -1,8 +0,0 @@
|
|||||||
package kscience.kmath.real
|
|
||||||
|
|
||||||
import kscience.kmath.structures.RealBuffer
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Simplified [RealBuffer] to array comparison
|
|
||||||
*/
|
|
||||||
public fun RealBuffer.contentEquals(vararg doubles: Double): Boolean = array.contentEquals(doubles)
|
|
@ -0,0 +1,13 @@
|
|||||||
|
package kaceince.kmath.real
|
||||||
|
|
||||||
|
import kscience.kmath.real.step
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class GridTest {
|
||||||
|
@Test
|
||||||
|
fun testStepGrid(){
|
||||||
|
val grid = 0.0..1.0 step 0.2
|
||||||
|
assertEquals(6, grid.size)
|
||||||
|
}
|
||||||
|
}
|
@ -1,9 +1,10 @@
|
|||||||
package scientific.kmath.real
|
package kaceince.kmath.real
|
||||||
|
|
||||||
import kscience.kmath.linear.VirtualMatrix
|
import kscience.kmath.linear.VirtualMatrix
|
||||||
import kscience.kmath.linear.build
|
import kscience.kmath.linear.build
|
||||||
import kscience.kmath.real.*
|
import kscience.kmath.real.*
|
||||||
import kscience.kmath.structures.Matrix
|
import kscience.kmath.structures.Matrix
|
||||||
|
import kscience.kmath.structures.contentEquals
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
@ -1,11 +1,14 @@
|
|||||||
package kscience.kmath.linear
|
package kaceince.kmath.real
|
||||||
|
|
||||||
|
import kscience.kmath.linear.MatrixContext
|
||||||
|
import kscience.kmath.linear.asMatrix
|
||||||
|
import kscience.kmath.linear.transpose
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
import kscience.kmath.real.RealVector
|
import kscience.kmath.real.RealVector
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class VectorTest {
|
internal class RealVectorTest {
|
||||||
@Test
|
@Test
|
||||||
fun testSum() {
|
fun testSum() {
|
||||||
val vector1 = RealVector(5) { it.toDouble() }
|
val vector1 = RealVector(5) { it.toDouble() }
|
9
kmath-kotlingrad/build.gradle.kts
Normal file
9
kmath-kotlingrad/build.gradle.kts
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.jvm")
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("com.github.breandan:kaliningraph:0.1.2")
|
||||||
|
implementation("com.github.breandan:kotlingrad:0.3.7")
|
||||||
|
api(project(":kmath-ast"))
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.SFun
|
||||||
|
import kscience.kmath.ast.MST
|
||||||
|
import kscience.kmath.ast.MstAlgebra
|
||||||
|
import kscience.kmath.ast.MstExpression
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents wrapper of [MstExpression] implementing [DifferentiableExpression].
|
||||||
|
*
|
||||||
|
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting
|
||||||
|
* [SFun] back to [MST].
|
||||||
|
*
|
||||||
|
* @param T the type of number.
|
||||||
|
* @param A the [NumericAlgebra] of [T].
|
||||||
|
* @property expr the underlying [MstExpression].
|
||||||
|
*/
|
||||||
|
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
|
||||||
|
DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T>, T : Number {
|
||||||
|
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The [MstExpression.algebra] of [expr].
|
||||||
|
*/
|
||||||
|
public val algebra: A
|
||||||
|
get() = expr.algebra
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The [MstExpression.mst] of [expr].
|
||||||
|
*/
|
||||||
|
public val mst: MST
|
||||||
|
get() = expr.mst
|
||||||
|
|
||||||
|
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||||
|
|
||||||
|
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
|
||||||
|
algebra,
|
||||||
|
symbols.map(Symbol::identity)
|
||||||
|
.map(MstAlgebra::symbol)
|
||||||
|
.map { it.toSVar<KMathNumber<T, A>>() }
|
||||||
|
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
|
||||||
|
.toMst(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
||||||
|
*/
|
||||||
|
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
|
||||||
|
DifferentiableMstExpression(this)
|
@ -0,0 +1,18 @@
|
|||||||
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.RealNumber
|
||||||
|
import edu.umontreal.kotlingrad.experimental.SConst
|
||||||
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implements [RealNumber] by delegating its functionality to [NumericAlgebra].
|
||||||
|
*
|
||||||
|
* @param T the type of number.
|
||||||
|
* @param A the [NumericAlgebra] of [T].
|
||||||
|
* @property algebra the algebra.
|
||||||
|
* @param value the value of this number.
|
||||||
|
*/
|
||||||
|
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
||||||
|
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
||||||
|
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
||||||
|
}
|
@ -0,0 +1,124 @@
|
|||||||
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.*
|
||||||
|
import kscience.kmath.ast.MST
|
||||||
|
import kscience.kmath.ast.MstAlgebra
|
||||||
|
import kscience.kmath.ast.MstExtendedField
|
||||||
|
import kscience.kmath.ast.MstExtendedField.unaryMinus
|
||||||
|
import kscience.kmath.operations.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [SVar] to [MST.Symbolic] directly.
|
||||||
|
*
|
||||||
|
* @receiver the variable.
|
||||||
|
* @return a node.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> SVar<X>.toMst(): MST.Symbolic = MstAlgebra.symbol(name)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [SVar] to [MST.Numeric] directly.
|
||||||
|
*
|
||||||
|
* @receiver the constant.
|
||||||
|
* @return a node.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> SConst<X>.toMst(): MST.Numeric = MstAlgebra.number(doubleValue)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then.
|
||||||
|
* [Power] operation is limited to constant right-hand side arguments.
|
||||||
|
*
|
||||||
|
* Detailed mapping is:
|
||||||
|
*
|
||||||
|
* - [SVar] -> [MstExtendedField.symbol];
|
||||||
|
* - [SConst] -> [MstExtendedField.number];
|
||||||
|
* - [Sum] -> [MstExtendedField.add];
|
||||||
|
* - [Prod] -> [MstExtendedField.multiply];
|
||||||
|
* - [Power] -> [MstExtendedField.power] (limited to constant exponents only);
|
||||||
|
* - [Negative] -> [MstExtendedField.unaryMinus];
|
||||||
|
* - [Log] -> [MstExtendedField.ln] (left) / [MstExtendedField.ln] (right);
|
||||||
|
* - [Sine] -> [MstExtendedField.sin];
|
||||||
|
* - [Cosine] -> [MstExtendedField.cos];
|
||||||
|
* - [Tangent] -> [MstExtendedField.tan];
|
||||||
|
* - [DProd] is vector operation, and it is requested to be evaluated;
|
||||||
|
* - [SComposition] is also requested to be evaluated eagerly;
|
||||||
|
* - [VSumAll] is requested to be evaluated;
|
||||||
|
* - [Derivative] is requested to be evaluated.
|
||||||
|
*
|
||||||
|
* @receiver the scalar function.
|
||||||
|
* @return a node.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
|
||||||
|
when (this@toMst) {
|
||||||
|
is SVar -> toMst()
|
||||||
|
is SConst -> toMst()
|
||||||
|
is Sum -> left.toMst() + right.toMst()
|
||||||
|
is Prod -> left.toMst() * right.toMst()
|
||||||
|
is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue)
|
||||||
|
is Negative -> -input.toMst()
|
||||||
|
is Log -> ln(left.toMst()) / ln(right.toMst())
|
||||||
|
is Sine -> sin(input.toMst())
|
||||||
|
is Cosine -> cos(input.toMst())
|
||||||
|
is Tangent -> tan(input.toMst())
|
||||||
|
is DProd -> this@toMst().toMst()
|
||||||
|
is SComposition -> this@toMst().toMst()
|
||||||
|
is VSumAll<X, *> -> this@toMst().toMst()
|
||||||
|
is Derivative -> this@toMst().toMst()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [MST.Numeric] to [SConst] directly.
|
||||||
|
*
|
||||||
|
* @receiver the node.
|
||||||
|
* @return a new constant.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [MST.Symbolic] to [SVar] directly.
|
||||||
|
*
|
||||||
|
* @receiver the node.
|
||||||
|
* @param proto the prototype instance.
|
||||||
|
* @return a new variable.
|
||||||
|
*/
|
||||||
|
internal fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
||||||
|
*
|
||||||
|
* Detailed mapping is:
|
||||||
|
*
|
||||||
|
* - [MST.Numeric] -> [SConst];
|
||||||
|
* - [MST.Symbolic] -> [SVar];
|
||||||
|
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
|
||||||
|
* - [MST.Binary] -> [Sum], [Prod], [Power].
|
||||||
|
*
|
||||||
|
* @receiver the node.
|
||||||
|
* @param proto the prototype instance.
|
||||||
|
* @return a scalar function.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||||
|
is MST.Numeric -> toSConst()
|
||||||
|
is MST.Symbolic -> toSVar()
|
||||||
|
|
||||||
|
is MST.Unary -> when (operation) {
|
||||||
|
SpaceOperations.PLUS_OPERATION -> +value.toSFun<X>()
|
||||||
|
SpaceOperations.MINUS_OPERATION -> -value.toSFun<X>()
|
||||||
|
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
|
||||||
|
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
|
||||||
|
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
|
||||||
|
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
|
||||||
|
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
|
||||||
|
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
|
||||||
|
else -> error("Unary operation $operation not defined in $this")
|
||||||
|
}
|
||||||
|
|
||||||
|
is MST.Binary -> when (operation) {
|
||||||
|
SpaceOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
|
||||||
|
SpaceOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
|
||||||
|
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
|
||||||
|
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
|
||||||
|
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
|
||||||
|
else -> error("Binary operation $operation not defined in $this")
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,64 @@
|
|||||||
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.*
|
||||||
|
import kscience.kmath.asm.compile
|
||||||
|
import kscience.kmath.ast.MstAlgebra
|
||||||
|
import kscience.kmath.ast.MstExpression
|
||||||
|
import kscience.kmath.ast.parseMath
|
||||||
|
import kscience.kmath.expressions.invoke
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
|
internal class AdaptingTests {
|
||||||
|
@Test
|
||||||
|
fun symbol() {
|
||||||
|
val c1 = MstAlgebra.symbol("x")
|
||||||
|
assertTrue(c1.toSVar<KMathNumber<Double, RealField>>().name == "x")
|
||||||
|
val c2 = "kitten".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun number() {
|
||||||
|
val c1 = MstAlgebra.number(12354324)
|
||||||
|
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
|
||||||
|
val c2 = "0.234".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
||||||
|
val c3 = "1e-3".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun simpleFunctionShape() {
|
||||||
|
val linear = "2*x+16".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
if (linear !is Sum) fail()
|
||||||
|
if (linear.left !is Prod) fail()
|
||||||
|
if (linear.right !is SConst) fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun simpleFunctionDerivative() {
|
||||||
|
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
|
||||||
|
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||||
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
|
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun moreComplexDerivative() {
|
||||||
|
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
|
||||||
|
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
|
||||||
|
|
||||||
|
val expectedDerivative = MstExpression(
|
||||||
|
RealField,
|
||||||
|
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
|
||||||
|
).compile()
|
||||||
|
|
||||||
|
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
|
||||||
|
}
|
||||||
|
}
|
82
kmath-nd4j/README.md
Normal file
82
kmath-nd4j/README.md
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# ND4J NDStructure implementation (`kmath-nd4j`)
|
||||||
|
|
||||||
|
This subproject implements the following features:
|
||||||
|
|
||||||
|
- [nd4jarraystrucure](src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt) : NDStructure wrapper for INDArray
|
||||||
|
- [nd4jarrayrings](src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt) : Rings over Nd4jArrayStructure of Int and Long
|
||||||
|
- [nd4jarrayfields](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : Fields over Nd4jArrayStructure of Float and Double
|
||||||
|
|
||||||
|
|
||||||
|
> #### Artifact:
|
||||||
|
>
|
||||||
|
> This module artifact: `kscience.kmath:kmath-nd4j:0.2.0-dev-3`.
|
||||||
|
>
|
||||||
|
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-nd4j/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-nd4j/_latestVersion)
|
||||||
|
>
|
||||||
|
> Bintray development version: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-nd4j/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-nd4j/_latestVersion)
|
||||||
|
>
|
||||||
|
> **Gradle:**
|
||||||
|
>
|
||||||
|
> ```gradle
|
||||||
|
> repositories {
|
||||||
|
> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
|
||||||
|
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||||
|
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||||
|
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||||
|
|
||||||
|
> }
|
||||||
|
>
|
||||||
|
> dependencies {
|
||||||
|
> implementation 'kscience.kmath:kmath-nd4j:0.2.0-dev-3'
|
||||||
|
> }
|
||||||
|
> ```
|
||||||
|
> **Gradle Kotlin DSL:**
|
||||||
|
>
|
||||||
|
> ```kotlin
|
||||||
|
> repositories {
|
||||||
|
> maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
|
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
|
> }
|
||||||
|
>
|
||||||
|
> dependencies {
|
||||||
|
> implementation("kscience.kmath:kmath-nd4j:0.2.0-dev-3")
|
||||||
|
> }
|
||||||
|
> ```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
NDStructure wrapper for INDArray:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
import org.nd4j.linalg.factory.*
|
||||||
|
import scientifik.kmath.nd4j.*
|
||||||
|
import scientifik.kmath.structures.*
|
||||||
|
|
||||||
|
val array = Nd4j.ones(2, 2).asRealStructure()
|
||||||
|
println(array[0, 0]) // 1.0
|
||||||
|
array[intArrayOf(0, 0)] = 24.0
|
||||||
|
println(array[0, 0]) // 24.0
|
||||||
|
```
|
||||||
|
|
||||||
|
Fast element-wise and in-place arithmetics for INDArray:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
import org.nd4j.linalg.factory.*
|
||||||
|
import scientifik.kmath.nd4j.*
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
val field = RealNd4jArrayField(intArrayOf(2, 2))
|
||||||
|
val array = Nd4j.rand(2, 2).asRealStructure()
|
||||||
|
|
||||||
|
val res = field {
|
||||||
|
(25.0 / array + 20) * 4
|
||||||
|
}
|
||||||
|
|
||||||
|
println(res.ndArray)
|
||||||
|
// [[ 250.6449, 428.5840],
|
||||||
|
// [ 269.7913, 202.2077]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis).
|
37
kmath-nd4j/build.gradle.kts
Normal file
37
kmath-nd4j/build.gradle.kts
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import ru.mipt.npm.gradle.Maturity
|
||||||
|
|
||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.jvm")
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
api("org.nd4j:nd4j-api:1.0.0-beta7")
|
||||||
|
testImplementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7")
|
||||||
|
testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||||
|
testImplementation("org.slf4j:slf4j-simple:1.7.30")
|
||||||
|
}
|
||||||
|
|
||||||
|
readme {
|
||||||
|
description = "ND4J NDStructure implementation and according NDAlgebra classes"
|
||||||
|
maturity = Maturity.EXPERIMENTAL
|
||||||
|
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||||
|
|
||||||
|
feature(
|
||||||
|
id = "nd4jarraystructure",
|
||||||
|
description = "NDStructure wrapper for INDArray",
|
||||||
|
ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt"
|
||||||
|
)
|
||||||
|
|
||||||
|
feature(
|
||||||
|
id = "nd4jarrayrings",
|
||||||
|
description = "Rings over Nd4jArrayStructure of Int and Long",
|
||||||
|
ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt"
|
||||||
|
)
|
||||||
|
|
||||||
|
feature(
|
||||||
|
id = "nd4jarrayfields",
|
||||||
|
description = "Fields over Nd4jArrayStructure of Float and Double",
|
||||||
|
ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt"
|
||||||
|
)
|
||||||
|
}
|
43
kmath-nd4j/docs/README-TEMPLATE.md
Normal file
43
kmath-nd4j/docs/README-TEMPLATE.md
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# ND4J NDStructure implementation (`kmath-nd4j`)
|
||||||
|
|
||||||
|
This subproject implements the following features:
|
||||||
|
|
||||||
|
${features}
|
||||||
|
|
||||||
|
${artifact}
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
NDStructure wrapper for INDArray:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
import org.nd4j.linalg.factory.*
|
||||||
|
import scientifik.kmath.nd4j.*
|
||||||
|
import scientifik.kmath.structures.*
|
||||||
|
|
||||||
|
val array = Nd4j.ones(2, 2).asRealStructure()
|
||||||
|
println(array[0, 0]) // 1.0
|
||||||
|
array[intArrayOf(0, 0)] = 24.0
|
||||||
|
println(array[0, 0]) // 24.0
|
||||||
|
```
|
||||||
|
|
||||||
|
Fast element-wise and in-place arithmetics for INDArray:
|
||||||
|
|
||||||
|
```kotlin
|
||||||
|
import org.nd4j.linalg.factory.*
|
||||||
|
import scientifik.kmath.nd4j.*
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
val field = RealNd4jArrayField(intArrayOf(2, 2))
|
||||||
|
val array = Nd4j.rand(2, 2).asRealStructure()
|
||||||
|
|
||||||
|
val res = field {
|
||||||
|
(25.0 / array + 20) * 4
|
||||||
|
}
|
||||||
|
|
||||||
|
println(res.ndArray)
|
||||||
|
// [[ 250.6449, 428.5840],
|
||||||
|
// [ 269.7913, 202.2077]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis).
|
@ -0,0 +1,349 @@
|
|||||||
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
|
import kscience.kmath.operations.*
|
||||||
|
import kscience.kmath.structures.NDAlgebra
|
||||||
|
import kscience.kmath.structures.NDField
|
||||||
|
import kscience.kmath.structures.NDRing
|
||||||
|
import kscience.kmath.structures.NDSpace
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDAlgebra] over [Nd4jArrayAlgebra].
|
||||||
|
*
|
||||||
|
* @param T the type of ND-structure element.
|
||||||
|
* @param C the type of the element context.
|
||||||
|
*/
|
||||||
|
public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C, Nd4jArrayStructure<T>> {
|
||||||
|
/**
|
||||||
|
* Wraps [INDArray] to [N].
|
||||||
|
*/
|
||||||
|
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
||||||
|
|
||||||
|
public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
|
||||||
|
val struct = Nd4j.create(*shape)!!.wrap()
|
||||||
|
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
|
||||||
|
return struct
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun map(arg: Nd4jArrayStructure<T>, transform: C.(T) -> T): Nd4jArrayStructure<T> {
|
||||||
|
check(arg)
|
||||||
|
val newStruct = arg.ndArray.dup().wrap()
|
||||||
|
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
|
||||||
|
return newStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun mapIndexed(
|
||||||
|
arg: Nd4jArrayStructure<T>,
|
||||||
|
transform: C.(index: IntArray, T) -> T
|
||||||
|
): Nd4jArrayStructure<T> {
|
||||||
|
check(arg)
|
||||||
|
val new = Nd4j.create(*shape).wrap()
|
||||||
|
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) }
|
||||||
|
return new
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun combine(
|
||||||
|
a: Nd4jArrayStructure<T>,
|
||||||
|
b: Nd4jArrayStructure<T>,
|
||||||
|
transform: C.(T, T) -> T
|
||||||
|
): Nd4jArrayStructure<T> {
|
||||||
|
check(a, b)
|
||||||
|
val new = Nd4j.create(*shape).wrap()
|
||||||
|
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
||||||
|
return new
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDSpace] over [Nd4jArrayStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of the element contained in ND structure.
|
||||||
|
* @param S the type of space of structure elements.
|
||||||
|
*/
|
||||||
|
public interface Nd4jArraySpace<T, S> : NDSpace<T, S, Nd4jArrayStructure<T>>,
|
||||||
|
Nd4jArrayAlgebra<T, S> where S : Space<T> {
|
||||||
|
public override val zero: Nd4jArrayStructure<T>
|
||||||
|
get() = Nd4j.zeros(*shape).wrap()
|
||||||
|
|
||||||
|
public override fun add(a: Nd4jArrayStructure<T>, b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||||
|
check(a, b)
|
||||||
|
return a.ndArray.add(b.ndArray).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<T>.minus(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||||
|
check(this, b)
|
||||||
|
return ndArray.sub(b.ndArray).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<T>.unaryMinus(): Nd4jArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.neg().wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun multiply(a: Nd4jArrayStructure<T>, k: Number): Nd4jArrayStructure<T> {
|
||||||
|
check(a)
|
||||||
|
return a.ndArray.mul(k).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<T>.div(k: Number): Nd4jArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.div(k).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<T>.times(k: Number): Nd4jArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(k).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDRing] over [Nd4jArrayStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of the element contained in ND structure.
|
||||||
|
* @param R the type of ring of structure elements.
|
||||||
|
*/
|
||||||
|
public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4jArraySpace<T, R> where R : Ring<T> {
|
||||||
|
public override val one: Nd4jArrayStructure<T>
|
||||||
|
get() = Nd4j.ones(*shape).wrap()
|
||||||
|
|
||||||
|
public override fun multiply(a: Nd4jArrayStructure<T>, b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||||
|
check(a, b)
|
||||||
|
return a.ndArray.mul(b.ndArray).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(b).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<T>.plus(b: Number): Nd4jArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(b).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Number.minus(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||||
|
check(b)
|
||||||
|
return b.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
|
||||||
|
ThreadLocal.withInitial { hashMapOf() }
|
||||||
|
|
||||||
|
private val longNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, LongNd4jArrayRing>> =
|
||||||
|
ThreadLocal.withInitial { hashMapOf() }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [NDRing] for [Int] values or pull it from cache if it was created previously.
|
||||||
|
*/
|
||||||
|
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
|
||||||
|
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [NDRing] for [Long] values or pull it from cache if it was created previously.
|
||||||
|
*/
|
||||||
|
public fun long(vararg shape: Int): Nd4jArrayRing<Long, LongRing> =
|
||||||
|
longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a most suitable implementation of [NDRing] using reified class.
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayRing<T, out Ring<T>> = when {
|
||||||
|
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, out Ring<T>>
|
||||||
|
T::class == Long::class -> long(*shape) as Nd4jArrayRing<T, out Ring<T>>
|
||||||
|
else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDField] over [Nd4jArrayStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of the element contained in ND structure.
|
||||||
|
* @param N the type of ND structure.
|
||||||
|
* @param F the type field of structure elements.
|
||||||
|
*/
|
||||||
|
public interface Nd4jArrayField<T, F> : NDField<T, F, Nd4jArrayStructure<T>>, Nd4jArrayRing<T, F> where F : Field<T> {
|
||||||
|
public override fun divide(a: Nd4jArrayStructure<T>, b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||||
|
check(a, b)
|
||||||
|
return a.ndArray.div(b.ndArray).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Number.div(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||||
|
check(b)
|
||||||
|
return b.ndArray.rdiv(this).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
|
||||||
|
ThreadLocal.withInitial { hashMapOf() }
|
||||||
|
|
||||||
|
private val realNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, RealNd4jArrayField>> =
|
||||||
|
ThreadLocal.withInitial { hashMapOf() }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [NDField] for [Float] values or pull it from cache if it was created previously.
|
||||||
|
*/
|
||||||
|
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
|
||||||
|
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [NDField] for [Double] values or pull it from cache if it was created previously.
|
||||||
|
*/
|
||||||
|
public fun real(vararg shape: Int): Nd4jArrayRing<Double, RealField> =
|
||||||
|
realNd4jArrayFieldCache.get().getOrPut(shape) { RealNd4jArrayField(shape) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a most suitable implementation of [NDRing] using reified class.
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, out Field<T>> = when {
|
||||||
|
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, out Field<T>>
|
||||||
|
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, out Field<T>>
|
||||||
|
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDField] over [Nd4jArrayRealStructure].
|
||||||
|
*/
|
||||||
|
public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, RealField> {
|
||||||
|
public override val elementContext: RealField
|
||||||
|
get() = RealField
|
||||||
|
|
||||||
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = check(asRealStructure())
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Double>.div(arg: Double): Nd4jArrayStructure<Double> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.div(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Double>.plus(arg: Double): Nd4jArrayStructure<Double> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Double>.minus(arg: Double): Nd4jArrayStructure<Double> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Double>.times(arg: Double): Nd4jArrayStructure<Double> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Double.div(arg: Nd4jArrayStructure<Double>): Nd4jArrayStructure<Double> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rdiv(this).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Double.minus(arg: Nd4jArrayStructure<Double>): Nd4jArrayStructure<Double> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDField] over [Nd4jArrayStructure] of [Float].
|
||||||
|
*/
|
||||||
|
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField> {
|
||||||
|
public override val elementContext: FloatField
|
||||||
|
get() = FloatField
|
||||||
|
|
||||||
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = check(asFloatStructure())
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Float>.div(arg: Float): Nd4jArrayStructure<Float> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.div(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Float>.plus(arg: Float): Nd4jArrayStructure<Float> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Float>.minus(arg: Float): Nd4jArrayStructure<Float> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Float>.times(arg: Float): Nd4jArrayStructure<Float> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Float.div(arg: Nd4jArrayStructure<Float>): Nd4jArrayStructure<Float> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rdiv(this).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Float.minus(arg: Nd4jArrayStructure<Float>): Nd4jArrayStructure<Float> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDRing] over [Nd4jArrayIntStructure].
|
||||||
|
*/
|
||||||
|
public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing<Int, IntRing> {
|
||||||
|
public override val elementContext: IntRing
|
||||||
|
get() = IntRing
|
||||||
|
|
||||||
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = check(asIntStructure())
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Int>.plus(arg: Int): Nd4jArrayStructure<Int> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Int>.minus(arg: Int): Nd4jArrayStructure<Int> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Int>.times(arg: Int): Nd4jArrayStructure<Int> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Int.minus(arg: Nd4jArrayStructure<Int>): Nd4jArrayStructure<Int> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDRing] over [Nd4jArrayStructure] of [Long].
|
||||||
|
*/
|
||||||
|
public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing<Long, LongRing> {
|
||||||
|
public override val elementContext: LongRing
|
||||||
|
get() = LongRing
|
||||||
|
|
||||||
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = check(asLongStructure())
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Long>.plus(arg: Long): Nd4jArrayStructure<Long> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Long>.minus(arg: Long): Nd4jArrayStructure<Long> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Nd4jArrayStructure<Long>.times(arg: Long): Nd4jArrayStructure<Long> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Long.minus(arg: Nd4jArrayStructure<Long>): Nd4jArrayStructure<Long> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,62 @@
|
|||||||
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import org.nd4j.linalg.api.shape.Shape
|
||||||
|
|
||||||
|
private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iterator<IntArray> {
|
||||||
|
private var i: Int = 0
|
||||||
|
|
||||||
|
override fun hasNext(): Boolean = i < iterateOver.length()
|
||||||
|
|
||||||
|
override fun next(): IntArray {
|
||||||
|
val la = if (iterateOver.ordering() == 'c')
|
||||||
|
Shape.ind2subC(iterateOver, i++.toLong())!!
|
||||||
|
else
|
||||||
|
Shape.ind2sub(iterateOver, i++.toLong())!!
|
||||||
|
|
||||||
|
return la.toIntArray()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun INDArray.indicesIterator(): Iterator<IntArray> = Nd4jArrayIndicesIterator(this)
|
||||||
|
|
||||||
|
private sealed class Nd4jArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
|
||||||
|
private var i: Int = 0
|
||||||
|
|
||||||
|
final override fun hasNext(): Boolean = i < iterateOver.length()
|
||||||
|
|
||||||
|
abstract fun getSingle(indices: LongArray): T
|
||||||
|
|
||||||
|
final override fun next(): Pair<IntArray, T> {
|
||||||
|
val la = if (iterateOver.ordering() == 'c')
|
||||||
|
Shape.ind2subC(iterateOver, i++.toLong())!!
|
||||||
|
else
|
||||||
|
Shape.ind2sub(iterateOver, i++.toLong())!!
|
||||||
|
|
||||||
|
return la.toIntArray() to getSingle(la)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class Nd4jArrayRealIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Double>(iterateOver) {
|
||||||
|
override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun INDArray.realIterator(): Iterator<Pair<IntArray, Double>> = Nd4jArrayRealIterator(this)
|
||||||
|
|
||||||
|
private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Long>(iterateOver) {
|
||||||
|
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun INDArray.longIterator(): Iterator<Pair<IntArray, Long>> = Nd4jArrayLongIterator(this)
|
||||||
|
|
||||||
|
private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Int>(iterateOver) {
|
||||||
|
override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun INDArray.intIterator(): Iterator<Pair<IntArray, Int>> = Nd4jArrayIntIterator(this)
|
||||||
|
|
||||||
|
private class Nd4jArrayFloatIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Float>(iterateOver) {
|
||||||
|
override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun INDArray.floatIterator(): Iterator<Pair<IntArray, Float>> = Nd4jArrayFloatIterator(this)
|
@ -0,0 +1,68 @@
|
|||||||
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
|
import kscience.kmath.structures.MutableNDStructure
|
||||||
|
import kscience.kmath.structures.NDStructure
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a [NDStructure] wrapping an [INDArray] object.
|
||||||
|
*
|
||||||
|
* @param T the type of items.
|
||||||
|
*/
|
||||||
|
public sealed class Nd4jArrayStructure<T> : MutableNDStructure<T> {
|
||||||
|
/**
|
||||||
|
* The wrapped [INDArray].
|
||||||
|
*/
|
||||||
|
public abstract val ndArray: INDArray
|
||||||
|
|
||||||
|
public override val shape: IntArray
|
||||||
|
get() = ndArray.shape().toIntArray()
|
||||||
|
|
||||||
|
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
||||||
|
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
|
||||||
|
public override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||||
|
}
|
||||||
|
|
||||||
|
private data class Nd4jArrayIntStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Int>() {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
|
||||||
|
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
||||||
|
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [INDArray] to [Nd4jArrayStructure].
|
||||||
|
*/
|
||||||
|
public fun INDArray.asIntStructure(): Nd4jArrayStructure<Int> = Nd4jArrayIntStructure(this)
|
||||||
|
|
||||||
|
private data class Nd4jArrayLongStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Long>() {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
|
||||||
|
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
|
||||||
|
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [INDArray] to [Nd4jArrayStructure].
|
||||||
|
*/
|
||||||
|
public fun INDArray.asLongStructure(): Nd4jArrayStructure<Long> = Nd4jArrayLongStructure(this)
|
||||||
|
|
||||||
|
private data class Nd4jArrayRealStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Double>() {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
|
||||||
|
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||||
|
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [INDArray] to [Nd4jArrayStructure].
|
||||||
|
*/
|
||||||
|
public fun INDArray.asRealStructure(): Nd4jArrayStructure<Double> = Nd4jArrayRealStructure(this)
|
||||||
|
|
||||||
|
private data class Nd4jArrayFloatStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Float>() {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
|
||||||
|
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
||||||
|
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [INDArray] to [Nd4jArrayStructure].
|
||||||
|
*/
|
||||||
|
public fun INDArray.asFloatStructure(): Nd4jArrayStructure<Float> = Nd4jArrayFloatStructure(this)
|
4
kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt
Normal file
4
kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/arrays.kt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
|
internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() }
|
||||||
|
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() }
|
@ -0,0 +1,42 @@
|
|||||||
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import kscience.kmath.operations.invoke
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
|
internal class Nd4jArrayAlgebraTest {
|
||||||
|
@Test
|
||||||
|
fun testProduce() {
|
||||||
|
val res = (RealNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
|
||||||
|
val expected = (Nd4j.create(2, 2) ?: fail()).asRealStructure()
|
||||||
|
expected[intArrayOf(0, 0)] = 0.0
|
||||||
|
expected[intArrayOf(0, 1)] = 1.0
|
||||||
|
expected[intArrayOf(1, 0)] = 1.0
|
||||||
|
expected[intArrayOf(1, 1)] = 2.0
|
||||||
|
assertEquals(expected, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMap() {
|
||||||
|
val res = (IntNd4jArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } }
|
||||||
|
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||||
|
expected[intArrayOf(0, 0)] = 3
|
||||||
|
expected[intArrayOf(0, 1)] = 3
|
||||||
|
expected[intArrayOf(1, 0)] = 3
|
||||||
|
expected[intArrayOf(1, 1)] = 3
|
||||||
|
assertEquals(expected, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAdd() {
|
||||||
|
val res = (IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 }
|
||||||
|
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||||
|
expected[intArrayOf(0, 0)] = 26
|
||||||
|
expected[intArrayOf(0, 1)] = 26
|
||||||
|
expected[intArrayOf(1, 0)] = 26
|
||||||
|
expected[intArrayOf(1, 1)] = 26
|
||||||
|
assertEquals(expected, res)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,72 @@
|
|||||||
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
|
import kscience.kmath.structures.get
|
||||||
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertNotEquals
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
|
internal class Nd4jArrayStructureTest {
|
||||||
|
@Test
|
||||||
|
fun testElements() {
|
||||||
|
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
|
val struct = nd.asRealStructure()
|
||||||
|
val res = struct.elements().map(Pair<IntArray, Double>::second).toList()
|
||||||
|
assertEquals(listOf(1.0, 2.0, 3.0), res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testShape() {
|
||||||
|
val nd = Nd4j.rand(10, 2, 3, 6) ?: fail()
|
||||||
|
val struct = nd.asRealStructure()
|
||||||
|
assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testEquals() {
|
||||||
|
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
|
||||||
|
val struct1 = nd1.asRealStructure()
|
||||||
|
assertEquals(struct1, struct1)
|
||||||
|
assertNotEquals(struct1 as Any?, null)
|
||||||
|
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
|
||||||
|
val struct2 = nd2.asRealStructure()
|
||||||
|
assertEquals(struct1, struct2)
|
||||||
|
assertEquals(struct2, struct1)
|
||||||
|
val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
|
||||||
|
val struct3 = nd3.asRealStructure()
|
||||||
|
assertEquals(struct2, struct3)
|
||||||
|
assertEquals(struct1, struct3)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testHashCode() {
|
||||||
|
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail()
|
||||||
|
val struct1 = nd1.asRealStructure()
|
||||||
|
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail()
|
||||||
|
val struct2 = nd2.asRealStructure()
|
||||||
|
assertEquals(struct1.hashCode(), struct2.hashCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDimension() {
|
||||||
|
val nd = Nd4j.rand(8, 16, 3, 7, 1)!!
|
||||||
|
val struct = nd.asFloatStructure()
|
||||||
|
assertEquals(5, struct.dimension)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testGet() {
|
||||||
|
val nd = Nd4j.rand(10, 2, 3, 6)?:fail()
|
||||||
|
val struct = nd.asIntStructure()
|
||||||
|
assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0])
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSet() {
|
||||||
|
val nd = Nd4j.rand(17, 12, 4, 8)!!
|
||||||
|
val struct = nd.asLongStructure()
|
||||||
|
struct[intArrayOf(1, 2, 3, 4)] = 777
|
||||||
|
assertEquals(777, struct[1, 2, 3, 4])
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,6 @@
|
|||||||
plugins { id("ru.mipt.npm.mpp") }
|
plugins {
|
||||||
|
id("ru.mipt.npm.mpp")
|
||||||
|
}
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
all {
|
all {
|
||||||
|
@ -12,16 +12,18 @@ public object Fitting {
|
|||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, I : Any, A> chiSquared(
|
public fun <T : Any, I : Any, A> chiSquared(
|
||||||
autoDiff: AutoDiffProcessor<T, I, A>,
|
autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
||||||
x: Buffer<T>,
|
x: Buffer<T>,
|
||||||
y: Buffer<T>,
|
y: Buffer<T>,
|
||||||
yErr: Buffer<T>,
|
yErr: Buffer<T>,
|
||||||
model: A.(I) -> I,
|
model: A.(I) -> I,
|
||||||
): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
): DifferentiableExpression<T, Expression<T>> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||||
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
|
||||||
return autoDiff.process {
|
return autoDiff.process {
|
||||||
var sum = zero
|
var sum = zero
|
||||||
|
|
||||||
x.indices.forEach {
|
x.indices.forEach {
|
||||||
val xValue = const(x[it])
|
val xValue = const(x[it])
|
||||||
val yValue = const(y[it])
|
val yValue = const(y[it])
|
||||||
@ -29,6 +31,7 @@ public object Fitting {
|
|||||||
val modelValue = model(xValue)
|
val modelValue = model(xValue)
|
||||||
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
sum
|
sum
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -45,6 +48,7 @@ public object Fitting {
|
|||||||
): Expression<Double> {
|
): Expression<Double> {
|
||||||
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
|
||||||
return Expression { arguments ->
|
return Expression { arguments ->
|
||||||
x.indices.sumByDouble {
|
x.indices.sumByDouble {
|
||||||
val xValue = x[it]
|
val xValue = x[it]
|
||||||
|
@ -0,0 +1,58 @@
|
|||||||
|
package kscience.kmath.stat
|
||||||
|
|
||||||
|
import kotlinx.coroutines.*
|
||||||
|
import kotlin.coroutines.CoroutineContext
|
||||||
|
import kotlin.coroutines.EmptyCoroutineContext
|
||||||
|
import kotlin.coroutines.coroutineContext
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A scope for a Monte-Carlo computations or multi-coroutine random number generation.
|
||||||
|
* The scope preserves the order of random generator calls as long as all concurrency calls is done via [launch] and [async]
|
||||||
|
* functions.
|
||||||
|
*/
|
||||||
|
public class MCScope(
|
||||||
|
public val coroutineContext: CoroutineContext,
|
||||||
|
public val random: RandomGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Launches a supervised Monte-Carlo scope
|
||||||
|
*/
|
||||||
|
public suspend inline fun <T> mcScope(generator: RandomGenerator, block: MCScope.() -> T): T =
|
||||||
|
MCScope(coroutineContext, generator).block()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Launch mc scope with a given seed
|
||||||
|
*/
|
||||||
|
public suspend inline fun <T> mcScope(seed: Long, block: MCScope.() -> T): T =
|
||||||
|
mcScope(RandomGenerator.default(seed), block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialized launch for [MCScope]. Behaves the same way as regular [CoroutineScope.launch], but also stores the generator fork.
|
||||||
|
* The method itself is not thread safe.
|
||||||
|
*/
|
||||||
|
public inline fun MCScope.launch(
|
||||||
|
context: CoroutineContext = EmptyCoroutineContext,
|
||||||
|
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||||
|
crossinline block: suspend MCScope.() -> Unit,
|
||||||
|
): Job {
|
||||||
|
val newRandom = random.fork()
|
||||||
|
return CoroutineScope(coroutineContext).launch(context, start) {
|
||||||
|
MCScope(coroutineContext, newRandom).block()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialized async for [MCScope]. Behaves the same way as regular [CoroutineScope.async], but also stores the generator fork.
|
||||||
|
* The method itself is not thread safe.
|
||||||
|
*/
|
||||||
|
public inline fun <T> MCScope.async(
|
||||||
|
context: CoroutineContext = EmptyCoroutineContext,
|
||||||
|
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||||
|
crossinline block: suspend MCScope.() -> T,
|
||||||
|
): Deferred<T> {
|
||||||
|
val newRandom = random.fork()
|
||||||
|
return CoroutineScope(coroutineContext).async(context, start) {
|
||||||
|
MCScope(coroutineContext, newRandom).block()
|
||||||
|
}
|
||||||
|
}
|
@ -27,17 +27,17 @@ public interface OptimizationProblem<T : Any> {
|
|||||||
/**
|
/**
|
||||||
* Define the initial guess for the optimization problem
|
* Define the initial guess for the optimization problem
|
||||||
*/
|
*/
|
||||||
public fun initialGuess(map: Map<Symbol, T>): Unit
|
public fun initialGuess(map: Map<Symbol, T>)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set an objective function expression
|
* Set an objective function expression
|
||||||
*/
|
*/
|
||||||
public fun expression(expression: Expression<T>): Unit
|
public fun expression(expression: Expression<T>)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set a differentiable expression as objective function as function and gradient provider
|
* Set a differentiable expression as objective function as function and gradient provider
|
||||||
*/
|
*/
|
||||||
public fun diffExpression(expression: DifferentiableExpression<T>): Unit
|
public fun diffExpression(expression: DifferentiableExpression<T, Expression<T>>)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the problem from previous optimization run
|
* Update the problem from previous optimization run
|
||||||
@ -50,9 +50,8 @@ public interface OptimizationProblem<T : Any> {
|
|||||||
public fun optimize(): OptimizationResult<T>
|
public fun optimize(): OptimizationResult<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
public interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
|
public fun interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
|
||||||
public fun build(symbols: List<Symbol>): P
|
public fun build(symbols: List<Symbol>): P
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
|
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
|
||||||
@ -60,7 +59,6 @@ public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFac
|
|||||||
block: P.() -> Unit,
|
block: P.() -> Unit,
|
||||||
): P = build(symbols).apply(block)
|
): P = build(symbols).apply(block)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
|
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
|
||||||
*/
|
*/
|
||||||
@ -78,7 +76,7 @@ public fun <T : Any, F : OptimizationProblem<T>> Expression<T>.optimizeWith(
|
|||||||
/**
|
/**
|
||||||
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.optimizeWith(
|
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
||||||
factory: OptimizationProblemFactory<T, F>,
|
factory: OptimizationProblemFactory<T, F>,
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: F.() -> Unit,
|
configuration: F.() -> Unit,
|
||||||
@ -88,4 +86,3 @@ public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.op
|
|||||||
problem.diffExpression(this)
|
problem.diffExpression(this)
|
||||||
return problem.optimize()
|
return problem.optimize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,85 @@
|
|||||||
|
package kscience.kmath.stat
|
||||||
|
|
||||||
|
import kotlinx.coroutines.*
|
||||||
|
import java.util.*
|
||||||
|
import kotlin.collections.HashSet
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
data class RandomResult(val branch: String, val order: Int, val value: Int)
|
||||||
|
|
||||||
|
typealias ATest = suspend CoroutineScope.() -> Set<RandomResult>
|
||||||
|
|
||||||
|
class MCScopeTest {
|
||||||
|
val simpleTest: ATest = {
|
||||||
|
mcScope(1111) {
|
||||||
|
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||||
|
|
||||||
|
launch {
|
||||||
|
//println(random)
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
res.add(RandomResult("first", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
launch {
|
||||||
|
//empty fork
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
launch {
|
||||||
|
//println(random)
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
res.add(RandomResult("second", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val testWithJoin: ATest = {
|
||||||
|
mcScope(1111) {
|
||||||
|
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||||
|
|
||||||
|
val job = launch {
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
res.add(RandomResult("first", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
launch {
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
if (it == 4) job.join()
|
||||||
|
res.add(RandomResult("second", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun compareResult(test: ATest) {
|
||||||
|
val res1 = runBlocking(Dispatchers.Default) { test() }
|
||||||
|
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
|
||||||
|
assertEquals(
|
||||||
|
res1.find { it.branch == "first" && it.order == 7 }?.value,
|
||||||
|
res2.find { it.branch == "first" && it.order == 7 }?.value
|
||||||
|
)
|
||||||
|
assertEquals(res1, res2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testParallel() {
|
||||||
|
compareResult(simpleTest)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testConditionalJoin() {
|
||||||
|
compareResult(testWithJoin)
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,6 @@
|
|||||||
plugins { id("ru.mipt.npm.jvm") }
|
plugins {
|
||||||
|
id("ru.mipt.npm.jvm")
|
||||||
|
}
|
||||||
|
|
||||||
description = "Binding for https://github.com/JetBrains-Research/viktor"
|
description = "Binding for https://github.com/JetBrains-Research/viktor"
|
||||||
|
|
||||||
|
@ -1,17 +1,15 @@
|
|||||||
pluginManagement {
|
pluginManagement {
|
||||||
repositories {
|
repositories {
|
||||||
mavenLocal()
|
|
||||||
jcenter()
|
|
||||||
gradlePluginPortal()
|
gradlePluginPortal()
|
||||||
|
jcenter()
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val toolsVersion = "0.6.4-dev-1.4.20-M2"
|
val toolsVersion = "0.7.0"
|
||||||
val kotlinVersion = "1.4.20-M2"
|
val kotlinVersion = "1.4.20"
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
||||||
@ -35,11 +33,13 @@ include(
|
|||||||
":kmath-commons",
|
":kmath-commons",
|
||||||
":kmath-viktor",
|
":kmath-viktor",
|
||||||
":kmath-stat",
|
":kmath-stat",
|
||||||
|
":kmath-nd4j",
|
||||||
":kmath-dimensions",
|
":kmath-dimensions",
|
||||||
":kmath-for-real",
|
":kmath-for-real",
|
||||||
":kmath-geometry",
|
":kmath-geometry",
|
||||||
":kmath-ast",
|
":kmath-ast",
|
||||||
":kmath-ejml",
|
":kmath-ejml",
|
||||||
|
":kmath-kotlingrad",
|
||||||
":kmath-gsl",
|
":kmath-gsl",
|
||||||
":examples"
|
":examples"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user