forked from kscience/kmath
commit
debff5357b
27
.github/workflows/build.yml
vendored
27
.github/workflows/build.yml
vendored
@ -13,25 +13,13 @@ jobs:
|
|||||||
runs-on: ${{matrix.os}}
|
runs-on: ${{matrix.os}}
|
||||||
timeout-minutes: 40
|
timeout-minutes: 40
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout the repo
|
- uses: actions/checkout@v3.0.0
|
||||||
uses: actions/checkout@v2
|
- uses: actions/setup-java@v3.0.0
|
||||||
- name: Set up JDK 11
|
|
||||||
uses: DeLaGuardo/setup-graalvm@4.0
|
|
||||||
with:
|
with:
|
||||||
graalvm: 21.2.0
|
java-version: 11
|
||||||
java: java11
|
distribution: liberica
|
||||||
arch: amd64
|
|
||||||
- name: Cache gradle
|
|
||||||
uses: actions/cache@v2
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
~/.gradle/caches
|
|
||||||
~/.gradle/wrapper
|
|
||||||
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-gradle-
|
|
||||||
- name: Cache konan
|
- name: Cache konan
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3.0.1
|
||||||
with:
|
with:
|
||||||
path: ~/.konan
|
path: ~/.konan
|
||||||
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
|
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
|
||||||
@ -39,5 +27,6 @@ jobs:
|
|||||||
${{ runner.os }}-gradle-
|
${{ runner.os }}-gradle-
|
||||||
- name: Gradle Wrapper Validation
|
- name: Gradle Wrapper Validation
|
||||||
uses: gradle/wrapper-validation-action@v1.0.4
|
uses: gradle/wrapper-validation-action@v1.0.4
|
||||||
- name: Build
|
- uses: gradle/gradle-build-action@v2.1.5
|
||||||
run: ./gradlew build --build-cache --no-daemon --stacktrace
|
with:
|
||||||
|
arguments: build
|
||||||
|
25
.github/workflows/pages.yml
vendored
25
.github/workflows/pages.yml
vendored
@ -1,28 +1,31 @@
|
|||||||
name: Dokka publication
|
name: Dokka publication
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
workflow_dispatch:
|
||||||
branches: [ master ]
|
release:
|
||||||
|
types: [ created ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
timeout-minutes: 40
|
timeout-minutes: 40
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3.0.0
|
||||||
- uses: DeLaGuardo/setup-graalvm@4.0
|
- uses: actions/setup-java@v3.0.0
|
||||||
with:
|
with:
|
||||||
graalvm: 21.2.0
|
java-version: 11
|
||||||
java: java11
|
distribution: liberica
|
||||||
arch: amd64
|
- name: Cache konan
|
||||||
- uses: actions/cache@v2
|
uses: actions/cache@v3.0.1
|
||||||
with:
|
with:
|
||||||
path: ~/.gradle/caches
|
path: ~/.konan
|
||||||
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
|
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gradle-
|
${{ runner.os }}-gradle-
|
||||||
- run: ./gradlew dokkaHtmlMultiModule --build-cache --no-daemon --no-parallel --stacktrace
|
- uses: gradle/gradle-build-action@v2.1.5
|
||||||
- uses: JamesIves/github-pages-deploy-action@4.1.0
|
with:
|
||||||
|
arguments: dokkaHtmlMultiModule --no-parallel
|
||||||
|
- uses: JamesIves/github-pages-deploy-action@v4.3.0
|
||||||
with:
|
with:
|
||||||
branch: gh-pages
|
branch: gh-pages
|
||||||
folder: build/dokka/htmlMultiModule
|
folder: build/dokka/htmlMultiModule
|
||||||
|
44
.github/workflows/publish.yml
vendored
44
.github/workflows/publish.yml
vendored
@ -14,42 +14,38 @@ jobs:
|
|||||||
os: [ macOS-latest, windows-latest ]
|
os: [ macOS-latest, windows-latest ]
|
||||||
runs-on: ${{matrix.os}}
|
runs-on: ${{matrix.os}}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout the repo
|
- uses: actions/checkout@v3.0.0
|
||||||
uses: actions/checkout@v2
|
- uses: actions/setup-java@v3.0.0
|
||||||
- name: Set up JDK 11
|
|
||||||
uses: DeLaGuardo/setup-graalvm@4.0
|
|
||||||
with:
|
with:
|
||||||
graalvm: 21.2.0
|
java-version: 11
|
||||||
java: java11
|
distribution: liberica
|
||||||
arch: amd64
|
|
||||||
- name: Cache gradle
|
|
||||||
uses: actions/cache@v2
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
~/.gradle/caches
|
|
||||||
~/.gradle/wrapper
|
|
||||||
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-gradle-
|
|
||||||
- name: Cache konan
|
- name: Cache konan
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3.0.1
|
||||||
with:
|
with:
|
||||||
path: ~/.konan
|
path: ~/.konan
|
||||||
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
|
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gradle-
|
${{ runner.os }}-gradle-
|
||||||
- name: Gradle Wrapper Validation
|
- uses: gradle/wrapper-validation-action@v1.0.4
|
||||||
uses: gradle/wrapper-validation-action@v1.0.4
|
|
||||||
- name: Publish Windows Artifacts
|
- name: Publish Windows Artifacts
|
||||||
if: matrix.os == 'windows-latest'
|
if: matrix.os == 'windows-latest'
|
||||||
shell: cmd
|
uses: gradle/gradle-build-action@v2.1.5
|
||||||
run: >
|
with:
|
||||||
./gradlew release --no-daemon --build-cache -Ppublishing.enabled=true
|
arguments: |
|
||||||
|
releaseAll
|
||||||
|
-Ppublishing.enabled=true
|
||||||
|
-Ppublishing.sonatype=false
|
||||||
-Ppublishing.space.user=${{ secrets.SPACE_APP_ID }}
|
-Ppublishing.space.user=${{ secrets.SPACE_APP_ID }}
|
||||||
-Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }}
|
-Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }}
|
||||||
- name: Publish Mac Artifacts
|
- name: Publish Mac Artifacts
|
||||||
if: matrix.os == 'macOS-latest'
|
if: matrix.os == 'macOS-latest'
|
||||||
run: >
|
uses: gradle/gradle-build-action@v2.1.5
|
||||||
./gradlew release --no-daemon --build-cache -Ppublishing.enabled=true -Ppublishing.platform=macosX64
|
with:
|
||||||
|
arguments: |
|
||||||
|
releaseMacosX64
|
||||||
|
releaseIosArm64
|
||||||
|
releaseIosX64
|
||||||
|
-Ppublishing.enabled=true
|
||||||
|
-Ppublishing.sonatype=false
|
||||||
-Ppublishing.space.user=${{ secrets.SPACE_APP_ID }}
|
-Ppublishing.space.user=${{ secrets.SPACE_APP_ID }}
|
||||||
-Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }}
|
-Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }}
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,3 +19,4 @@ out/
|
|||||||
|
|
||||||
!/.idea/copyright/
|
!/.idea/copyright/
|
||||||
!/.idea/scopes/
|
!/.idea/scopes/
|
||||||
|
/kotlin-js-store/yarn.lock
|
||||||
|
36
CHANGELOG.md
36
CHANGELOG.md
@ -2,6 +2,21 @@
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- Kotlin 1.7
|
||||||
|
- `LazyStructure` `deffered` -> `async` to comply with coroutines code style
|
||||||
|
|
||||||
|
### Deprecated
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
## [0.3.0]
|
||||||
|
### Added
|
||||||
- `ScaleOperations` interface
|
- `ScaleOperations` interface
|
||||||
- `Field` extends `ScaleOperations`
|
- `Field` extends `ScaleOperations`
|
||||||
- Basic integration API
|
- Basic integration API
|
||||||
@ -19,6 +34,12 @@
|
|||||||
- Complex power
|
- Complex power
|
||||||
- Separate methods for UInt, Int and Number powers. NaN safety.
|
- Separate methods for UInt, Int and Number powers. NaN safety.
|
||||||
- Tensorflow prototype
|
- Tensorflow prototype
|
||||||
|
- `ValueAndErrorField`
|
||||||
|
- MST compilation to WASM: #286
|
||||||
|
- Jafama integration: #176
|
||||||
|
- `contentEquals` with tolerance: #364
|
||||||
|
- Compilation to TeX for MST: #254
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Exponential operations merged with hyperbolic functions
|
- Exponential operations merged with hyperbolic functions
|
||||||
@ -48,10 +69,15 @@
|
|||||||
- Operations -> Ops
|
- Operations -> Ops
|
||||||
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
|
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
|
||||||
- Tensor algebra takes read-only structures as input and inherits AlgebraND
|
- Tensor algebra takes read-only structures as input and inherits AlgebraND
|
||||||
|
- `UnivariateDistribution` renamed to `Distribution1D`
|
||||||
|
- Rework of histograms.
|
||||||
|
- `UnivariateFunction` -> `Function1D`, `MultivariateFunction` -> `FunctionND`
|
||||||
|
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
- Specialized `DoubleBufferAlgebra`
|
- Specialized `DoubleBufferAlgebra`
|
||||||
|
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
- Nearest in Domain. To be implemented in geometry package.
|
- Nearest in Domain. To be implemented in geometry package.
|
||||||
- Number multiplication and division in main Algebra chain
|
- Number multiplication and division in main Algebra chain
|
||||||
@ -62,10 +88,12 @@
|
|||||||
- Second generic from DifferentiableExpression
|
- Second generic from DifferentiableExpression
|
||||||
- Algebra elements are completely removed. Use algebra contexts instead.
|
- Algebra elements are completely removed. Use algebra contexts instead.
|
||||||
|
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
- Ring inherits RingOperations, not GroupOperations
|
- Ring inherits RingOperations, not GroupOperations
|
||||||
- Univariate histogram filling
|
- Univariate histogram filling
|
||||||
|
|
||||||
|
|
||||||
### Security
|
### Security
|
||||||
|
|
||||||
## [0.2.0]
|
## [0.2.0]
|
||||||
@ -88,6 +116,7 @@
|
|||||||
- New `MatrixFeature` interfaces for matrix decompositions
|
- New `MatrixFeature` interfaces for matrix decompositions
|
||||||
- Basic Quaternion vector support in `kmath-complex`.
|
- Basic Quaternion vector support in `kmath-complex`.
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Package changed from `scientifik` to `space.kscience`
|
- Package changed from `scientifik` to `space.kscience`
|
||||||
- Gradle version: 6.6 -> 6.8.2
|
- Gradle version: 6.6 -> 6.8.2
|
||||||
@ -112,7 +141,6 @@
|
|||||||
- `symbol` method in `Algebra` renamed to `bindSymbol` to avoid ambiguity
|
- `symbol` method in `Algebra` renamed to `bindSymbol` to avoid ambiguity
|
||||||
- Add `out` projection to `Buffer` generic
|
- Add `out` projection to `Buffer` generic
|
||||||
|
|
||||||
### Deprecated
|
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
- `kmath-koma` module because it doesn't support Kotlin 1.4.
|
- `kmath-koma` module because it doesn't support Kotlin 1.4.
|
||||||
@ -122,13 +150,11 @@
|
|||||||
- `Real` class
|
- `Real` class
|
||||||
- StructureND identity and equals
|
- StructureND identity and equals
|
||||||
|
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
- `symbol` method in `MstExtendedField` (https://github.com/mipt-npm/kmath/pull/140)
|
- `symbol` method in `MstExtendedField` (https://github.com/mipt-npm/kmath/pull/140)
|
||||||
|
|
||||||
### Security
|
|
||||||
|
|
||||||
## [0.1.4]
|
## [0.1.4]
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
- Functional Expressions API
|
- Functional Expressions API
|
||||||
- Mathematical Syntax Tree, its interpreter and API
|
- Mathematical Syntax Tree, its interpreter and API
|
||||||
@ -146,6 +172,7 @@
|
|||||||
- Full hyperbolic functions support and default implementations within `ExtendedField`
|
- Full hyperbolic functions support and default implementations within `ExtendedField`
|
||||||
- Norm support for `Complex`
|
- Norm support for `Complex`
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- `readAsMemory` now has `throws IOException` in JVM signature.
|
- `readAsMemory` now has `throws IOException` in JVM signature.
|
||||||
- Several functions taking functional types were made `inline`.
|
- Several functions taking functional types were made `inline`.
|
||||||
@ -157,6 +184,7 @@
|
|||||||
- Gradle version: 6.3 -> 6.6
|
- Gradle version: 6.3 -> 6.6
|
||||||
- Moved probability distributions to commons-rng and to `kmath-prob`
|
- Moved probability distributions to commons-rng and to `kmath-prob`
|
||||||
|
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
- Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106)
|
- Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106)
|
||||||
- D3.dim value in `kmath-dimensions`
|
- D3.dim value in `kmath-dimensions`
|
||||||
|
82
README.md
82
README.md
@ -52,21 +52,18 @@ module definitions below. The module stability could have the following levels:
|
|||||||
|
|
||||||
## Modules
|
## Modules
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [benchmarks](benchmarks)
|
### [benchmarks](benchmarks)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [examples](examples)
|
### [examples](examples)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-ast](kmath-ast)
|
### [kmath-ast](kmath-ast)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
@ -77,15 +74,13 @@ module definitions below. The module stability could have the following levels:
|
|||||||
> - [mst-js-codegen](kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler
|
> - [mst-js-codegen](kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler
|
||||||
> - [rendering](kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt) : Extendable MST rendering
|
> - [rendering](kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt) : Extendable MST rendering
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-commons](kmath-commons)
|
### [kmath-commons](kmath-commons)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-complex](kmath-complex)
|
### [kmath-complex](kmath-complex)
|
||||||
> Complex numbers and quaternions.
|
> Complex numbers and quaternions.
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
@ -94,9 +89,8 @@ module definitions below. The module stability could have the following levels:
|
|||||||
> - [complex](kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt) : Complex Numbers
|
> - [complex](kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt) : Complex Numbers
|
||||||
> - [quaternion](kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt) : Quaternions
|
> - [quaternion](kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt) : Quaternions
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-core](kmath-core)
|
### [kmath-core](kmath-core)
|
||||||
> Core classes, algebra definitions, basic linear algebra
|
> Core classes, algebra definitions, basic linear algebra
|
||||||
>
|
>
|
||||||
> **Maturity**: DEVELOPMENT
|
> **Maturity**: DEVELOPMENT
|
||||||
@ -112,21 +106,18 @@ performance calculations to code generation.
|
|||||||
> - [domains](kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains) : Domains
|
> - [domains](kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains) : Domains
|
||||||
> - [autodiff](kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
|
> - [autodiff](kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-coroutines](kmath-coroutines)
|
### [kmath-coroutines](kmath-coroutines)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-dimensions](kmath-dimensions)
|
### [kmath-dimensions](kmath-dimensions)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-ejml](kmath-ejml)
|
### [kmath-ejml](kmath-ejml)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
@ -136,9 +127,8 @@ performance calculations to code generation.
|
|||||||
> - [ejml-matrix](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : Matrix implementation.
|
> - [ejml-matrix](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : Matrix implementation.
|
||||||
> - [ejml-linear-space](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : LinearSpace implementations.
|
> - [ejml-linear-space](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : LinearSpace implementations.
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-for-real](kmath-for-real)
|
### [kmath-for-real](kmath-for-real)
|
||||||
> Extension module that should be used to achieve numpy-like behavior.
|
> Extension module that should be used to achieve numpy-like behavior.
|
||||||
All operations are specialized to work with `Double` numbers without declaring algebraic contexts.
|
All operations are specialized to work with `Double` numbers without declaring algebraic contexts.
|
||||||
One can still use generic algebras though.
|
One can still use generic algebras though.
|
||||||
@ -150,9 +140,8 @@ One can still use generic algebras though.
|
|||||||
> - [DoubleMatrix](kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt) : Numpy-like operations for 2d real structures
|
> - [DoubleMatrix](kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt) : Numpy-like operations for 2d real structures
|
||||||
> - [grids](kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt) : Uniform grid generators
|
> - [grids](kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt) : Uniform grid generators
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-functions](kmath-functions)
|
### [kmath-functions](kmath-functions)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
@ -164,21 +153,18 @@ One can still use generic algebras though.
|
|||||||
> - [spline interpolation](kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt) : Cubic spline XY interpolator.
|
> - [spline interpolation](kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt) : Cubic spline XY interpolator.
|
||||||
> - [integration](kmath-functions/#) : Univariate and multivariate quadratures
|
> - [integration](kmath-functions/#) : Univariate and multivariate quadratures
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-geometry](kmath-geometry)
|
### [kmath-geometry](kmath-geometry)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-histograms](kmath-histograms)
|
### [kmath-histograms](kmath-histograms)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-jafama](kmath-jafama)
|
### [kmath-jafama](kmath-jafama)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
@ -186,15 +172,13 @@ One can still use generic algebras though.
|
|||||||
> **Features:**
|
> **Features:**
|
||||||
> - [jafama-double](kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/) : Double ExtendedField implementations based on Jafama
|
> - [jafama-double](kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/) : Double ExtendedField implementations based on Jafama
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-jupyter](kmath-jupyter)
|
### [kmath-jupyter](kmath-jupyter)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-kotlingrad](kmath-kotlingrad)
|
### [kmath-kotlingrad](kmath-kotlingrad)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
@ -203,21 +187,18 @@ One can still use generic algebras though.
|
|||||||
> - [differentiable-mst-expression](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt) : MST based DifferentiableExpression.
|
> - [differentiable-mst-expression](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt) : MST based DifferentiableExpression.
|
||||||
> - [scalars-adapters](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt) : Conversions between Kotlin∇'s SFun and MST
|
> - [scalars-adapters](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt) : Conversions between Kotlin∇'s SFun and MST
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-memory](kmath-memory)
|
### [kmath-memory](kmath-memory)
|
||||||
> An API and basic implementation for arranging objects in a continuous memory block.
|
> An API and basic implementation for arranging objects in a continuous memory block.
|
||||||
>
|
>
|
||||||
> **Maturity**: DEVELOPMENT
|
> **Maturity**: DEVELOPMENT
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-multik](kmath-multik)
|
### [kmath-multik](kmath-multik)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-nd4j](kmath-nd4j)
|
### [kmath-nd4j](kmath-nd4j)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
@ -227,27 +208,28 @@ One can still use generic algebras though.
|
|||||||
> - [nd4jarrayrings](kmath-nd4j/#) : Rings over Nd4jArrayStructure of Int and Long
|
> - [nd4jarrayrings](kmath-nd4j/#) : Rings over Nd4jArrayStructure of Int and Long
|
||||||
> - [nd4jarrayfields](kmath-nd4j/#) : Fields over Nd4jArrayStructure of Float and Double
|
> - [nd4jarrayfields](kmath-nd4j/#) : Fields over Nd4jArrayStructure of Float and Double
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-optimization](kmath-optimization)
|
### [kmath-optimization](kmath-optimization)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-stat](kmath-stat)
|
### [kmath-stat](kmath-stat)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-symja](kmath-symja)
|
### [kmath-symja](kmath-symja)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-tensors](kmath-tensors)
|
### [kmath-tensorflow](kmath-tensorflow)
|
||||||
|
>
|
||||||
|
>
|
||||||
|
> **Maturity**: PROTOTYPE
|
||||||
|
|
||||||
|
### [kmath-tensors](kmath-tensors)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: PROTOTYPE
|
> **Maturity**: PROTOTYPE
|
||||||
@ -257,13 +239,11 @@ One can still use generic algebras though.
|
|||||||
> - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
|
> - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
|
||||||
> - [linear algebra operations](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.
|
> - [linear algebra operations](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.
|
||||||
|
|
||||||
<hr/>
|
|
||||||
|
|
||||||
* ### [kmath-viktor](kmath-viktor)
|
### [kmath-viktor](kmath-viktor)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: DEVELOPMENT
|
> **Maturity**: DEVELOPMENT
|
||||||
<hr/>
|
|
||||||
|
|
||||||
|
|
||||||
## Multi-platform support
|
## Multi-platform support
|
||||||
@ -302,8 +282,8 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api("space.kscience:kmath-core:0.3.0-dev-17")
|
api("space.kscience:kmath-core:$version")
|
||||||
// api("space.kscience:kmath-core-jvm:0.3.0-dev-17") for jvm-specific version
|
// api("space.kscience:kmath-core-jvm:$version") for jvm-specific version
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
4
benchmarks/README.md
Normal file
4
benchmarks/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Module benchmarks
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,6 +52,8 @@ kotlin {
|
|||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-jafama"))
|
implementation(project(":kmath-jafama"))
|
||||||
implementation(project(":kmath-multik"))
|
implementation(project(":kmath-multik"))
|
||||||
|
implementation(projects.kmath.kmathTensorflow)
|
||||||
|
implementation("org.tensorflow:tensorflow-core-platform:0.4.0")
|
||||||
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
||||||
// uncomment if your system supports AVX2
|
// uncomment if your system supports AVX2
|
||||||
// val os = System.getProperty("os.name")
|
// val os = System.getProperty("os.name")
|
||||||
@ -122,6 +124,11 @@ benchmark {
|
|||||||
include("JafamaBenchmark")
|
include("JafamaBenchmark")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
configurations.register("tensorAlgebra") {
|
||||||
|
commonConfiguration()
|
||||||
|
include("TensorAlgebraBenchmark")
|
||||||
|
}
|
||||||
|
|
||||||
configurations.register("viktor") {
|
configurations.register("viktor") {
|
||||||
commonConfiguration()
|
commonConfiguration()
|
||||||
include("ViktorBenchmark")
|
include("ViktorBenchmark")
|
||||||
@ -148,7 +155,7 @@ kotlin.sourceSets.all {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
|
tasks.withType<org.jetbrains.kotlin.gradle.dsl.KotlinJvmCompile> {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = "11"
|
jvmTarget = "11"
|
||||||
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy"
|
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy"
|
||||||
|
@ -15,7 +15,11 @@ import space.kscience.kmath.linear.invoke
|
|||||||
import space.kscience.kmath.linear.linearSpace
|
import space.kscience.kmath.linear.linearSpace
|
||||||
import space.kscience.kmath.multik.multikAlgebra
|
import space.kscience.kmath.multik.multikAlgebra
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.tensorflow.produceWithTF
|
||||||
|
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.core.tensorAlgebra
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
@ -39,6 +43,16 @@ internal class DotBenchmark {
|
|||||||
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
|
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun tfDot(blackhole: Blackhole) {
|
||||||
|
blackhole.consume(
|
||||||
|
DoubleField.produceWithTF {
|
||||||
|
matrix1 dot matrix1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
|
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
@ -59,13 +73,13 @@ internal class DotBenchmark {
|
|||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Benchmark
|
@Benchmark
|
||||||
// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
fun tensorDot(blackhole: Blackhole) = with(DoubleField.tensorAlgebra) {
|
||||||
// blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
// }
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) {
|
fun multikDot(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,4 +92,9 @@ internal class DotBenchmark {
|
|||||||
fun doubleDot(blackhole: Blackhole) = with(DoubleField.linearSpace) {
|
fun doubleDot(blackhole: Blackhole) = with(DoubleField.linearSpace) {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun doubleTensorDot(blackhole: Blackhole) = DoubleTensorAlgebra.invoke {
|
||||||
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.benchmarks
|
||||||
|
|
||||||
|
import kotlinx.benchmark.Benchmark
|
||||||
|
import kotlinx.benchmark.Blackhole
|
||||||
|
import kotlinx.benchmark.Scope
|
||||||
|
import kotlinx.benchmark.State
|
||||||
|
import space.kscience.kmath.linear.linearSpace
|
||||||
|
import space.kscience.kmath.linear.matrix
|
||||||
|
import space.kscience.kmath.linear.symmetric
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.tensors.core.tensorAlgebra
|
||||||
|
import kotlin.random.Random
|
||||||
|
|
||||||
|
@State(Scope.Benchmark)
|
||||||
|
internal class TensorAlgebraBenchmark {
|
||||||
|
companion object {
|
||||||
|
private val random = Random(12224)
|
||||||
|
private const val dim = 30
|
||||||
|
|
||||||
|
private val matrix = DoubleField.linearSpace.matrix(dim, dim).symmetric { _, _ -> random.nextDouble() }
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun tensorSymEigSvd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||||
|
blackhole.consume(matrix.symEigSvd(1e-10))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun tensorSymEigJacobi(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||||
|
blackhole.consume(matrix.symEigJacobi(50, 1e-10))
|
||||||
|
}
|
||||||
|
}
|
@ -1,16 +1,17 @@
|
|||||||
plugins {
|
plugins {
|
||||||
id("ru.mipt.npm.gradle.project")
|
id("ru.mipt.npm.gradle.project")
|
||||||
id("org.jetbrains.kotlinx.kover") version "0.5.0-RC"
|
id("org.jetbrains.kotlinx.kover") version "0.5.0"
|
||||||
}
|
}
|
||||||
|
|
||||||
allprojects {
|
allprojects {
|
||||||
repositories {
|
repositories {
|
||||||
|
maven("https://repo.kotlin.link")
|
||||||
maven("https://oss.sonatype.org/content/repositories/snapshots")
|
maven("https://oss.sonatype.org/content/repositories/snapshots")
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
group = "space.kscience"
|
group = "space.kscience"
|
||||||
version = "0.3.0-dev-17"
|
version = "0.3.1-dev-1"
|
||||||
}
|
}
|
||||||
|
|
||||||
subprojects {
|
subprojects {
|
||||||
@ -50,12 +51,24 @@ subprojects {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
plugins.withId("org.jetbrains.kotlin.multiplatform") {
|
||||||
|
configure<org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension> {
|
||||||
|
sourceSets {
|
||||||
|
val commonTest by getting {
|
||||||
|
dependencies {
|
||||||
|
implementation(projects.testUtils)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
readme.readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
readme.readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
||||||
|
|
||||||
ksciencePublish {
|
ksciencePublish {
|
||||||
github("kmath")
|
github("kmath", addToRelease = false)
|
||||||
space()
|
space()
|
||||||
sonatype()
|
sonatype()
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
plugins {
|
plugins {
|
||||||
|
kotlin("jvm") version "1.7.0"
|
||||||
`kotlin-dsl`
|
`kotlin-dsl`
|
||||||
`version-catalog`
|
`version-catalog`
|
||||||
alias(npmlibs.plugins.kotlin.plugin.serialization)
|
alias(npmlibs.plugins.kotlin.plugin.serialization)
|
||||||
@ -7,17 +8,19 @@ plugins {
|
|||||||
java.targetCompatibility = JavaVersion.VERSION_11
|
java.targetCompatibility = JavaVersion.VERSION_11
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
|
mavenLocal()
|
||||||
maven("https://repo.kotlin.link")
|
maven("https://repo.kotlin.link")
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
gradlePluginPortal()
|
gradlePluginPortal()
|
||||||
}
|
}
|
||||||
|
|
||||||
val toolsVersion: String by extra
|
val toolsVersion = npmlibs.versions.tools.get()
|
||||||
val kotlinVersion = npmlibs.versions.kotlin.asProvider().get()
|
val kotlinVersion = npmlibs.versions.kotlin.asProvider().get()
|
||||||
val benchmarksVersion = "0.4.2"
|
val benchmarksVersion = npmlibs.versions.kotlinx.benchmark.get()
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api("ru.mipt.npm:gradle-tools:$toolsVersion")
|
api("ru.mipt.npm:gradle-tools:$toolsVersion")
|
||||||
|
api(npmlibs.atomicfu.gradle)
|
||||||
//plugins form benchmarks
|
//plugins form benchmarks
|
||||||
api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion")
|
api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion")
|
||||||
api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")
|
api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2018-2021 KMath contributors.
|
|
||||||
# Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
|
||||||
#
|
|
||||||
|
|
||||||
kotlin.code.style=official
|
|
||||||
kotlin.mpp.stability.nowarn=true
|
|
||||||
|
|
||||||
kotlin.jupyter.add.scanner=false
|
|
||||||
|
|
||||||
org.gradle.configureondemand=true
|
|
||||||
org.gradle.parallel=true
|
|
||||||
|
|
||||||
toolsVersion=0.10.9-kotlin-1.6.10
|
|
@ -3,17 +3,26 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
|
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
|
||||||
enableFeaturePreview("VERSION_CATALOGS")
|
|
||||||
|
|
||||||
dependencyResolutionManagement {
|
dependencyResolutionManagement {
|
||||||
|
val projectProperties = java.util.Properties()
|
||||||
|
file("../gradle.properties").inputStream().use {
|
||||||
|
projectProperties.load(it)
|
||||||
|
}
|
||||||
|
|
||||||
val toolsVersion: String by extra
|
projectProperties.forEach { key, value ->
|
||||||
|
extra.set(key.toString(), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
val toolsVersion: String = projectProperties["toolsVersion"].toString()
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
|
mavenLocal()
|
||||||
maven("https://repo.kotlin.link")
|
maven("https://repo.kotlin.link")
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
|
gradlePluginPortal()
|
||||||
}
|
}
|
||||||
|
|
||||||
versionCatalogs {
|
versionCatalogs {
|
||||||
|
@ -319,7 +319,9 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra},
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
2
docs/templates/README-TEMPLATE.md
vendored
2
docs/templates/README-TEMPLATE.md
vendored
@ -52,7 +52,7 @@ module definitions below. The module stability could have the following levels:
|
|||||||
|
|
||||||
## Modules
|
## Modules
|
||||||
|
|
||||||
$modules
|
${modules}
|
||||||
|
|
||||||
## Multi-platform support
|
## Multi-platform support
|
||||||
|
|
||||||
|
4
examples/README.md
Normal file
4
examples/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Module examples
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -58,7 +58,7 @@ kotlin.sourceSets.all {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
|
tasks.withType<org.jetbrains.kotlin.gradle.dsl.KotlinJvmCompile> {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = "11"
|
jvmTarget = "11"
|
||||||
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy"
|
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy"
|
||||||
|
@ -13,7 +13,7 @@ import kotlin.math.pow
|
|||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
//Define a function
|
//Define a function
|
||||||
val function: UnivariateFunction<Double> = { x -> 3 * x.pow(2) + 2 * x + 1 }
|
val function: Function1D<Double> = { x -> 3 * x.pow(2) + 2 * x + 1 }
|
||||||
|
|
||||||
//get the result of the integration
|
//get the result of the integration
|
||||||
val result = DoubleField.gaussIntegrator.integrate(0.0..10.0, function = function)
|
val result = DoubleField.gaussIntegrator.integrate(0.0..10.0, function = function)
|
||||||
|
@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.functions
|
package space.kscience.kmath.functions
|
||||||
|
|
||||||
import space.kscience.kmath.interpolation.SplineInterpolator
|
|
||||||
import space.kscience.kmath.interpolation.interpolatePolynomials
|
import space.kscience.kmath.interpolation.interpolatePolynomials
|
||||||
|
import space.kscience.kmath.interpolation.splineInterpolator
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.real.map
|
import space.kscience.kmath.real.map
|
||||||
import space.kscience.kmath.real.step
|
import space.kscience.kmath.real.step
|
||||||
@ -18,7 +18,7 @@ import space.kscience.plotly.scatter
|
|||||||
|
|
||||||
@OptIn(UnstablePlotlyAPI::class)
|
@OptIn(UnstablePlotlyAPI::class)
|
||||||
fun main() {
|
fun main() {
|
||||||
val function: UnivariateFunction<Double> = { x ->
|
val function: Function1D<Double> = { x ->
|
||||||
if (x in 30.0..50.0) {
|
if (x in 30.0..50.0) {
|
||||||
1.0
|
1.0
|
||||||
} else {
|
} else {
|
||||||
@ -28,7 +28,7 @@ fun main() {
|
|||||||
val xs = 0.0..100.0 step 0.5
|
val xs = 0.0..100.0 step 0.5
|
||||||
val ys = xs.map(function)
|
val ys = xs.map(function)
|
||||||
|
|
||||||
val polynomial: PiecewisePolynomial<Double> = SplineInterpolator.double.interpolatePolynomials(xs, ys)
|
val polynomial: PiecewisePolynomial<Double> = DoubleField.splineInterpolator.interpolatePolynomials(xs, ys)
|
||||||
|
|
||||||
val polyFunction = polynomial.asFunction(DoubleField, 0.0)
|
val polyFunction = polynomial.asFunction(DoubleField, 0.0)
|
||||||
|
|
||||||
|
@ -2,14 +2,14 @@
|
|||||||
# Copyright 2018-2021 KMath contributors.
|
# Copyright 2018-2021 KMath contributors.
|
||||||
# Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
# Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
#
|
#
|
||||||
|
|
||||||
kotlin.code.style=official
|
kotlin.code.style=official
|
||||||
kotlin.mpp.stability.nowarn=true
|
|
||||||
|
|
||||||
kotlin.jupyter.add.scanner=false
|
kotlin.jupyter.add.scanner=false
|
||||||
|
kotlin.mpp.stability.nowarn=true
|
||||||
|
kotlin.native.ignoreDisabledTargets=true
|
||||||
|
//kotlin.incremental.js.ir=true
|
||||||
|
|
||||||
org.gradle.configureondemand=true
|
org.gradle.configureondemand=true
|
||||||
org.gradle.parallel=true
|
org.gradle.parallel=true
|
||||||
org.gradle.jvmargs=-XX:MaxMetaspaceSize=1G
|
org.gradle.jvmargs=-Xmx4096m
|
||||||
|
|
||||||
toolsVersion=0.11.1-kotlin-1.6.10
|
toolsVersion=0.11.7-kotlin-1.7.0
|
||||||
|
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
BIN
gradle/wrapper/gradle-wrapper.jar
vendored
Binary file not shown.
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
|||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.3.3-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.2-bin.zip
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Module kmath-ast
|
# Module kmath-ast
|
||||||
|
|
||||||
Performance and visualization extensions to MST API.
|
Extensions to MST API: transformations, dynamic compilation and visualization.
|
||||||
|
|
||||||
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
|
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
|
||||||
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
|
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
|
||||||
@ -10,17 +10,17 @@ Performance and visualization extensions to MST API.
|
|||||||
|
|
||||||
## Artifact:
|
## Artifact:
|
||||||
|
|
||||||
The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-17`.
|
The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0`.
|
||||||
|
|
||||||
**Gradle:**
|
**Gradle Groovy:**
|
||||||
```gradle
|
```groovy
|
||||||
repositories {
|
repositories {
|
||||||
maven { url 'https://repo.kotlin.link' }
|
maven { url 'https://repo.kotlin.link' }
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation 'space.kscience:kmath-ast:0.3.0-dev-17'
|
implementation 'space.kscience:kmath-ast:0.3.0'
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
**Gradle Kotlin DSL:**
|
**Gradle Kotlin DSL:**
|
||||||
@ -31,10 +31,30 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("space.kscience:kmath-ast:0.3.0-dev-17")
|
implementation("space.kscience:kmath-ast:0.3.0")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Parsing expressions
|
||||||
|
|
||||||
|
In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
|
||||||
|
|
||||||
|
Supported literals:
|
||||||
|
1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
|
||||||
|
2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`.
|
||||||
|
|
||||||
|
Supported binary operators (from the highest precedence to the lowest one):
|
||||||
|
1. `^`
|
||||||
|
2. `*`, `/`
|
||||||
|
3. `+`, `-`
|
||||||
|
|
||||||
|
Supported unary operator:
|
||||||
|
1. `-`, e. g. `-x`
|
||||||
|
|
||||||
|
Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
|
||||||
|
1. `sin(x)`
|
||||||
|
2. `add(x, y)`
|
||||||
|
|
||||||
## Dynamic expression code generation
|
## Dynamic expression code generation
|
||||||
|
|
||||||
### On JVM
|
### On JVM
|
||||||
@ -42,48 +62,66 @@ dependencies {
|
|||||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
||||||
special implementation of `Expression<T>` with implemented `invoke` function.
|
special implementation of `Expression<T>` with implemented `invoke` function.
|
||||||
|
|
||||||
For example, the following builder:
|
For example, the following code:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
import space.kscience.kmath.asm.compileToExpression
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.*
|
|
||||||
import space.kscience.kmath.asm.*
|
|
||||||
|
|
||||||
MstField { x + 2 }.compileToExpression(DoubleField)
|
"x^3-x+3".parseMath().compileToExpression(DoubleField)
|
||||||
```
|
```
|
||||||
|
|
||||||
... leads to generation of bytecode, which can be decompiled to the following Java class:
|
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||||
|
|
||||||
```java
|
```java
|
||||||
package space.kscience.kmath.asm.generated;
|
import java.util.*;
|
||||||
|
import kotlin.jvm.functions.*;
|
||||||
|
import space.kscience.kmath.asm.internal.*;
|
||||||
|
import space.kscience.kmath.complex.*;
|
||||||
|
import space.kscience.kmath.expressions.*;
|
||||||
|
|
||||||
import java.util.Map;
|
public final class CompiledExpression_45045_0 implements Expression<Complex> {
|
||||||
|
|
||||||
import kotlin.jvm.functions.Function2;
|
|
||||||
import space.kscience.kmath.asm.internal.MapIntrinsics;
|
|
||||||
import space.kscience.kmath.expressions.Expression;
|
|
||||||
import space.kscience.kmath.expressions.Symbol;
|
|
||||||
|
|
||||||
public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
|
||||||
private final Object[] constants;
|
private final Object[] constants;
|
||||||
|
|
||||||
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
|
public Complex invoke(Map<Symbol, ? extends Complex> arguments) {
|
||||||
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
|
Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
|
||||||
}
|
return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
|
||||||
|
|
||||||
public AsmCompiledExpression_45045_0(Object[] constants) {
|
|
||||||
this.constants = constants;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Known issues
|
For `LongRing`, `IntRing`, and `DoubleField` specialization is supported for better performance:
|
||||||
|
|
||||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
|
```java
|
||||||
loading overhead.
|
import java.util.*;
|
||||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
import space.kscience.kmath.asm.internal.*;
|
||||||
|
import space.kscience.kmath.expressions.*;
|
||||||
|
|
||||||
|
public final class CompiledExpression_-386104628_0 implements DoubleExpression {
|
||||||
|
private final SymbolIndexer indexer;
|
||||||
|
|
||||||
|
public SymbolIndexer getIndexer() {
|
||||||
|
return this.indexer;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double invoke(double[] arguments) {
|
||||||
|
double var2 = arguments[0];
|
||||||
|
return Math.pow(var2, 3.0D) - var2 + 3.0D;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
|
||||||
|
double var2 = ((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue();
|
||||||
|
return Math.pow(var2, 3.0D) - var2 + 3.0D;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
|
||||||
|
|
||||||
|
#### Limitations
|
||||||
|
|
||||||
|
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
|
||||||
|
- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
|
||||||
|
|
||||||
### On JS
|
### On JS
|
||||||
|
|
||||||
@ -121,15 +159,15 @@ MstField { x + 2 }.compileToExpression(DoubleField)
|
|||||||
An example of emitted Wasm IR in the form of WAT:
|
An example of emitted Wasm IR in the form of WAT:
|
||||||
|
|
||||||
```lisp
|
```lisp
|
||||||
(func $executable (param $0 f64) (result f64)
|
(func \$executable (param \$0 f64) (result f64)
|
||||||
(f64.add
|
(f64.add
|
||||||
(local.get $0)
|
(local.get \$0)
|
||||||
(f64.const 2)
|
(f64.const 2)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Known issues
|
#### Limitations
|
||||||
|
|
||||||
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
|
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
|
||||||
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
|
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
|
||||||
@ -161,10 +199,7 @@ public fun main() {
|
|||||||
|
|
||||||
Result LaTeX:
|
Result LaTeX:
|
||||||
|
|
||||||
<div style="background-color:white;">
|
$$\operatorname{exp}\\,\left(\sqrt{x}\right)-\frac{\frac{\operatorname{arcsin}\\,\left(2\\,x\right)}{2\times10^{10}+x^{3}}}{12}+x^{2/3}$$
|
||||||
|
|
||||||
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3})
|
|
||||||
</div>
|
|
||||||
|
|
||||||
Result MathML (can be used with MathJax or other renderers):
|
Result MathML (can be used with MathJax or other renderers):
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ kotlin.sourceSets {
|
|||||||
|
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api("com.github.h0tk3y.betterParse:better-parse:0.4.2")
|
api("com.github.h0tk3y.betterParse:better-parse:0.4.4")
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -57,7 +57,7 @@ tasks.dokkaHtml {
|
|||||||
|
|
||||||
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
|
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
|
||||||
tasks.jvmTest {
|
tasks.jvmTest {
|
||||||
jvmArgs = (jvmArgs ?: emptyList()) + listOf("-Dspace.kscience.kmath.ast.dump.generated.classes=1")
|
jvmArgs("-Dspace.kscience.kmath.ast.dump.generated.classes=1")
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
|
@ -1,11 +1,31 @@
|
|||||||
# Module kmath-ast
|
# Module kmath-ast
|
||||||
|
|
||||||
Performance and visualization extensions to MST API.
|
Extensions to MST API: transformations, dynamic compilation and visualization.
|
||||||
|
|
||||||
${features}
|
${features}
|
||||||
|
|
||||||
${artifact}
|
${artifact}
|
||||||
|
|
||||||
|
## Parsing expressions
|
||||||
|
|
||||||
|
In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
|
||||||
|
|
||||||
|
Supported literals:
|
||||||
|
1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
|
||||||
|
2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`—all parsed either as `kotlin.Long` or `kotlin.Double`.
|
||||||
|
|
||||||
|
Supported binary operators (from the highest precedence to the lowest one):
|
||||||
|
1. `^`
|
||||||
|
2. `*`, `/`
|
||||||
|
3. `+`, `-`
|
||||||
|
|
||||||
|
Supported unary operator:
|
||||||
|
1. `-`, e. g. `-x`
|
||||||
|
|
||||||
|
Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
|
||||||
|
1. `sin(x)`
|
||||||
|
2. `add(x, y)`
|
||||||
|
|
||||||
## Dynamic expression code generation
|
## Dynamic expression code generation
|
||||||
|
|
||||||
### On JVM
|
### On JVM
|
||||||
@ -13,48 +33,66 @@ ${artifact}
|
|||||||
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
|
||||||
special implementation of `Expression<T>` with implemented `invoke` function.
|
special implementation of `Expression<T>` with implemented `invoke` function.
|
||||||
|
|
||||||
For example, the following builder:
|
For example, the following code:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
import space.kscience.kmath.asm.compileToExpression
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.*
|
|
||||||
import space.kscience.kmath.asm.*
|
|
||||||
|
|
||||||
MstField { x + 2 }.compileToExpression(DoubleField)
|
"x^3-x+3".parseMath().compileToExpression(DoubleField)
|
||||||
```
|
```
|
||||||
|
|
||||||
... leads to generation of bytecode, which can be decompiled to the following Java class:
|
… leads to generation of bytecode, which can be decompiled to the following Java class:
|
||||||
|
|
||||||
```java
|
```java
|
||||||
package space.kscience.kmath.asm.generated;
|
import java.util.*;
|
||||||
|
import kotlin.jvm.functions.*;
|
||||||
|
import space.kscience.kmath.asm.internal.*;
|
||||||
|
import space.kscience.kmath.complex.*;
|
||||||
|
import space.kscience.kmath.expressions.*;
|
||||||
|
|
||||||
import java.util.Map;
|
public final class CompiledExpression_45045_0 implements Expression<Complex> {
|
||||||
|
|
||||||
import kotlin.jvm.functions.Function2;
|
|
||||||
import space.kscience.kmath.asm.internal.MapIntrinsics;
|
|
||||||
import space.kscience.kmath.expressions.Expression;
|
|
||||||
import space.kscience.kmath.expressions.Symbol;
|
|
||||||
|
|
||||||
public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
|
|
||||||
private final Object[] constants;
|
private final Object[] constants;
|
||||||
|
|
||||||
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
|
public Complex invoke(Map<Symbol, ? extends Complex> arguments) {
|
||||||
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
|
Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
|
||||||
}
|
return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
|
||||||
|
|
||||||
public AsmCompiledExpression_45045_0(Object[] constants) {
|
|
||||||
this.constants = constants;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Known issues
|
For `LongRing`, `IntRing`, and `DoubleField` specialization is supported for better performance:
|
||||||
|
|
||||||
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class
|
```java
|
||||||
loading overhead.
|
import java.util.*;
|
||||||
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders.
|
import space.kscience.kmath.asm.internal.*;
|
||||||
|
import space.kscience.kmath.expressions.*;
|
||||||
|
|
||||||
|
public final class CompiledExpression_-386104628_0 implements DoubleExpression {
|
||||||
|
private final SymbolIndexer indexer;
|
||||||
|
|
||||||
|
public SymbolIndexer getIndexer() {
|
||||||
|
return this.indexer;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double invoke(double[] arguments) {
|
||||||
|
double var2 = arguments[0];
|
||||||
|
return Math.pow(var2, 3.0D) - var2 + 3.0D;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
|
||||||
|
double var2 = ((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue();
|
||||||
|
return Math.pow(var2, 3.0D) - var2 + 3.0D;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
|
||||||
|
|
||||||
|
#### Limitations
|
||||||
|
|
||||||
|
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
|
||||||
|
- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
|
||||||
|
|
||||||
### On JS
|
### On JS
|
||||||
|
|
||||||
@ -100,7 +138,7 @@ An example of emitted Wasm IR in the form of WAT:
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Known issues
|
#### Limitations
|
||||||
|
|
||||||
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
|
- ESTree expression compilation uses `eval` which can be unavailable in several environments.
|
||||||
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
|
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
|
||||||
@ -132,10 +170,7 @@ public fun main() {
|
|||||||
|
|
||||||
Result LaTeX:
|
Result LaTeX:
|
||||||
|
|
||||||
<div style="background-color:white;">
|
$$\operatorname{exp}\\,\left(\sqrt{x}\right)-\frac{\frac{\operatorname{arcsin}\\,\left(2\\,x\right)}{2\times10^{10}+x^{3}}}{12}+x^{2/3}$$
|
||||||
|
|
||||||
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3})
|
|
||||||
</div>
|
|
||||||
|
|
||||||
Result MathML (can be used with MathJax or other renderers):
|
Result MathML (can be used with MathJax or other renderers):
|
||||||
|
|
||||||
|
@ -0,0 +1,177 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.kmath.expressions.Expression
|
||||||
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.Algebra
|
||||||
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MST form where all values belong to the type [T]. It is optimal for constant folding, dynamic compilation, etc.
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public sealed interface TypedMst<T> {
|
||||||
|
/**
|
||||||
|
* A node containing a unary operation.
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
* @property operation The identifier of operation.
|
||||||
|
* @property function The function implementing this operation.
|
||||||
|
* @property value The argument of this operation.
|
||||||
|
*/
|
||||||
|
public class Unary<T>(public val operation: String, public val function: (T) -> T, public val value: TypedMst<T>) :
|
||||||
|
TypedMst<T> {
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
other as Unary<*>
|
||||||
|
if (operation != other.operation) return false
|
||||||
|
if (value != other.value) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = operation.hashCode()
|
||||||
|
result = 31 * result + value.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "Unary(operation=$operation, value=$value)"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A node containing binary operation.
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
* @property operation The identifier of operation.
|
||||||
|
* @property function The binary function implementing this operation.
|
||||||
|
* @property left The left operand.
|
||||||
|
* @property right The right operand.
|
||||||
|
*/
|
||||||
|
public class Binary<T>(
|
||||||
|
public val operation: String,
|
||||||
|
public val function: Function<T>,
|
||||||
|
public val left: TypedMst<T>,
|
||||||
|
public val right: TypedMst<T>,
|
||||||
|
) : TypedMst<T> {
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
|
||||||
|
other as Binary<*>
|
||||||
|
|
||||||
|
if (operation != other.operation) return false
|
||||||
|
if (left != other.left) return false
|
||||||
|
if (right != other.right) return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = operation.hashCode()
|
||||||
|
result = 31 * result + left.hashCode()
|
||||||
|
result = 31 * result + right.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "Binary(operation=$operation, left=$left, right=$right)"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The non-numeric constant value.
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
* @property value The held value.
|
||||||
|
* @property number The number this value corresponds.
|
||||||
|
*/
|
||||||
|
public class Constant<T>(public val value: T, public val number: Number?) : TypedMst<T> {
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
other as Constant<*>
|
||||||
|
if (value != other.value) return false
|
||||||
|
if (number != other.number) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = value?.hashCode() ?: 0
|
||||||
|
result = 31 * result + (number?.hashCode() ?: 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "Constant(value=$value, number=$number)"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The node containing a variable
|
||||||
|
*
|
||||||
|
* @param T the type.
|
||||||
|
* @property symbol The symbol of the variable.
|
||||||
|
*/
|
||||||
|
public class Variable<T>(public val symbol: Symbol) : TypedMst<T> {
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
other as Variable<*>
|
||||||
|
if (symbol != other.symbol) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int = symbol.hashCode()
|
||||||
|
override fun toString(): String = "Variable(symbol=$symbol)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interprets the [TypedMst] node with this [Algebra] and [arguments].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = when (this) {
|
||||||
|
is TypedMst.Unary -> algebra.unaryOperation(operation, interpret(algebra, arguments))
|
||||||
|
|
||||||
|
is TypedMst.Binary -> when {
|
||||||
|
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null ->
|
||||||
|
algebra.leftSideNumberOperation(operation, left.number, right.interpret(algebra, arguments))
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null ->
|
||||||
|
algebra.rightSideNumberOperation(operation, left.interpret(algebra, arguments), right.number)
|
||||||
|
|
||||||
|
else -> algebra.binaryOperation(
|
||||||
|
operation,
|
||||||
|
left.interpret(algebra, arguments),
|
||||||
|
right.interpret(algebra, arguments),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
is TypedMst.Constant -> value
|
||||||
|
is TypedMst.Variable -> arguments.getValue(symbol)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interprets the [TypedMst] node with this [Algebra] and optional [arguments].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = interpret(
|
||||||
|
algebra,
|
||||||
|
when (arguments.size) {
|
||||||
|
0 -> emptyMap()
|
||||||
|
1 -> mapOf(arguments[0])
|
||||||
|
else -> hashMapOf(*arguments)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interpret this [TypedMst] node as expression.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T : Any> TypedMst<T>.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
|
||||||
|
interpret(algebra, arguments)
|
||||||
|
}
|
@ -0,0 +1,93 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.kmath.expressions.MST
|
||||||
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.Algebra
|
||||||
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
import space.kscience.kmath.operations.bindSymbolOrNull
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluates constants in given [MST] for given [algebra] at the same time with converting to [TypedMst].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (this) {
|
||||||
|
is MST.Numeric -> TypedMst.Constant(
|
||||||
|
(algebra as? NumericAlgebra<T>)?.number(value) ?: error("Numeric nodes are not supported by $algebra"),
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
|
is MST.Unary -> when (val arg = value.evaluateConstants(algebra)) {
|
||||||
|
is TypedMst.Constant<T> -> {
|
||||||
|
val value = algebra.unaryOperation(
|
||||||
|
operation,
|
||||||
|
arg.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
TypedMst.Constant(value, if (value is Number) value else null)
|
||||||
|
}
|
||||||
|
|
||||||
|
else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
is MST.Binary -> {
|
||||||
|
val left = left.evaluateConstants(algebra)
|
||||||
|
val right = right.evaluateConstants(algebra)
|
||||||
|
|
||||||
|
when {
|
||||||
|
left is TypedMst.Constant<T> && right is TypedMst.Constant<T> -> {
|
||||||
|
val value = when {
|
||||||
|
algebra is NumericAlgebra && left.number != null -> algebra.leftSideNumberOperation(
|
||||||
|
operation,
|
||||||
|
left.number,
|
||||||
|
right.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && right.number != null -> algebra.rightSideNumberOperation(
|
||||||
|
operation,
|
||||||
|
left.value,
|
||||||
|
right.number,
|
||||||
|
)
|
||||||
|
|
||||||
|
else -> algebra.binaryOperation(
|
||||||
|
operation,
|
||||||
|
left.value,
|
||||||
|
right.value,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
TypedMst.Constant(value, if (value is Number) value else null)
|
||||||
|
}
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary(
|
||||||
|
operation,
|
||||||
|
algebra.leftSideNumberOperationFunction(operation),
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
)
|
||||||
|
|
||||||
|
algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null -> TypedMst.Binary(
|
||||||
|
operation,
|
||||||
|
algebra.rightSideNumberOperationFunction(operation),
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
)
|
||||||
|
|
||||||
|
else -> TypedMst.Binary(operation, algebra.binaryOperationFunction(operation), left, right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
is Symbol -> {
|
||||||
|
val boundSymbol = algebra.bindSymbolOrNull(this)
|
||||||
|
|
||||||
|
if (boundSymbol != null)
|
||||||
|
TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null)
|
||||||
|
else
|
||||||
|
TypedMst.Variable(this)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,52 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.kmath.operations.ByteRing
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.IntRing
|
||||||
|
import space.kscience.kmath.operations.pi
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
|
internal class TestFolding {
|
||||||
|
@Test
|
||||||
|
fun foldUnary() = assertEquals(
|
||||||
|
-1,
|
||||||
|
("-(1)".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldDeepUnary() = assertEquals(
|
||||||
|
1,
|
||||||
|
("-(-(1))".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldBinary() = assertEquals(
|
||||||
|
2,
|
||||||
|
("1*2".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldDeepBinary() = assertEquals(
|
||||||
|
10,
|
||||||
|
("1*2*5".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldSymbol() = assertEquals(
|
||||||
|
DoubleField.pi,
|
||||||
|
("pi".parseMath().evaluateConstants(DoubleField) as? TypedMst.Constant<Double> ?: fail()).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun foldNumeric() = assertEquals(
|
||||||
|
42.toByte(),
|
||||||
|
("42".parseMath().evaluateConstants(ByteRing) as? TypedMst.Constant<Byte> ?: fail()).value,
|
||||||
|
)
|
||||||
|
}
|
@ -5,87 +5,48 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.estree
|
package space.kscience.kmath.estree
|
||||||
|
|
||||||
|
import space.kscience.kmath.ast.TypedMst
|
||||||
|
import space.kscience.kmath.ast.evaluateConstants
|
||||||
import space.kscience.kmath.estree.internal.ESTreeBuilder
|
import space.kscience.kmath.estree.internal.ESTreeBuilder
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
import space.kscience.kmath.expressions.MST
|
import space.kscience.kmath.expressions.MST
|
||||||
import space.kscience.kmath.expressions.MST.*
|
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.internal.estree.BaseExpression
|
import space.kscience.kmath.internal.estree.BaseExpression
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
import space.kscience.kmath.operations.NumericAlgebra
|
|
||||||
import space.kscience.kmath.operations.bindSymbolOrNull
|
|
||||||
|
|
||||||
@PublishedApi
|
|
||||||
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
|
||||||
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
|
|
||||||
is Symbol -> {
|
|
||||||
val symbol = algebra.bindSymbolOrNull(node)
|
|
||||||
|
|
||||||
if (symbol != null)
|
|
||||||
constant(symbol)
|
|
||||||
else
|
|
||||||
variable(node.identity)
|
|
||||||
}
|
|
||||||
|
|
||||||
is Numeric -> constant(
|
|
||||||
(algebra as? NumericAlgebra<T>)?.number(node.value) ?: error("Numeric nodes are not supported by $this")
|
|
||||||
)
|
|
||||||
|
|
||||||
is Unary -> when {
|
|
||||||
algebra is NumericAlgebra && node.value is Numeric -> constant(
|
|
||||||
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
|
||||||
}
|
|
||||||
|
|
||||||
is Binary -> when {
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
|
|
||||||
algebra.binaryOperationFunction(node.operation).invoke(
|
|
||||||
algebra.number((node.left as Numeric).value),
|
|
||||||
algebra.number((node.right as Numeric).value)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric -> call(
|
|
||||||
algebra.leftSideNumberOperationFunction(node.operation),
|
|
||||||
visit(node.left),
|
|
||||||
visit(node.right),
|
|
||||||
)
|
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.right is Numeric -> call(
|
|
||||||
algebra.rightSideNumberOperationFunction(node.operation),
|
|
||||||
visit(node.left),
|
|
||||||
visit(node.right),
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> call(
|
|
||||||
algebra.binaryOperationFunction(node.operation),
|
|
||||||
visit(node.left),
|
|
||||||
visit(node.right),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a compiled expression with given [MST] and given [algebra].
|
* Create a compiled expression with given [MST] and given [algebra].
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = compileWith(algebra)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
if (typed is TypedMst.Constant<T>) return Expression { typed.value }
|
||||||
|
|
||||||
|
fun ESTreeBuilder<T>.visit(node: TypedMst<T>): BaseExpression = when (node) {
|
||||||
|
is TypedMst.Constant -> constant(node.value)
|
||||||
|
is TypedMst.Variable -> variable(node.symbol)
|
||||||
|
is TypedMst.Unary -> call(node.function, visit(node.value))
|
||||||
|
|
||||||
|
is TypedMst.Binary -> call(
|
||||||
|
node.function,
|
||||||
|
visit(node.left),
|
||||||
|
visit(node.right),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ESTreeBuilder { visit(typed) }.instance
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments]
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
public fun <T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments]
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
public fun <T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
||||||
compileToExpression(algebra).invoke(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
@ -22,28 +22,20 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Suppress("UNUSED_VARIABLE")
|
||||||
val instance: Expression<T> by lazy {
|
val instance: Expression<T> by lazy {
|
||||||
val node = Program(
|
val node = Program(
|
||||||
sourceType = "script",
|
sourceType = "script",
|
||||||
VariableDeclaration(
|
ReturnStatement(bodyCallback())
|
||||||
kind = "var",
|
|
||||||
VariableDeclarator(
|
|
||||||
id = Identifier("executable"),
|
|
||||||
init = FunctionExpression(
|
|
||||||
params = arrayOf(Identifier("constants"), Identifier("arguments")),
|
|
||||||
body = BlockStatement(ReturnStatement(bodyCallback())),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
eval(generate(node))
|
val code = generate(node)
|
||||||
GeneratedExpression(js("executable"), constants.toTypedArray())
|
GeneratedExpression(js("new Function('constants', 'arguments_0', code)"), constants.toTypedArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
private val constants = mutableListOf<Any>()
|
private val constants = mutableListOf<Any>()
|
||||||
|
|
||||||
fun constant(value: Any?) = when {
|
fun constant(value: Any?): BaseExpression = when {
|
||||||
value == null || jsTypeOf(value) == "number" || jsTypeOf(value) == "string" || jsTypeOf(value) == "boolean" ->
|
value == null || jsTypeOf(value) == "number" || jsTypeOf(value) == "string" || jsTypeOf(value) == "boolean" ->
|
||||||
SimpleLiteral(value)
|
SimpleLiteral(value)
|
||||||
|
|
||||||
@ -61,7 +53,8 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name))
|
fun variable(name: Symbol): BaseExpression =
|
||||||
|
call(getOrFail, Identifier("arguments_0"), SimpleLiteral(name.identity))
|
||||||
|
|
||||||
fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
|
fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
|
||||||
optional = false,
|
optional = false,
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@file:Suppress("unused")
|
||||||
|
|
||||||
package space.kscience.kmath.internal.estree
|
package space.kscience.kmath.internal.estree
|
||||||
|
|
||||||
internal fun Program(sourceType: String, vararg body: dynamic) = object : Program {
|
internal fun Program(sourceType: String, vararg body: dynamic) = object : Program {
|
||||||
@ -28,9 +30,10 @@ internal fun Identifier(name: String) = object : Identifier {
|
|||||||
override var name = name
|
override var name = name
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun FunctionExpression(params: Array<dynamic>, body: BlockStatement) = object : FunctionExpression {
|
internal fun FunctionExpression(id: Identifier?, params: Array<dynamic>, body: BlockStatement) = object : FunctionExpression {
|
||||||
override var params = params
|
override var params = params
|
||||||
override var type = "FunctionExpression"
|
override var type = "FunctionExpression"
|
||||||
|
override var id: Identifier? = id
|
||||||
override var body = body
|
override var body = body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.wasm.internal
|
package space.kscience.kmath.wasm.internal
|
||||||
|
|
||||||
|
import space.kscience.kmath.ast.TypedMst
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.MST.*
|
|
||||||
import space.kscience.kmath.internal.binaryen.*
|
import space.kscience.kmath.internal.binaryen.*
|
||||||
import space.kscience.kmath.internal.webassembly.Instance
|
import space.kscience.kmath.internal.webassembly.Instance
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
@ -16,11 +16,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
|
|||||||
|
|
||||||
private val spreader = eval("(obj, args) => obj(...args)")
|
private val spreader = eval("(obj, args) => obj(...args)")
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
@Suppress("UnsafeCastFromDynamic")
|
@Suppress("UnsafeCastFromDynamic")
|
||||||
internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
|
internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
|
||||||
protected val binaryenType: Type,
|
protected val binaryenType: Type,
|
||||||
protected val algebra: Algebra<T>,
|
protected val algebra: Algebra<T>,
|
||||||
protected val target: MST,
|
protected val target: TypedMst<T>,
|
||||||
) {
|
) {
|
||||||
protected val keys: MutableList<Symbol> = mutableListOf()
|
protected val keys: MutableList<Symbol> = mutableListOf()
|
||||||
protected lateinit var ctx: BinaryenModule
|
protected lateinit var ctx: BinaryenModule
|
||||||
@ -51,59 +52,41 @@ internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
|
|||||||
Instance(c, js("{}")).exports.executable
|
Instance(c, js("{}")).exports.executable
|
||||||
}
|
}
|
||||||
|
|
||||||
protected open fun visitSymbol(node: Symbol): ExpressionRef {
|
protected abstract fun visitNumber(number: Number): ExpressionRef
|
||||||
algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) }
|
|
||||||
|
|
||||||
var idx = keys.indexOf(node)
|
protected open fun visitVariable(node: TypedMst.Variable<T>): ExpressionRef {
|
||||||
|
var idx = keys.indexOf(node.symbol)
|
||||||
|
|
||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
keys += node
|
keys += node.symbol
|
||||||
idx = keys.lastIndex
|
idx = keys.lastIndex
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx.local.get(idx, binaryenType)
|
return ctx.local.get(idx, binaryenType)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract fun visitNumeric(node: Numeric): ExpressionRef
|
protected open fun visitUnary(node: TypedMst.Unary<T>): ExpressionRef =
|
||||||
|
|
||||||
protected open fun visitUnary(node: Unary): ExpressionRef =
|
|
||||||
error("Unary operation ${node.operation} not defined in $this")
|
error("Unary operation ${node.operation} not defined in $this")
|
||||||
|
|
||||||
protected open fun visitBinary(mst: Binary): ExpressionRef =
|
protected open fun visitBinary(mst: TypedMst.Binary<T>): ExpressionRef =
|
||||||
error("Binary operation ${mst.operation} not defined in $this")
|
error("Binary operation ${mst.operation} not defined in $this")
|
||||||
|
|
||||||
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
|
protected open fun createModule(): BinaryenModule = space.kscience.kmath.internal.binaryen.Module()
|
||||||
|
|
||||||
protected fun visit(node: MST): ExpressionRef = when (node) {
|
protected fun visit(node: TypedMst<T>): ExpressionRef = when (node) {
|
||||||
is Symbol -> visitSymbol(node)
|
is TypedMst.Constant -> visitNumber(
|
||||||
is Numeric -> visitNumeric(node)
|
node.number ?: error("Object constants are not supported by pritimive ASM builder"),
|
||||||
|
|
||||||
is Unary -> when {
|
|
||||||
algebra is NumericAlgebra && node.value is Numeric -> visitNumeric(
|
|
||||||
Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else -> visitUnary(node)
|
is TypedMst.Variable -> visitVariable(node)
|
||||||
}
|
is TypedMst.Unary -> visitUnary(node)
|
||||||
|
is TypedMst.Binary -> visitBinary(node)
|
||||||
is Binary -> when {
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric(
|
|
||||||
Numeric(
|
|
||||||
algebra.binaryOperationFunction(node.operation)
|
|
||||||
.invoke(
|
|
||||||
algebra.number((node.left as Numeric).value),
|
|
||||||
algebra.number((node.right as Numeric).value)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> visitBinary(node)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) {
|
internal class DoubleWasmBuilder(target: TypedMst<Double>) :
|
||||||
|
WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) {
|
||||||
override val instance by lazy {
|
override val instance by lazy {
|
||||||
object : DoubleExpression {
|
object : DoubleExpression {
|
||||||
override val indexer = SimpleSymbolIndexer(keys)
|
override val indexer = SimpleSymbolIndexer(keys)
|
||||||
@ -114,9 +97,9 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
|
|||||||
|
|
||||||
override fun createModule() = readBinary(f64StandardFunctions)
|
override fun createModule() = readBinary(f64StandardFunctions)
|
||||||
|
|
||||||
override fun visitNumeric(node: Numeric) = ctx.f64.const(node.value.toDouble())
|
override fun visitNumber(number: Number) = ctx.f64.const(number.toDouble())
|
||||||
|
|
||||||
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
|
override fun visitUnary(node: TypedMst.Unary<Double>): ExpressionRef = when (node.operation) {
|
||||||
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
|
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
|
||||||
GroupOps.PLUS_OPERATION -> visit(node.value)
|
GroupOps.PLUS_OPERATION -> visit(node.value)
|
||||||
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
|
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
|
||||||
@ -137,7 +120,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
|
|||||||
else -> super.visitUnary(node)
|
else -> super.visitUnary(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: TypedMst.Binary<Double>): ExpressionRef = when (mst.operation) {
|
||||||
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
||||||
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
||||||
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
||||||
@ -148,7 +131,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) {
|
internal class IntWasmBuilder(target: TypedMst<Int>) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) {
|
||||||
override val instance by lazy {
|
override val instance by lazy {
|
||||||
object : IntExpression {
|
object : IntExpression {
|
||||||
override val indexer = SimpleSymbolIndexer(keys)
|
override val indexer = SimpleSymbolIndexer(keys)
|
||||||
@ -157,15 +140,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value.toInt())
|
override fun visitNumber(number: Number) = ctx.i32.const(number.toInt())
|
||||||
|
|
||||||
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
|
override fun visitUnary(node: TypedMst.Unary<Int>): ExpressionRef = when (node.operation) {
|
||||||
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
|
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
|
||||||
GroupOps.PLUS_OPERATION -> visit(node.value)
|
GroupOps.PLUS_OPERATION -> visit(node.value)
|
||||||
else -> super.visitUnary(node)
|
else -> super.visitUnary(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: TypedMst.Binary<Int>): ExpressionRef = when (mst.operation) {
|
||||||
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
||||||
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
||||||
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
||||||
|
@ -7,7 +7,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.wasm
|
package space.kscience.kmath.wasm
|
||||||
|
|
||||||
import space.kscience.kmath.estree.compileWith
|
import space.kscience.kmath.ast.TypedMst
|
||||||
|
import space.kscience.kmath.ast.evaluateConstants
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
@ -21,8 +22,16 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: IntRing): IntExpression {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant) object : IntExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: IntArray): Int = typed.value
|
||||||
|
} else
|
||||||
|
IntWasmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -31,7 +40,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBui
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -49,7 +58,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleWasmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant) object : DoubleExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: DoubleArray): Double = typed.value
|
||||||
|
} else
|
||||||
|
DoubleWasmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -59,7 +77,7 @@ public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = D
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -69,4 +87,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
@ -8,10 +8,14 @@
|
|||||||
package space.kscience.kmath.asm
|
package space.kscience.kmath.asm
|
||||||
|
|
||||||
import space.kscience.kmath.asm.internal.*
|
import space.kscience.kmath.asm.internal.*
|
||||||
|
import space.kscience.kmath.ast.TypedMst
|
||||||
|
import space.kscience.kmath.ast.evaluateConstants
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.MST.*
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.Algebra
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.IntRing
|
||||||
|
import space.kscience.kmath.operations.LongRing
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compiles given MST to an Expression using AST compiler.
|
* Compiles given MST to an Expression using AST compiler.
|
||||||
@ -21,102 +25,64 @@ import space.kscience.kmath.operations.*
|
|||||||
* @return the compiled expression.
|
* @return the compiled expression.
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun GenericAsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) {
|
val typed = evaluateConstants(algebra)
|
||||||
is Symbol -> prepareVariable(node.identity)
|
if (typed is TypedMst.Constant<T>) return Expression { typed.value }
|
||||||
is Unary -> variablesVisitor(node.value)
|
|
||||||
|
|
||||||
is Binary -> {
|
fun GenericAsmBuilder<T>.variablesVisitor(node: TypedMst<T>): Unit = when (node) {
|
||||||
|
is TypedMst.Unary -> variablesVisitor(node.value)
|
||||||
|
|
||||||
|
is TypedMst.Binary -> {
|
||||||
variablesVisitor(node.left)
|
variablesVisitor(node.left)
|
||||||
variablesVisitor(node.right)
|
variablesVisitor(node.right)
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> Unit
|
is TypedMst.Variable -> prepareVariable(node.symbol)
|
||||||
|
is TypedMst.Constant -> Unit
|
||||||
}
|
}
|
||||||
|
|
||||||
fun GenericAsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) {
|
fun GenericAsmBuilder<T>.expressionVisitor(node: TypedMst<T>): Unit = when (node) {
|
||||||
is Symbol -> {
|
is TypedMst.Constant -> if (node.number != null)
|
||||||
val symbol = algebra.bindSymbolOrNull(node)
|
loadNumberConstant(node.number)
|
||||||
|
|
||||||
if (symbol != null)
|
|
||||||
loadObjectConstant(symbol as Any)
|
|
||||||
else
|
else
|
||||||
loadVariable(node.identity)
|
loadObjectConstant(node.value)
|
||||||
}
|
|
||||||
|
|
||||||
is Numeric -> if (algebra is NumericAlgebra) {
|
is TypedMst.Variable -> loadVariable(node.symbol)
|
||||||
if (Number::class.java.isAssignableFrom(type))
|
is TypedMst.Unary -> buildCall(node.function) { expressionVisitor(node.value) }
|
||||||
loadNumberConstant(algebra.number(node.value) as Number)
|
|
||||||
else
|
|
||||||
loadObjectConstant(algebra.number(node.value))
|
|
||||||
} else
|
|
||||||
error("Numeric nodes are not supported by $this")
|
|
||||||
|
|
||||||
is Unary -> when {
|
is TypedMst.Binary -> buildCall(node.function) {
|
||||||
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
|
||||||
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)),
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
is Binary -> when {
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
|
||||||
algebra.binaryOperationFunction(node.operation).invoke(
|
|
||||||
algebra.number((node.left as Numeric).value),
|
|
||||||
algebra.number((node.right as Numeric).value),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
|
||||||
algebra.leftSideNumberOperationFunction(node.operation),
|
|
||||||
) {
|
|
||||||
expressionVisitor(node.left)
|
expressionVisitor(node.left)
|
||||||
expressionVisitor(node.right)
|
expressionVisitor(node.right)
|
||||||
}
|
}
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
|
||||||
algebra.rightSideNumberOperationFunction(node.operation),
|
|
||||||
) {
|
|
||||||
expressionVisitor(node.left)
|
|
||||||
expressionVisitor(node.right)
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
|
|
||||||
expressionVisitor(node.left)
|
|
||||||
expressionVisitor(node.right)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return GenericAsmBuilder<T>(
|
return GenericAsmBuilder<T>(
|
||||||
type,
|
type,
|
||||||
buildName(this),
|
buildName("${typed.hashCode()}_${type.simpleName}"),
|
||||||
{ variablesVisitor(this@compileWith) },
|
{ variablesVisitor(typed) },
|
||||||
{ expressionVisitor(this@compileWith) },
|
{ expressionVisitor(typed) },
|
||||||
).instance
|
).instance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a compiled expression with given [MST] and given [algebra].
|
* Create a compiled expression with given [MST] and given [algebra].
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
|
public inline fun <reified T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
|
||||||
compileWith(T::class.java, algebra)
|
compileWith(T::class.java, algebra)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments]
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments]
|
* Compile given MST to expression and evaluate it against [arguments]
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
|
||||||
compileToExpression(algebra).invoke(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -125,7 +91,16 @@ public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg argu
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: IntRing): IntExpression {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant) object : IntExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: IntArray): Int = typed.value
|
||||||
|
} else
|
||||||
|
IntAsmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -134,7 +109,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuil
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -152,8 +127,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: LongRing): LongExpression {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant<Long>) object : LongExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: LongArray): Long = typed.value
|
||||||
|
} else
|
||||||
|
LongAsmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -162,7 +145,7 @@ public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmB
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
|
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -181,7 +164,17 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>):
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance
|
public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression {
|
||||||
|
val typed = evaluateConstants(algebra)
|
||||||
|
|
||||||
|
return if (typed is TypedMst.Constant) object : DoubleExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(emptyList())
|
||||||
|
|
||||||
|
override fun invoke(arguments: DoubleArray): Double = typed.value
|
||||||
|
} else
|
||||||
|
DoubleAsmBuilder(typed).instance
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -190,7 +183,7 @@ public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = Dou
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra)(arguments)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
@ -199,4 +192,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
@ -49,5 +49,7 @@ internal abstract class AsmBuilder {
|
|||||||
* ASM Type for [space.kscience.kmath.expressions.Symbol].
|
* ASM Type for [space.kscience.kmath.expressions.Symbol].
|
||||||
*/
|
*/
|
||||||
val SYMBOL_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Symbol")
|
val SYMBOL_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Symbol")
|
||||||
|
|
||||||
|
const val ARGUMENTS_NAME = "args"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ internal class GenericAsmBuilder<T>(
|
|||||||
/**
|
/**
|
||||||
* Local variables indices are indices of symbols in this list.
|
* Local variables indices are indices of symbols in this list.
|
||||||
*/
|
*/
|
||||||
private val argumentsLocals = mutableListOf<String>()
|
private val argumentsLocals = mutableListOf<Symbol>()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subclasses, loads and instantiates [Expression] for given parameters.
|
* Subclasses, loads and instantiates [Expression] for given parameters.
|
||||||
@ -253,10 +253,10 @@ internal class GenericAsmBuilder<T>(
|
|||||||
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
|
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
|
||||||
* [loadVariable].
|
* [loadVariable].
|
||||||
*/
|
*/
|
||||||
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
|
fun prepareVariable(name: Symbol): Unit = invokeMethodVisitor.run {
|
||||||
if (name in argumentsLocals) return@run
|
if (name in argumentsLocals) return@run
|
||||||
load(1, MAP_TYPE)
|
load(1, MAP_TYPE)
|
||||||
aconst(name)
|
aconst(name.identity)
|
||||||
|
|
||||||
invokestatic(
|
invokestatic(
|
||||||
MAP_INTRINSICS_TYPE.internalName,
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
@ -280,7 +280,7 @@ internal class GenericAsmBuilder<T>(
|
|||||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
|
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
|
||||||
* with [prepareVariable] first.
|
* with [prepareVariable] first.
|
||||||
*/
|
*/
|
||||||
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
|
fun loadVariable(name: Symbol): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
|
||||||
|
|
||||||
inline fun buildCall(function: Function<T>, parameters: GenericAsmBuilder<T>.() -> Unit) {
|
inline fun buildCall(function: Function<T>, parameters: GenericAsmBuilder<T>.() -> Unit) {
|
||||||
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
@ -11,6 +11,7 @@ import org.objectweb.asm.Opcodes.*
|
|||||||
import org.objectweb.asm.Type
|
import org.objectweb.asm.Type
|
||||||
import org.objectweb.asm.Type.*
|
import org.objectweb.asm.Type.*
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
|
import space.kscience.kmath.ast.TypedMst
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
@ -25,9 +26,9 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
|
|||||||
classOfT: Class<*>,
|
classOfT: Class<*>,
|
||||||
protected val classOfTPrimitive: Class<*>,
|
protected val classOfTPrimitive: Class<*>,
|
||||||
expressionParent: Class<E>,
|
expressionParent: Class<E>,
|
||||||
protected val target: MST,
|
protected val target: TypedMst<T>,
|
||||||
) : AsmBuilder() {
|
) : AsmBuilder() {
|
||||||
private val className: String = buildName(target)
|
private val className: String = buildName("${target.hashCode()}_${classOfT.simpleName}")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [tType].
|
* ASM type for [tType].
|
||||||
@ -329,63 +330,39 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
private fun visitVariables(
|
private fun visitVariables(
|
||||||
node: MST,
|
node: TypedMst<T>,
|
||||||
arrayMode: Boolean,
|
arrayMode: Boolean,
|
||||||
alreadyLoaded: MutableList<Symbol> = mutableListOf()
|
alreadyLoaded: MutableList<Symbol> = mutableListOf()
|
||||||
): Unit = when (node) {
|
): Unit = when (node) {
|
||||||
is Symbol -> when (node) {
|
is TypedMst.Variable -> if (node.symbol !in alreadyLoaded) {
|
||||||
!in alreadyLoaded -> {
|
alreadyLoaded += node.symbol
|
||||||
alreadyLoaded += node
|
prepareVariable(node.symbol, arrayMode)
|
||||||
prepareVariable(node, arrayMode)
|
} else Unit
|
||||||
}
|
|
||||||
else -> {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
|
is TypedMst.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
|
||||||
|
|
||||||
is MST.Binary -> {
|
is TypedMst.Binary -> {
|
||||||
visitVariables(node.left, arrayMode, alreadyLoaded)
|
visitVariables(node.left, arrayMode, alreadyLoaded)
|
||||||
visitVariables(node.right, arrayMode, alreadyLoaded)
|
visitVariables(node.right, arrayMode, alreadyLoaded)
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> Unit
|
is TypedMst.Constant -> Unit
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun visitExpression(node: MST): Unit = when (node) {
|
private fun visitExpression(node: TypedMst<T>): Unit = when (node) {
|
||||||
is Symbol -> {
|
is TypedMst.Variable -> loadVariable(node.symbol)
|
||||||
val symbol = algebra.bindSymbolOrNull(node)
|
|
||||||
|
|
||||||
if (symbol != null)
|
is TypedMst.Constant -> loadNumberConstant(
|
||||||
loadNumberConstant(symbol)
|
node.number ?: error("Object constants are not supported by pritimive ASM builder"),
|
||||||
else
|
|
||||||
loadVariable(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
is MST.Numeric -> loadNumberConstant(algebra.number(node.value))
|
|
||||||
|
|
||||||
is MST.Unary -> if (node.value is MST.Numeric)
|
|
||||||
loadNumberConstant(
|
|
||||||
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)),
|
|
||||||
)
|
|
||||||
else
|
|
||||||
visitUnary(node)
|
|
||||||
|
|
||||||
is MST.Binary -> when {
|
|
||||||
node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant(
|
|
||||||
algebra.binaryOperationFunction(node.operation)(
|
|
||||||
algebra.number((node.left as MST.Numeric).value),
|
|
||||||
algebra.number((node.right as MST.Numeric).value),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else -> visitBinary(node)
|
is TypedMst.Unary -> visitUnary(node)
|
||||||
}
|
is TypedMst.Binary -> visitBinary(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value)
|
protected open fun visitUnary(node: TypedMst.Unary<T>) = visitExpression(node.value)
|
||||||
|
|
||||||
protected open fun visitBinary(node: MST.Binary) {
|
protected open fun visitBinary(node: TypedMst.Binary<T>) {
|
||||||
visitExpression(node.left)
|
visitExpression(node.left)
|
||||||
visitExpression(node.right)
|
visitExpression(node.right)
|
||||||
}
|
}
|
||||||
@ -404,14 +381,13 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, DoubleExpression>(
|
internal class DoubleAsmBuilder(target: TypedMst<Double>) : PrimitiveAsmBuilder<Double, DoubleExpression>(
|
||||||
DoubleField,
|
DoubleField,
|
||||||
java.lang.Double::class.java,
|
java.lang.Double::class.java,
|
||||||
java.lang.Double.TYPE,
|
java.lang.Double.TYPE,
|
||||||
DoubleExpression::class.java,
|
DoubleExpression::class.java,
|
||||||
target,
|
target,
|
||||||
) {
|
) {
|
||||||
|
|
||||||
private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
|
private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
|
||||||
MATH_TYPE.internalName,
|
MATH_TYPE.internalName,
|
||||||
name,
|
name,
|
||||||
@ -434,7 +410,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
|
|||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun visitUnary(node: MST.Unary) {
|
override fun visitUnary(node: TypedMst.Unary<Double>) {
|
||||||
super.visitUnary(node)
|
super.visitUnary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -459,7 +435,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(node: MST.Binary) {
|
override fun visitBinary(node: TypedMst.Binary<Double>) {
|
||||||
super.visitBinary(node)
|
super.visitBinary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -479,7 +455,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class IntAsmBuilder(target: MST) :
|
internal class IntAsmBuilder(target: TypedMst<Int>) :
|
||||||
PrimitiveAsmBuilder<Int, IntExpression>(
|
PrimitiveAsmBuilder<Int, IntExpression>(
|
||||||
IntRing,
|
IntRing,
|
||||||
Integer::class.java,
|
Integer::class.java,
|
||||||
@ -487,7 +463,7 @@ internal class IntAsmBuilder(target: MST) :
|
|||||||
IntExpression::class.java,
|
IntExpression::class.java,
|
||||||
target
|
target
|
||||||
) {
|
) {
|
||||||
override fun visitUnary(node: MST.Unary) {
|
override fun visitUnary(node: TypedMst.Unary<Int>) {
|
||||||
super.visitUnary(node)
|
super.visitUnary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -497,7 +473,7 @@ internal class IntAsmBuilder(target: MST) :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(node: MST.Binary) {
|
override fun visitBinary(node: TypedMst.Binary<Int>) {
|
||||||
super.visitBinary(node)
|
super.visitBinary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -510,14 +486,14 @@ internal class IntAsmBuilder(target: MST) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpression>(
|
internal class LongAsmBuilder(target: TypedMst<Long>) : PrimitiveAsmBuilder<Long, LongExpression>(
|
||||||
LongRing,
|
LongRing,
|
||||||
java.lang.Long::class.java,
|
java.lang.Long::class.java,
|
||||||
java.lang.Long.TYPE,
|
java.lang.Long.TYPE,
|
||||||
LongExpression::class.java,
|
LongExpression::class.java,
|
||||||
target,
|
target,
|
||||||
) {
|
) {
|
||||||
override fun visitUnary(node: MST.Unary) {
|
override fun visitUnary(node: TypedMst.Unary<Long>) {
|
||||||
super.visitUnary(node)
|
super.visitUnary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
@ -527,7 +503,7 @@ internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpre
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(node: MST.Binary) {
|
override fun visitBinary(node: TypedMst.Binary<Long>) {
|
||||||
super.visitBinary(node)
|
super.visitBinary(node)
|
||||||
|
|
||||||
when (node.operation) {
|
when (node.operation) {
|
||||||
|
@ -55,15 +55,15 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.(
|
|||||||
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
|
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a class name for [Expression] subclassed to implement [mst] provided.
|
* Creates a class name for [Expression] based with appending [marker] to reduce collisions.
|
||||||
*
|
*
|
||||||
* These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there
|
* These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there
|
||||||
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
|
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
|
||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
|
internal tailrec fun buildName(marker: String, collision: Int = 0): String {
|
||||||
val name = "space.kscience.kmath.asm.generated.CompiledExpression_${mst.hashCode()}_$collision"
|
val name = "space.kscience.kmath.asm.generated.CompiledExpression_${marker}_$collision"
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Class.forName(name)
|
Class.forName(name)
|
||||||
@ -71,7 +71,7 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
|
|||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
return buildName(mst, collision + 1)
|
return buildName(marker, collision + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
|
32
kmath-commons/README.md
Normal file
32
kmath-commons/README.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Module kmath-commons
|
||||||
|
|
||||||
|
Commons math binding for kmath
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
## Artifact:
|
||||||
|
|
||||||
|
The Maven coordinates of this project are `space.kscience:kmath-commons:0.3.0`.
|
||||||
|
|
||||||
|
**Gradle Groovy:**
|
||||||
|
```groovy
|
||||||
|
repositories {
|
||||||
|
maven { url 'https://repo.kotlin.link' }
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation 'space.kscience:kmath-commons:0.3.0'
|
||||||
|
}
|
||||||
|
```
|
||||||
|
**Gradle Kotlin DSL:**
|
||||||
|
```kotlin
|
||||||
|
repositories {
|
||||||
|
maven("https://repo.kotlin.link")
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("space.kscience:kmath-commons:0.3.0")
|
||||||
|
}
|
||||||
|
```
|
@ -8,17 +8,17 @@ Complex and hypercomplex number systems in KMath.
|
|||||||
|
|
||||||
## Artifact:
|
## Artifact:
|
||||||
|
|
||||||
The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-17`.
|
The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0`.
|
||||||
|
|
||||||
**Gradle:**
|
**Gradle Groovy:**
|
||||||
```gradle
|
```groovy
|
||||||
repositories {
|
repositories {
|
||||||
maven { url 'https://repo.kotlin.link' }
|
maven { url 'https://repo.kotlin.link' }
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation 'space.kscience:kmath-complex:0.3.0-dev-17'
|
implementation 'space.kscience:kmath-complex:0.3.0'
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
**Gradle Kotlin DSL:**
|
**Gradle Kotlin DSL:**
|
||||||
@ -29,6 +29,6 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("space.kscience:kmath-complex:0.3.0-dev-17")
|
implementation("space.kscience:kmath-complex:0.3.0")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -19,13 +19,15 @@ readme {
|
|||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "complex",
|
id = "complex",
|
||||||
description = "Complex Numbers",
|
|
||||||
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt"
|
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt"
|
||||||
)
|
){
|
||||||
|
"Complex numbers operations"
|
||||||
|
}
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "quaternion",
|
id = "quaternion",
|
||||||
description = "Quaternions",
|
|
||||||
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt"
|
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt"
|
||||||
)
|
){
|
||||||
|
"Quaternions and their composition"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,52 +16,132 @@ import space.kscience.kmath.structures.MutableBuffer
|
|||||||
import space.kscience.kmath.structures.MutableMemoryBuffer
|
import space.kscience.kmath.structures.MutableMemoryBuffer
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents `double`-based quaternion.
|
||||||
|
*
|
||||||
|
* @property w The first component.
|
||||||
|
* @property x The second component.
|
||||||
|
* @property y The third component.
|
||||||
|
* @property z The fourth component.
|
||||||
|
*/
|
||||||
|
public class Quaternion(
|
||||||
|
public val w: Double,
|
||||||
|
public val x: Double,
|
||||||
|
public val y: Double,
|
||||||
|
public val z: Double,
|
||||||
|
) : Buffer<Double> {
|
||||||
|
init {
|
||||||
|
require(!w.isNaN()) { "w-component of quaternion is not-a-number" }
|
||||||
|
require(!x.isNaN()) { "x-component of quaternion is not-a-number" }
|
||||||
|
require(!y.isNaN()) { "y-component of quaternion is not-a-number" }
|
||||||
|
require(!z.isNaN()) { "z-component of quaternion is not-a-number" }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a string representation of this quaternion.
|
||||||
|
*/
|
||||||
|
override fun toString(): String = "($w + $x * i + $y * j + $z * k)"
|
||||||
|
|
||||||
|
override val size: Int get() = 4
|
||||||
|
|
||||||
|
override fun get(index: Int): Double = when (index) {
|
||||||
|
0 -> w
|
||||||
|
1 -> x
|
||||||
|
2 -> y
|
||||||
|
3 -> z
|
||||||
|
else -> error("Index $index out of bounds [0,3]")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<Double> = listOf(w, x, y, z).iterator()
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
|
||||||
|
other as Quaternion
|
||||||
|
|
||||||
|
if (w != other.w) return false
|
||||||
|
if (x != other.x) return false
|
||||||
|
if (y != other.y) return false
|
||||||
|
if (z != other.z) return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = w.hashCode()
|
||||||
|
result = 31 * result + x.hashCode()
|
||||||
|
result = 31 * result + y.hashCode()
|
||||||
|
result = 31 * result + z.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
public companion object : MemorySpec<Quaternion> {
|
||||||
|
override val objectSize: Int get() = 32
|
||||||
|
|
||||||
|
override fun MemoryReader.read(offset: Int): Quaternion = Quaternion(
|
||||||
|
readDouble(offset),
|
||||||
|
readDouble(offset + 8),
|
||||||
|
readDouble(offset + 16),
|
||||||
|
readDouble(offset + 24)
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun MemoryWriter.write(offset: Int, value: Quaternion) {
|
||||||
|
writeDouble(offset, value.w)
|
||||||
|
writeDouble(offset + 8, value.x)
|
||||||
|
writeDouble(offset + 16, value.y)
|
||||||
|
writeDouble(offset + 24, value.z)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun Quaternion(w: Number, x: Number = 0.0, y: Number = 0.0, z: Number = 0.0): Quaternion = Quaternion(
|
||||||
|
w.toDouble(),
|
||||||
|
x.toDouble(),
|
||||||
|
y.toDouble(),
|
||||||
|
z.toDouble(),
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This quaternion's conjugate.
|
* This quaternion's conjugate.
|
||||||
*/
|
*/
|
||||||
public val Quaternion.conjugate: Quaternion
|
public val Quaternion.conjugate: Quaternion
|
||||||
get() = QuaternionField { z - x * i - y * j - z * k }
|
get() = Quaternion(w, -x, -y, -z)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This quaternion's reciprocal.
|
* This quaternion's reciprocal.
|
||||||
*/
|
*/
|
||||||
public val Quaternion.reciprocal: Quaternion
|
public val Quaternion.reciprocal: Quaternion
|
||||||
get() {
|
get() {
|
||||||
QuaternionField {
|
val norm2 = (w * w + x * x + y * y + z * z)
|
||||||
val n = norm(this@reciprocal)
|
return Quaternion(w / norm2, -x / norm2, -y / norm2, -z / norm2)
|
||||||
return conjugate / (n * n)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
public fun Quaternion.normalized(): Quaternion = with(QuaternionField){ this@normalized / norm(this@normalized) }
|
||||||
* Absolute value of the quaternion.
|
|
||||||
*/
|
|
||||||
public val Quaternion.r: Double
|
|
||||||
get() = sqrt(w * w + x * x + y * y + z * z)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field of [Quaternion].
|
* A field of [Quaternion].
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>,
|
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Double>, PowerOperations<Quaternion>,
|
||||||
ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
|
ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
|
||||||
override val zero: Quaternion = 0.toQuaternion()
|
override val zero: Quaternion = Quaternion(0.0)
|
||||||
override val one: Quaternion = 1.toQuaternion()
|
override val one: Quaternion = Quaternion(1.0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The `i` quaternion unit.
|
* The `i` quaternion unit.
|
||||||
*/
|
*/
|
||||||
public val i: Quaternion = Quaternion(0, 1)
|
public val i: Quaternion = Quaternion(0.0, 1.0, 0.0, 0.0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The `j` quaternion unit.
|
* The `j` quaternion unit.
|
||||||
*/
|
*/
|
||||||
public val j: Quaternion = Quaternion(0, 0, 1)
|
public val j: Quaternion = Quaternion(0.0, 0.0, 1.0, 0.0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The `k` quaternion unit.
|
* The `k` quaternion unit.
|
||||||
*/
|
*/
|
||||||
public val k: Quaternion = Quaternion(0, 0, 0, 1)
|
public val k: Quaternion = Quaternion(0.0, 0.0, 0.0, 1.0)
|
||||||
|
|
||||||
override fun add(left: Quaternion, right: Quaternion): Quaternion =
|
override fun add(left: Quaternion, right: Quaternion): Quaternion =
|
||||||
Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z)
|
Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z)
|
||||||
@ -133,7 +213,7 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
|
|
||||||
override fun exp(arg: Quaternion): Quaternion {
|
override fun exp(arg: Quaternion): Quaternion {
|
||||||
val un = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z
|
val un = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z
|
||||||
if (un == 0.0) return exp(arg.w).toQuaternion()
|
if (un == 0.0) return Quaternion(exp(arg.w))
|
||||||
val n1 = sqrt(un)
|
val n1 = sqrt(un)
|
||||||
val ea = exp(arg.w)
|
val ea = exp(arg.w)
|
||||||
val n2 = ea * sin(n1) / n1
|
val n2 = ea * sin(n1) / n1
|
||||||
@ -158,7 +238,8 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z)
|
return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Number.plus(other: Quaternion): Quaternion = Quaternion(toDouble() + other.w, other.x, other.y, other.z)
|
override operator fun Number.plus(other: Quaternion): Quaternion =
|
||||||
|
Quaternion(toDouble() + other.w, other.x, other.y, other.z)
|
||||||
|
|
||||||
override operator fun Number.minus(other: Quaternion): Quaternion =
|
override operator fun Number.minus(other: Quaternion): Quaternion =
|
||||||
Quaternion(toDouble() - other.w, -other.x, -other.y, -other.z)
|
Quaternion(toDouble() - other.w, -other.x, -other.y, -other.z)
|
||||||
@ -170,7 +251,12 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z)
|
Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z)
|
||||||
|
|
||||||
override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z)
|
override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z)
|
||||||
override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg)
|
override fun norm(arg: Quaternion): Double = sqrt(
|
||||||
|
arg.w.pow(2) +
|
||||||
|
arg.x.pow(2) +
|
||||||
|
arg.y.pow(2) +
|
||||||
|
arg.z.pow(2)
|
||||||
|
)
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): Quaternion? = when (value) {
|
override fun bindSymbolOrNull(value: String): Quaternion? = when (value) {
|
||||||
"i" -> i
|
"i" -> i
|
||||||
@ -179,7 +265,7 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
else -> null
|
else -> null
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun number(value: Number): Quaternion = value.toQuaternion()
|
override fun number(value: Number): Quaternion = Quaternion(value)
|
||||||
|
|
||||||
override fun sinh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / 2.0
|
override fun sinh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / 2.0
|
||||||
override fun cosh(arg: Quaternion): Quaternion = (exp(arg) + exp(-arg)) / 2.0
|
override fun cosh(arg: Quaternion): Quaternion = (exp(arg) + exp(-arg)) / 2.0
|
||||||
@ -189,76 +275,6 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
override fun atanh(arg: Quaternion): Quaternion = (ln(arg + one) - ln(one - arg)) / 2.0
|
override fun atanh(arg: Quaternion): Quaternion = (ln(arg + one) - ln(one - arg)) / 2.0
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents `double`-based quaternion.
|
|
||||||
*
|
|
||||||
* @property w The first component.
|
|
||||||
* @property x The second component.
|
|
||||||
* @property y The third component.
|
|
||||||
* @property z The fourth component.
|
|
||||||
*/
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public data class Quaternion(
|
|
||||||
val w: Double, val x: Double, val y: Double, val z: Double,
|
|
||||||
) {
|
|
||||||
public constructor(w: Number, x: Number, y: Number, z: Number) : this(
|
|
||||||
w.toDouble(),
|
|
||||||
x.toDouble(),
|
|
||||||
y.toDouble(),
|
|
||||||
z.toDouble(),
|
|
||||||
)
|
|
||||||
|
|
||||||
public constructor(w: Number, x: Number, y: Number) : this(w.toDouble(), x.toDouble(), y.toDouble(), 0.0)
|
|
||||||
public constructor(w: Number, x: Number) : this(w.toDouble(), x.toDouble(), 0.0, 0.0)
|
|
||||||
public constructor(w: Number) : this(w.toDouble(), 0.0, 0.0, 0.0)
|
|
||||||
public constructor(wx: Complex, yz: Complex) : this(wx.re, wx.im, yz.re, yz.im)
|
|
||||||
public constructor(wx: Complex) : this(wx.re, wx.im, 0, 0)
|
|
||||||
|
|
||||||
init {
|
|
||||||
require(!w.isNaN()) { "w-component of quaternion is not-a-number" }
|
|
||||||
require(!x.isNaN()) { "x-component of quaternion is not-a-number" }
|
|
||||||
require(!y.isNaN()) { "x-component of quaternion is not-a-number" }
|
|
||||||
require(!z.isNaN()) { "x-component of quaternion is not-a-number" }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a string representation of this quaternion.
|
|
||||||
*/
|
|
||||||
override fun toString(): String = "($w + $x * i + $y * j + $z * k)"
|
|
||||||
|
|
||||||
public companion object : MemorySpec<Quaternion> {
|
|
||||||
override val objectSize: Int
|
|
||||||
get() = 32
|
|
||||||
|
|
||||||
override fun MemoryReader.read(offset: Int): Quaternion =
|
|
||||||
Quaternion(readDouble(offset), readDouble(offset + 8), readDouble(offset + 16), readDouble(offset + 24))
|
|
||||||
|
|
||||||
override fun MemoryWriter.write(offset: Int, value: Quaternion) {
|
|
||||||
writeDouble(offset, value.w)
|
|
||||||
writeDouble(offset + 8, value.x)
|
|
||||||
writeDouble(offset + 16, value.y)
|
|
||||||
writeDouble(offset + 24, value.z)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a quaternion with real part equal to this real.
|
|
||||||
*
|
|
||||||
* @receiver the real part.
|
|
||||||
* @return a new quaternion.
|
|
||||||
*/
|
|
||||||
public fun Number.toQuaternion(): Quaternion = Quaternion(this)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a quaternion with `w`-component equal to `re`-component of given complex and `x`-component equal to
|
|
||||||
* `im`-component of given complex.
|
|
||||||
*
|
|
||||||
* @receiver the complex number.
|
|
||||||
* @return a new quaternion.
|
|
||||||
*/
|
|
||||||
public fun Complex.toQuaternion(): Quaternion = Quaternion(this)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new buffer of quaternions with the specified [size], where each element is calculated by calling the
|
* Creates a new buffer of quaternions with the specified [size], where each element is calculated by calling the
|
||||||
* specified [init] function.
|
* specified [init] function.
|
||||||
|
@ -6,10 +6,23 @@
|
|||||||
package space.kscience.kmath.complex
|
package space.kscience.kmath.complex
|
||||||
|
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
import space.kscience.kmath.testutils.assertBufferEquals
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class QuaternionFieldTest {
|
internal class QuaternionTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testNorm() {
|
||||||
|
assertEquals(2.0, QuaternionField.norm(Quaternion(1.0, 1.0, 1.0, 1.0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testInverse() = QuaternionField {
|
||||||
|
val q = Quaternion(1.0, 2.0, -3.0, 4.0)
|
||||||
|
assertBufferEquals(one, q * q.reciprocal, 1e-4)
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAddition() {
|
fun testAddition() {
|
||||||
assertEquals(Quaternion(42, 42), QuaternionField { Quaternion(16, 16) + Quaternion(26, 26) })
|
assertEquals(Quaternion(42, 42), QuaternionField { Quaternion(16, 16) + Quaternion(26, 26) })
|
@ -15,17 +15,17 @@ performance calculations to code generation.
|
|||||||
|
|
||||||
## Artifact:
|
## Artifact:
|
||||||
|
|
||||||
The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-17`.
|
The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0`.
|
||||||
|
|
||||||
**Gradle:**
|
**Gradle Groovy:**
|
||||||
```gradle
|
```groovy
|
||||||
repositories {
|
repositories {
|
||||||
maven { url 'https://repo.kotlin.link' }
|
maven { url 'https://repo.kotlin.link' }
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation 'space.kscience:kmath-core:0.3.0-dev-17'
|
implementation 'space.kscience:kmath-core:0.3.0'
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
**Gradle Kotlin DSL:**
|
**Gradle Kotlin DSL:**
|
||||||
@ -36,6 +36,6 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("space.kscience:kmath-core:0.3.0-dev-17")
|
implementation("space.kscience:kmath-core:0.3.0")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,6 @@
|
|||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
id("ru.mipt.npm.gradle.mpp")
|
||||||
id("ru.mipt.npm.gradle.common")
|
|
||||||
id("ru.mipt.npm.gradle.native")
|
id("ru.mipt.npm.gradle.native")
|
||||||
// id("com.xcporter.metaview") version "0.0.5"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
|
@ -0,0 +1,59 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.domains
|
||||||
|
|
||||||
|
import space.kscience.kmath.linear.Point
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public abstract class Domain1D<T : Comparable<T>>(public val range: ClosedRange<T>) : Domain<T> {
|
||||||
|
override val dimension: Int get() = 1
|
||||||
|
|
||||||
|
public operator fun contains(value: T): Boolean = range.contains(value)
|
||||||
|
|
||||||
|
override operator fun contains(point: Point<T>): Boolean {
|
||||||
|
require(point.size == 0)
|
||||||
|
return contains(point[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DoubleDomain1D(
|
||||||
|
@Suppress("CanBeParameter") public val doubleRange: ClosedFloatingPointRange<Double>,
|
||||||
|
) : Domain1D<Double>(doubleRange), DoubleDomain {
|
||||||
|
override fun getLowerBound(num: Int): Double {
|
||||||
|
require(num == 0)
|
||||||
|
return range.start
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getUpperBound(num: Int): Double {
|
||||||
|
require(num == 0)
|
||||||
|
return range.endInclusive
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun volume(): Double = range.endInclusive - range.start
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
|
||||||
|
other as DoubleDomain1D
|
||||||
|
|
||||||
|
if (doubleRange != other.doubleRange) return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int = doubleRange.hashCode()
|
||||||
|
|
||||||
|
override fun toString(): String = doubleRange.toString()
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public val Domain1D<Double>.center: Double
|
||||||
|
get() = (range.endInclusive + range.start) / 2
|
@ -7,18 +7,28 @@ package space.kscience.kmath.domains
|
|||||||
import space.kscience.kmath.linear.Point
|
import space.kscience.kmath.linear.Point
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import space.kscience.kmath.structures.indices
|
import space.kscience.kmath.structures.indices
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
* A hyper-square (or hyper-cube) real-space domain. It is formed by a [Buffer] of [lower] boundaries
|
||||||
* HyperSquareDomain class.
|
* and a [Buffer] of upper boundaries. Upper should be greater or equals than lower.
|
||||||
*
|
|
||||||
* @author Alexander Nozik
|
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public class HyperSquareDomain(private val lower: Buffer<Double>, private val upper: Buffer<Double>) : DoubleDomain {
|
public class HyperSquareDomain(public val lower: Buffer<Double>, public val upper: Buffer<Double>) : DoubleDomain {
|
||||||
|
init {
|
||||||
|
require(lower.size == upper.size) {
|
||||||
|
"Domain borders size mismatch. Lower borders size is ${lower.size}, but upper borders size is ${upper.size}."
|
||||||
|
}
|
||||||
|
require(lower.indices.all { lower[it] <= upper[it] }) {
|
||||||
|
"Domain borders order mismatch. Lower borders must be less or equals than upper borders."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override val dimension: Int get() = lower.size
|
override val dimension: Int get() = lower.size
|
||||||
|
|
||||||
|
public val center: DoubleBuffer get() = DoubleBuffer(dimension) { (lower[it] + upper[it]) / 2.0 }
|
||||||
|
|
||||||
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
|
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
|
||||||
point[i] in lower[i]..upper[i]
|
point[i] in lower[i]..upper[i]
|
||||||
}
|
}
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2021 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.domains
|
|
||||||
|
|
||||||
import space.kscience.kmath.linear.Point
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class UnivariateDomain(public val range: ClosedFloatingPointRange<Double>) : DoubleDomain {
|
|
||||||
override val dimension: Int get() = 1
|
|
||||||
|
|
||||||
public operator fun contains(d: Double): Boolean = range.contains(d)
|
|
||||||
|
|
||||||
override operator fun contains(point: Point<Double>): Boolean {
|
|
||||||
require(point.size == 0)
|
|
||||||
return contains(point[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun getLowerBound(num: Int): Double {
|
|
||||||
require(num == 0)
|
|
||||||
return range.start
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun getUpperBound(num: Int): Double {
|
|
||||||
require(num == 0)
|
|
||||||
return range.endInclusive
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun volume(): Double = range.endInclusive - range.start
|
|
||||||
}
|
|
@ -0,0 +1,458 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.*
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
|
import space.kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.max
|
||||||
|
import kotlin.math.min
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class representing both the value and the differentials of a function.
|
||||||
|
*
|
||||||
|
* This class is the workhorse of the differentiation package.
|
||||||
|
*
|
||||||
|
* This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper
|
||||||
|
* [Doubly Recursive Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf),
|
||||||
|
* Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used
|
||||||
|
* throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's
|
||||||
|
* derivative structures hold all partial derivatives up to any specified order, with respect to any number of free
|
||||||
|
* parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free
|
||||||
|
* parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters.
|
||||||
|
*
|
||||||
|
* Derived from
|
||||||
|
* [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java).
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public interface DS<T, A : Ring<T>> {
|
||||||
|
public val derivativeAlgebra: DSAlgebra<T, A>
|
||||||
|
public val data: Buffer<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a partial derivative.
|
||||||
|
*
|
||||||
|
* @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned).
|
||||||
|
* @return partial derivative.
|
||||||
|
* @see value
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
private fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
|
||||||
|
data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T, A : Ring<T>> DS<T, A>.derivative(symbols: List<Symbol>): T {
|
||||||
|
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
|
||||||
|
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||||
|
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T, A : Ring<T>> DS<T, A>.derivative(vararg symbols: Symbol): T {
|
||||||
|
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
|
||||||
|
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||||
|
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The value part of the derivative structure.
|
||||||
|
*
|
||||||
|
* @see getPartialDerivative
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public val <T, A : Ring<T>> DS<T, A>.value: T get() = data[0]
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
|
public val order: Int,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the compiler for number of free parameters and order.
|
||||||
|
*
|
||||||
|
* @return cached rules set.
|
||||||
|
*/
|
||||||
|
@PublishedApi
|
||||||
|
internal val compiler: DSCompiler<T, A> by lazy {
|
||||||
|
// get the cached compilers
|
||||||
|
val cache: Array<Array<DSCompiler<T, A>?>>? = null
|
||||||
|
|
||||||
|
// we need to create more compilers
|
||||||
|
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
|
||||||
|
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
|
||||||
|
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
|
||||||
|
|
||||||
|
if (cache != null) {
|
||||||
|
// preserve the already created compilers
|
||||||
|
for (i in cache.indices) {
|
||||||
|
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the array in increasing diagonal order
|
||||||
|
for (diag in 0..numberOfVariables + order) {
|
||||||
|
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
|
||||||
|
val p: Int = diag - o
|
||||||
|
if (newCache[p][o] == null) {
|
||||||
|
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
|
||||||
|
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
|
||||||
|
|
||||||
|
newCache[p][o] = DSCompiler(
|
||||||
|
algebra,
|
||||||
|
bufferFactory,
|
||||||
|
p,
|
||||||
|
o,
|
||||||
|
valueCompiler,
|
||||||
|
derivativeCompiler,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return@lazy newCache[numberOfVariables][order]!!
|
||||||
|
}
|
||||||
|
|
||||||
|
private val variables: Map<Symbol, DSSymbol> by lazy {
|
||||||
|
bindings.entries.mapIndexed { index, (key, value) ->
|
||||||
|
key to DSSymbol(
|
||||||
|
index,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
}.toMap()
|
||||||
|
}
|
||||||
|
override val symbols: List<Symbol> = bindings.map { it.key }
|
||||||
|
|
||||||
|
public val numberOfVariables: Int get() = symbols.size
|
||||||
|
|
||||||
|
|
||||||
|
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
|
||||||
|
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||||
|
buffer[0] = value
|
||||||
|
if (compiler.order > 0) {
|
||||||
|
// the derivative of the variable with respect to itself is 1.
|
||||||
|
|
||||||
|
val indexOfDerivative = compiler.getPartialDerivativeIndex(*IntArray(numberOfVariables).apply {
|
||||||
|
set(index, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
buffer[indexOfDerivative] = algebra.one
|
||||||
|
}
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
private inner class DSImpl(
|
||||||
|
override val data: Buffer<T>,
|
||||||
|
) : DS<T, A> {
|
||||||
|
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun DS(data: Buffer<T>): DS<T, A> = DSImpl(data)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance representing a variable.
|
||||||
|
*
|
||||||
|
* Instances built using this constructor are considered to be the free variables with respect to which
|
||||||
|
* differentials are computed. As such, their differential with respect to themselves is +1.
|
||||||
|
*/
|
||||||
|
public fun variable(
|
||||||
|
index: Int,
|
||||||
|
value: T,
|
||||||
|
): DS<T, A> {
|
||||||
|
require(index < compiler.freeParameters) { "number is too large: $index >= ${compiler.freeParameters}" }
|
||||||
|
return DS(bufferForVariable(index, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance from all its derivatives.
|
||||||
|
*
|
||||||
|
* @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex].
|
||||||
|
*/
|
||||||
|
public fun ofDerivatives(
|
||||||
|
vararg derivatives: T,
|
||||||
|
): DS<T, A> {
|
||||||
|
require(derivatives.size == compiler.size) { "dimension mismatch: ${derivatives.size} and ${compiler.size}" }
|
||||||
|
val data = derivatives.asBuffer()
|
||||||
|
|
||||||
|
return DS(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class implementing both [DS] and [Symbol].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public inner class DSSymbol internal constructor(
|
||||||
|
index: Int,
|
||||||
|
symbol: Symbol,
|
||||||
|
value: T,
|
||||||
|
) : Symbol by symbol, DS<T, A> {
|
||||||
|
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
|
||||||
|
override val data: Buffer<T> = bufferForVariable(index, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun const(value: T): DS<T, A> {
|
||||||
|
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||||
|
buffer[0] = value
|
||||||
|
|
||||||
|
return DS(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun bindSymbolOrNull(value: String): DSSymbol? = variables[StringSymbol(value)]
|
||||||
|
|
||||||
|
override fun bindSymbol(value: String): DSSymbol =
|
||||||
|
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
|
||||||
|
|
||||||
|
public fun bindSymbolOrNull(symbol: Symbol): DSSymbol? = variables[symbol.identity]
|
||||||
|
|
||||||
|
public fun bindSymbol(symbol: Symbol): DSSymbol =
|
||||||
|
bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this")
|
||||||
|
|
||||||
|
public fun DS<T, A>.derivative(symbols: List<Symbol>): T {
|
||||||
|
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
||||||
|
val ordersCount = symbols.groupBy { it }.mapValues { it.value.size }
|
||||||
|
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun DS<T, A>.derivative(vararg symbols: Symbol): T = derivative(symbols.toList())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A ring over [DS].
|
||||||
|
*
|
||||||
|
* @property order The derivation order.
|
||||||
|
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public open class DSRing<T, A>(
|
||||||
|
algebra: A,
|
||||||
|
bufferFactory: MutableBufferFactory<T>,
|
||||||
|
order: Int,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : DSAlgebra<T, A>(algebra, bufferFactory, order, bindings),
|
||||||
|
Ring<DS<T, A>>, ScaleOperations<DS<T, A>>,
|
||||||
|
NumericAlgebra<DS<T, A>>,
|
||||||
|
NumbersAddOps<DS<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||||
|
|
||||||
|
override fun bindSymbolOrNull(value: String): DSSymbol? =
|
||||||
|
super<DSAlgebra>.bindSymbolOrNull(value)
|
||||||
|
|
||||||
|
override fun DS<T, A>.unaryMinus(): DS<T, A> = mapData { -it }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a copy of given [Buffer] and modify it according to [block]
|
||||||
|
*/
|
||||||
|
protected inline fun DS<T, A>.transformDataBuffer(block: A.(MutableBuffer<T>) -> Unit): DS<T, A> {
|
||||||
|
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
|
val newData = bufferFactory(compiler.size) { data[it] }
|
||||||
|
algebra.block(newData)
|
||||||
|
return DS(newData)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun DS<T, A>.mapData(block: A.(T) -> T): DS<T, A> {
|
||||||
|
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
|
val newData: Buffer<T> = data.map(bufferFactory) {
|
||||||
|
algebra.block(it)
|
||||||
|
}
|
||||||
|
return DS(newData)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun DS<T, A>.mapDataIndexed(block: (Int, T) -> T): DS<T, A> {
|
||||||
|
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
|
val newData: Buffer<T> = data.mapIndexed(bufferFactory, block)
|
||||||
|
return DS(newData)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val zero: DS<T, A> by lazy {
|
||||||
|
const(algebra.zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val one: DS<T, A> by lazy {
|
||||||
|
const(algebra.one)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
|
||||||
|
|
||||||
|
override fun add(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
|
||||||
|
require(right.derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
|
compiler.add(left.data, 0, right.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun scale(a: DS<T, A>, value: Double): DS<T, A> = a.mapData {
|
||||||
|
it.times(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(
|
||||||
|
left: DS<T, A>,
|
||||||
|
right: DS<T, A>,
|
||||||
|
): DS<T, A> = left.transformDataBuffer { result ->
|
||||||
|
compiler.multiply(left.data, 0, right.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// override fun DS<T, A>.minus(arg: DS): DS<T, A> = transformDataBuffer { result ->
|
||||||
|
// subtract(data, 0, arg.data, 0, result, 0)
|
||||||
|
// }
|
||||||
|
|
||||||
|
override operator fun DS<T, A>.plus(other: Number): DS<T, A> = transformDataBuffer {
|
||||||
|
it[0] += number(other)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// override operator fun DS<T, A>.minus(other: Number): DS<T, A> =
|
||||||
|
// this + (-other.toDouble())
|
||||||
|
|
||||||
|
override operator fun Number.plus(other: DS<T, A>): DS<T, A> = other + this
|
||||||
|
override operator fun Number.minus(other: DS<T, A>): DS<T, A> = other - this
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DerivativeStructureRingExpression<T, A>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
|
public val function: DSRing<T, A>.() -> DS<T, A>,
|
||||||
|
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
||||||
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
|
DSRing(algebra, bufferFactory, 0, arguments).function().value
|
||||||
|
|
||||||
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||||
|
with(
|
||||||
|
DSRing(
|
||||||
|
algebra,
|
||||||
|
bufferFactory,
|
||||||
|
symbols.size,
|
||||||
|
arguments
|
||||||
|
)
|
||||||
|
) { function().derivative(symbols) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field over commons-math [DerivativeStructure].
|
||||||
|
*
|
||||||
|
* @property order The derivation order.
|
||||||
|
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DSField<T, A : ExtendedField<T>>(
|
||||||
|
algebra: A,
|
||||||
|
bufferFactory: MutableBufferFactory<T>,
|
||||||
|
order: Int,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : DSRing<T, A>(algebra, bufferFactory, order, bindings), ExtendedField<DS<T, A>> {
|
||||||
|
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
|
||||||
|
|
||||||
|
override fun divide(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
|
||||||
|
compiler.divide(left.data, 0, right.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sin(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.sin(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cos(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.cos(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun tan(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.tan(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun asin(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.asin(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun acos(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.acos(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun atan(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.atan(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sinh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.sinh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cosh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.cosh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun tanh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.tanh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun asinh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.asinh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun acosh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.acosh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun atanh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.atanh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun power(arg: DS<T, A>, pow: Number): DS<T, A> = when (pow) {
|
||||||
|
is Int -> arg.transformDataBuffer { result ->
|
||||||
|
compiler.pow(arg.data, 0, pow, result, 0)
|
||||||
|
}
|
||||||
|
else -> arg.transformDataBuffer { result ->
|
||||||
|
compiler.pow(arg.data, 0, pow.toDouble(), result, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sqrt(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.sqrt(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun power(arg: DS<T, A>, pow: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.pow(arg.data, 0, pow.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun exp(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.exp(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun ln(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.ln(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DSFieldExpression<T, A : ExtendedField<T>>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
|
public val function: DSField<T, A>.() -> DS<T, A>,
|
||||||
|
) : DifferentiableExpression<T> {
|
||||||
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
|
DSField(algebra, bufferFactory, 0, arguments).function().value
|
||||||
|
|
||||||
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||||
|
DSField(
|
||||||
|
algebra,
|
||||||
|
bufferFactory,
|
||||||
|
symbols.size,
|
||||||
|
arguments,
|
||||||
|
).run { function().derivative(symbols) }
|
||||||
|
}
|
||||||
|
}
|
@ -5,11 +5,11 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.expressions
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
import space.kscience.kmath.structures.MutableBufferFactory
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
import kotlin.math.max
|
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
|
|
||||||
internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex: Int = size) {
|
internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex: Int = size) {
|
||||||
@ -52,20 +52,20 @@ internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex:
|
|||||||
*
|
*
|
||||||
* @property freeParameters Number of free parameters.
|
* @property freeParameters Number of free parameters.
|
||||||
* @property order Derivation order.
|
* @property order Derivation order.
|
||||||
* @see DerivativeStructure
|
* @see DS
|
||||||
*/
|
*/
|
||||||
internal class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
public class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||||
val algebra: A,
|
public val algebra: A,
|
||||||
val bufferFactory: MutableBufferFactory<T>,
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
val freeParameters: Int,
|
public val freeParameters: Int,
|
||||||
val order: Int,
|
public val order: Int,
|
||||||
valueCompiler: DSCompiler<T, A>?,
|
valueCompiler: DSCompiler<T, A>?,
|
||||||
derivativeCompiler: DSCompiler<T, A>?,
|
derivativeCompiler: DSCompiler<T, A>?,
|
||||||
) {
|
) {
|
||||||
/**
|
/**
|
||||||
* Number of partial derivatives (including the single 0 order derivative element).
|
* Number of partial derivatives (including the single 0 order derivative element).
|
||||||
*/
|
*/
|
||||||
val sizes: Array<IntArray> by lazy {
|
public val sizes: Array<IntArray> by lazy {
|
||||||
compileSizes(
|
compileSizes(
|
||||||
freeParameters,
|
freeParameters,
|
||||||
order,
|
order,
|
||||||
@ -76,7 +76,7 @@ internal class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Indirection array for partial derivatives.
|
* Indirection array for partial derivatives.
|
||||||
*/
|
*/
|
||||||
val derivativesIndirection: Array<IntArray> by lazy {
|
internal val derivativesIndirection: Array<IntArray> by lazy {
|
||||||
compileDerivativesIndirection(
|
compileDerivativesIndirection(
|
||||||
freeParameters, order,
|
freeParameters, order,
|
||||||
valueCompiler, derivativeCompiler,
|
valueCompiler, derivativeCompiler,
|
||||||
@ -86,7 +86,7 @@ internal class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Indirection array of the lower derivative elements.
|
* Indirection array of the lower derivative elements.
|
||||||
*/
|
*/
|
||||||
val lowerIndirection: IntArray by lazy {
|
internal val lowerIndirection: IntArray by lazy {
|
||||||
compileLowerIndirection(
|
compileLowerIndirection(
|
||||||
freeParameters, order,
|
freeParameters, order,
|
||||||
valueCompiler, derivativeCompiler,
|
valueCompiler, derivativeCompiler,
|
||||||
@ -96,7 +96,7 @@ internal class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Indirection arrays for multiplication.
|
* Indirection arrays for multiplication.
|
||||||
*/
|
*/
|
||||||
val multIndirection: Array<Array<IntArray>> by lazy {
|
internal val multIndirection: Array<Array<IntArray>> by lazy {
|
||||||
compileMultiplicationIndirection(
|
compileMultiplicationIndirection(
|
||||||
freeParameters, order,
|
freeParameters, order,
|
||||||
valueCompiler, derivativeCompiler, lowerIndirection,
|
valueCompiler, derivativeCompiler, lowerIndirection,
|
||||||
@ -106,7 +106,7 @@ internal class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Indirection arrays for function composition.
|
* Indirection arrays for function composition.
|
||||||
*/
|
*/
|
||||||
val compositionIndirection: Array<Array<IntArray>> by lazy {
|
internal val compositionIndirection: Array<Array<IntArray>> by lazy {
|
||||||
compileCompositionIndirection(
|
compileCompositionIndirection(
|
||||||
freeParameters, order,
|
freeParameters, order,
|
||||||
valueCompiler, derivativeCompiler,
|
valueCompiler, derivativeCompiler,
|
||||||
@ -120,8 +120,7 @@ internal class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
* This number includes the single 0 order derivative element, which is
|
* This number includes the single 0 order derivative element, which is
|
||||||
* guaranteed to be stored in the first element of the array.
|
* guaranteed to be stored in the first element of the array.
|
||||||
*/
|
*/
|
||||||
val size: Int
|
public val size: Int get() = sizes[freeParameters][order]
|
||||||
get() = sizes[freeParameters][order]
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the index of a partial derivative in the array.
|
* Get the index of a partial derivative in the array.
|
||||||
@ -148,7 +147,7 @@ internal class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
* @return index of the partial derivative.
|
* @return index of the partial derivative.
|
||||||
* @see getPartialDerivativeOrders
|
* @see getPartialDerivativeOrders
|
||||||
*/
|
*/
|
||||||
fun getPartialDerivativeIndex(vararg orders: Int): Int {
|
public fun getPartialDerivativeIndex(vararg orders: Int): Int {
|
||||||
// safety check
|
// safety check
|
||||||
require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" }
|
require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" }
|
||||||
return getPartialDerivativeIndex(freeParameters, order, sizes, *orders)
|
return getPartialDerivativeIndex(freeParameters, order, sizes, *orders)
|
||||||
@ -163,7 +162,7 @@ internal class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
* @return orders derivation orders with respect to each parameter
|
* @return orders derivation orders with respect to each parameter
|
||||||
* @see getPartialDerivativeIndex
|
* @see getPartialDerivativeIndex
|
||||||
*/
|
*/
|
||||||
fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index]
|
public fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -178,7 +177,7 @@ internal fun <T, A> DSCompiler<T, A>.ln(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : ExponentialOperations<T> = algebra {
|
) where A : Field<T>, A : ExponentialOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -211,7 +210,7 @@ internal fun <T, A> DSCompiler<T, A>.pow(
|
|||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
n: Int,
|
n: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : PowerOperations<T> = algebra {
|
) where A : Field<T>, A : PowerOperations<T> = algebra {
|
||||||
if (n == 0) {
|
if (n == 0) {
|
||||||
// special case, x^0 = 1 for all x
|
// special case, x^0 = 1 for all x
|
||||||
@ -267,7 +266,7 @@ internal fun <T, A> DSCompiler<T, A>.exp(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Ring<T>, A : ScaleOperations<T>, A : ExponentialOperations<T> = algebra {
|
) where A : Ring<T>, A : ScaleOperations<T>, A : ExponentialOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -290,7 +289,7 @@ internal fun <T, A> DSCompiler<T, A>.sqrt(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : PowerOperations<T> = algebra {
|
) where A : Field<T>, A : PowerOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
// [x^(1/n), (1/n)x^((1/n)-1), (1-n)/n^2x^((1/n)-2), ... ]
|
// [x^(1/n), (1/n)x^((1/n)-1), (1-n)/n^2x^((1/n)-2), ... ]
|
||||||
@ -351,7 +350,7 @@ internal fun <T, A> DSCompiler<T, A>.pow(
|
|||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
p: Double,
|
p: Double,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Ring<T>, A : NumericAlgebra<T>, A : PowerOperations<T>, A : ScaleOperations<T> = algebra {
|
) where A : Ring<T>, A : NumericAlgebra<T>, A : PowerOperations<T>, A : ScaleOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
// [x^p, px^(p-1), p(p-1)x^(p-2), ... ]
|
// [x^p, px^(p-1), p(p-1)x^(p-2), ... ]
|
||||||
@ -387,7 +386,7 @@ internal fun <T, A> DSCompiler<T, A>.tan(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Ring<T>, A : TrigonometricOperations<T>, A : ScaleOperations<T> = algebra {
|
) where A : Ring<T>, A : TrigonometricOperations<T>, A : ScaleOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -469,7 +468,7 @@ internal fun <T, A> DSCompiler<T, A>.sin(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Ring<T>, A : ScaleOperations<T>, A : TrigonometricOperations<T> = algebra {
|
) where A : Ring<T>, A : ScaleOperations<T>, A : TrigonometricOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -497,7 +496,7 @@ internal fun <T, A> DSCompiler<T, A>.acos(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : TrigonometricOperations<T>, A : PowerOperations<T> = algebra {
|
) where A : Field<T>, A : TrigonometricOperations<T>, A : PowerOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -559,7 +558,7 @@ internal fun <T, A> DSCompiler<T, A>.asin(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : TrigonometricOperations<T>, A : PowerOperations<T> = algebra {
|
) where A : Field<T>, A : TrigonometricOperations<T>, A : PowerOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -618,7 +617,7 @@ internal fun <T, A> DSCompiler<T, A>.atan(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : TrigonometricOperations<T> = algebra {
|
) where A : Field<T>, A : TrigonometricOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -678,7 +677,7 @@ internal fun <T, A> DSCompiler<T, A>.cosh(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Ring<T>, A : ScaleOperations<T>, A : ExponentialOperations<T> = algebra {
|
) where A : Ring<T>, A : ScaleOperations<T>, A : ExponentialOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -708,7 +707,7 @@ internal fun <T, A> DSCompiler<T, A>.tanh(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : ExponentialOperations<T> = algebra {
|
) where A : Field<T>, A : ExponentialOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -765,7 +764,7 @@ internal fun <T, A> DSCompiler<T, A>.acosh(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : ExponentialOperations<T>, A : PowerOperations<T> = algebra {
|
) where A : Field<T>, A : ExponentialOperations<T>, A : PowerOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -857,7 +856,7 @@ internal fun <T, A> DSCompiler<T, A>.sinh(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : ExponentialOperations<T> = algebra {
|
) where A : Field<T>, A : ExponentialOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -964,7 +963,7 @@ internal fun <T, A> DSCompiler<T, A>.asinh(
|
|||||||
operand: Buffer<T>,
|
operand: Buffer<T>,
|
||||||
operandOffset: Int,
|
operandOffset: Int,
|
||||||
result: MutableBuffer<T>,
|
result: MutableBuffer<T>,
|
||||||
resultOffset: Int
|
resultOffset: Int,
|
||||||
) where A : Field<T>, A : ExponentialOperations<T>, A : PowerOperations<T> = algebra {
|
) where A : Field<T>, A : ExponentialOperations<T>, A : PowerOperations<T> = algebra {
|
||||||
// create the function value and derivatives
|
// create the function value and derivatives
|
||||||
val function = bufferFactory(1 + order) { zero }
|
val function = bufferFactory(1 + order) { zero }
|
||||||
@ -1109,59 +1108,6 @@ internal fun <T, A> DSCompiler<T, A>.atanh(
|
|||||||
compose(operand, operandOffset, function, result, resultOffset)
|
compose(operand, operandOffset, function, result, resultOffset)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the compiler for number of free parameters and order.
|
|
||||||
*
|
|
||||||
* @param parameters number of free parameters.
|
|
||||||
* @param order derivation order.
|
|
||||||
* @return cached rules set.
|
|
||||||
*/
|
|
||||||
internal fun <T, A : Algebra<T>> getCompiler(
|
|
||||||
algebra: A,
|
|
||||||
bufferFactory: MutableBufferFactory<T>,
|
|
||||||
parameters: Int,
|
|
||||||
order: Int
|
|
||||||
): DSCompiler<T, A> {
|
|
||||||
// get the cached compilers
|
|
||||||
val cache: Array<Array<DSCompiler<T, A>?>>? = null
|
|
||||||
|
|
||||||
// we need to create more compilers
|
|
||||||
val maxParameters: Int = max(parameters, cache?.size ?: 0)
|
|
||||||
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
|
|
||||||
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
|
|
||||||
|
|
||||||
if (cache != null) {
|
|
||||||
// preserve the already created compilers
|
|
||||||
for (i in cache.indices) {
|
|
||||||
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// create the array in increasing diagonal order
|
|
||||||
|
|
||||||
// create the array in increasing diagonal order
|
|
||||||
for (diag in 0..parameters + order) {
|
|
||||||
for (o in max(0, diag - parameters)..min(order, diag)) {
|
|
||||||
val p: Int = diag - o
|
|
||||||
if (newCache[p][o] == null) {
|
|
||||||
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
|
|
||||||
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
|
|
||||||
|
|
||||||
newCache[p][o] = DSCompiler(
|
|
||||||
algebra,
|
|
||||||
bufferFactory,
|
|
||||||
p,
|
|
||||||
o,
|
|
||||||
valueCompiler,
|
|
||||||
derivativeCompiler,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return newCache[parameters][order]!!
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile the sizes array.
|
* Compile the sizes array.
|
||||||
*
|
*
|
||||||
|
@ -1,186 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2021 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.expressions
|
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.operations.NumericAlgebra
|
|
||||||
import space.kscience.kmath.operations.Ring
|
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Class representing both the value and the differentials of a function.
|
|
||||||
*
|
|
||||||
* This class is the workhorse of the differentiation package.
|
|
||||||
*
|
|
||||||
* This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper [Doubly Recursive
|
|
||||||
* Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf),
|
|
||||||
* Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used
|
|
||||||
* throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's
|
|
||||||
* derivative structures hold all partial derivatives up to any specified order, with respect to any number of free
|
|
||||||
* parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free
|
|
||||||
* parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters.
|
|
||||||
*
|
|
||||||
* Derived from
|
|
||||||
* [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java).
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public open class DerivativeStructure<T, A> internal constructor(
|
|
||||||
internal val derivativeAlgebra: DerivativeStructureRing<T, A>,
|
|
||||||
internal val compiler: DSCompiler<T, A>,
|
|
||||||
) where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
|
||||||
/**
|
|
||||||
* Combined array holding all values.
|
|
||||||
*/
|
|
||||||
internal var data: MutableBuffer<T> =
|
|
||||||
derivativeAlgebra.bufferFactory(compiler.size) { derivativeAlgebra.algebra.zero }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build an instance with all values and derivatives set to 0.
|
|
||||||
*
|
|
||||||
* @param parameters number of free parameters.
|
|
||||||
* @param order derivation order.
|
|
||||||
*/
|
|
||||||
public constructor (
|
|
||||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
|
||||||
parameters: Int,
|
|
||||||
order: Int,
|
|
||||||
) : this(
|
|
||||||
derivativeAlgebra,
|
|
||||||
getCompiler<T, A>(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, parameters, order),
|
|
||||||
)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build an instance representing a constant value.
|
|
||||||
*
|
|
||||||
* @param parameters number of free parameters.
|
|
||||||
* @param order derivation order.
|
|
||||||
* @param value value of the constant.
|
|
||||||
* @see DerivativeStructure
|
|
||||||
*/
|
|
||||||
public constructor (
|
|
||||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
|
||||||
parameters: Int,
|
|
||||||
order: Int,
|
|
||||||
value: T,
|
|
||||||
) : this(
|
|
||||||
derivativeAlgebra,
|
|
||||||
parameters,
|
|
||||||
order,
|
|
||||||
) {
|
|
||||||
data[0] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build an instance representing a variable.
|
|
||||||
*
|
|
||||||
* Instances built using this constructor are considered to be the free variables with respect to which
|
|
||||||
* differentials are computed. As such, their differential with respect to themselves is +1.
|
|
||||||
*
|
|
||||||
* @param parameters number of free parameters.
|
|
||||||
* @param order derivation order.
|
|
||||||
* @param index index of the variable (from 0 to `parameters - 1`).
|
|
||||||
* @param value value of the variable.
|
|
||||||
*/
|
|
||||||
public constructor (
|
|
||||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
|
||||||
parameters: Int,
|
|
||||||
order: Int,
|
|
||||||
index: Int,
|
|
||||||
value: T,
|
|
||||||
) : this(derivativeAlgebra, parameters, order, value) {
|
|
||||||
require(index < parameters) { "number is too large: $index >= $parameters" }
|
|
||||||
|
|
||||||
if (order > 0) {
|
|
||||||
// the derivative of the variable with respect to itself is 1.
|
|
||||||
data[getCompiler(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, index, order).size] =
|
|
||||||
derivativeAlgebra.algebra.one
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build an instance from all its derivatives.
|
|
||||||
*
|
|
||||||
* @param parameters number of free parameters.
|
|
||||||
* @param order derivation order.
|
|
||||||
* @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex].
|
|
||||||
*/
|
|
||||||
public constructor (
|
|
||||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
|
||||||
parameters: Int,
|
|
||||||
order: Int,
|
|
||||||
vararg derivatives: T,
|
|
||||||
) : this(
|
|
||||||
derivativeAlgebra,
|
|
||||||
parameters,
|
|
||||||
order,
|
|
||||||
) {
|
|
||||||
require(derivatives.size == data.size) { "dimension mismatch: ${derivatives.size} and ${data.size}" }
|
|
||||||
data = derivativeAlgebra.bufferFactory(data.size) { derivatives[it] }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Copy constructor.
|
|
||||||
*
|
|
||||||
* @param ds instance to copy.
|
|
||||||
*/
|
|
||||||
internal constructor(ds: DerivativeStructure<T, A>) : this(ds.derivativeAlgebra, ds.compiler) {
|
|
||||||
this.data = ds.data.copy()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The number of free parameters.
|
|
||||||
*/
|
|
||||||
public val freeParameters: Int
|
|
||||||
get() = compiler.freeParameters
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The derivation order.
|
|
||||||
*/
|
|
||||||
public val order: Int
|
|
||||||
get() = compiler.order
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The value part of the derivative structure.
|
|
||||||
*
|
|
||||||
* @see getPartialDerivative
|
|
||||||
*/
|
|
||||||
public val value: T
|
|
||||||
get() = data[0]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get a partial derivative.
|
|
||||||
*
|
|
||||||
* @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned).
|
|
||||||
* @return partial derivative.
|
|
||||||
* @see value
|
|
||||||
*/
|
|
||||||
public fun getPartialDerivative(vararg orders: Int): T = data[compiler.getPartialDerivativeIndex(*orders)]
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Test for the equality of two derivative structures.
|
|
||||||
*
|
|
||||||
* Derivative structures are considered equal if they have the same number
|
|
||||||
* of free parameters, the same derivation order, and the same derivatives.
|
|
||||||
*
|
|
||||||
* @return `true` if two derivative structures are equal.
|
|
||||||
*/
|
|
||||||
public override fun equals(other: Any?): Boolean {
|
|
||||||
if (this === other) return true
|
|
||||||
|
|
||||||
if (other is DerivativeStructure<*, *>) {
|
|
||||||
return ((freeParameters == other.freeParameters) &&
|
|
||||||
(order == other.order) &&
|
|
||||||
data == other.data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
public override fun hashCode(): Int =
|
|
||||||
227 + 229 * freeParameters + 233 * order + 239 * data.hashCode()
|
|
||||||
}
|
|
@ -1,332 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2021 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.expressions
|
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.operations.*
|
|
||||||
import space.kscience.kmath.structures.MutableBufferFactory
|
|
||||||
import space.kscience.kmath.structures.indices
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A class implementing both [DerivativeStructure] and [Symbol].
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class DerivativeStructureSymbol<T, A>(
|
|
||||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
|
||||||
size: Int,
|
|
||||||
order: Int,
|
|
||||||
index: Int,
|
|
||||||
symbol: Symbol,
|
|
||||||
value: T,
|
|
||||||
) : Symbol by symbol, DerivativeStructure<T, A>(
|
|
||||||
derivativeAlgebra,
|
|
||||||
size,
|
|
||||||
order,
|
|
||||||
index,
|
|
||||||
value
|
|
||||||
) where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
|
||||||
override fun toString(): String = symbol.toString()
|
|
||||||
override fun equals(other: Any?): Boolean = (other as? Symbol) == symbol
|
|
||||||
override fun hashCode(): Int = symbol.hashCode()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A ring over [DerivativeStructure].
|
|
||||||
*
|
|
||||||
* @property order The derivation order.
|
|
||||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public open class DerivativeStructureRing<T, A>(
|
|
||||||
public val algebra: A,
|
|
||||||
public val bufferFactory: MutableBufferFactory<T>,
|
|
||||||
public val order: Int,
|
|
||||||
bindings: Map<Symbol, T>,
|
|
||||||
) : Ring<DerivativeStructure<T, A>>, ScaleOperations<DerivativeStructure<T, A>>,
|
|
||||||
NumericAlgebra<DerivativeStructure<T, A>>,
|
|
||||||
ExpressionAlgebra<T, DerivativeStructure<T, A>>,
|
|
||||||
NumbersAddOps<DerivativeStructure<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
|
||||||
public val numberOfVariables: Int = bindings.size
|
|
||||||
|
|
||||||
override val zero: DerivativeStructure<T, A> by lazy {
|
|
||||||
DerivativeStructure(
|
|
||||||
this,
|
|
||||||
numberOfVariables,
|
|
||||||
order,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override val one: DerivativeStructure<T, A> by lazy {
|
|
||||||
DerivativeStructure(
|
|
||||||
this,
|
|
||||||
numberOfVariables,
|
|
||||||
order,
|
|
||||||
algebra.one,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
|
|
||||||
|
|
||||||
private val variables: Map<Symbol, DerivativeStructureSymbol<T, A>> =
|
|
||||||
bindings.entries.mapIndexed { index, (key, value) ->
|
|
||||||
key to DerivativeStructureSymbol(
|
|
||||||
this,
|
|
||||||
numberOfVariables,
|
|
||||||
order,
|
|
||||||
index,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
}.toMap()
|
|
||||||
|
|
||||||
public override fun const(value: T): DerivativeStructure<T, A> =
|
|
||||||
DerivativeStructure(this, numberOfVariables, order, value)
|
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol<T, A>? = variables[StringSymbol(value)]
|
|
||||||
|
|
||||||
override fun bindSymbol(value: String): DerivativeStructureSymbol<T, A> =
|
|
||||||
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
|
|
||||||
|
|
||||||
public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol<T, A>? = variables[symbol.identity]
|
|
||||||
|
|
||||||
public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol<T, A> =
|
|
||||||
bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this")
|
|
||||||
|
|
||||||
public fun DerivativeStructure<T, A>.derivative(symbols: List<Symbol>): T {
|
|
||||||
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
|
||||||
val ordersCount = symbols.groupBy { it }.mapValues { it.value.size }
|
|
||||||
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DerivativeStructure<T, A>.derivative(vararg symbols: Symbol): T = derivative(symbols.toList())
|
|
||||||
|
|
||||||
override fun DerivativeStructure<T, A>.unaryMinus(): DerivativeStructure<T, A> {
|
|
||||||
val ds = DerivativeStructure(this@DerivativeStructureRing, compiler)
|
|
||||||
for (i in ds.data.indices) {
|
|
||||||
ds.data[i] = algebra { -data[i] }
|
|
||||||
}
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun add(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
left.compiler.checkCompatibility(right.compiler)
|
|
||||||
val ds = DerivativeStructure(left)
|
|
||||||
left.compiler.add(left.data, 0, right.data, 0, ds.data, 0)
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun scale(a: DerivativeStructure<T, A>, value: Double): DerivativeStructure<T, A> {
|
|
||||||
val ds = DerivativeStructure(a)
|
|
||||||
for (i in ds.data.indices) {
|
|
||||||
ds.data[i] = algebra { ds.data[i].times(value) }
|
|
||||||
}
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(
|
|
||||||
left: DerivativeStructure<T, A>,
|
|
||||||
right: DerivativeStructure<T, A>
|
|
||||||
): DerivativeStructure<T, A> {
|
|
||||||
left.compiler.checkCompatibility(right.compiler)
|
|
||||||
val result = DerivativeStructure(this, left.compiler)
|
|
||||||
left.compiler.multiply(left.data, 0, right.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DerivativeStructure<T, A>.minus(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
compiler.checkCompatibility(arg.compiler)
|
|
||||||
val ds = DerivativeStructure(this)
|
|
||||||
compiler.subtract(data, 0, arg.data, 0, ds.data, 0)
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun DerivativeStructure<T, A>.plus(other: Number): DerivativeStructure<T, A> {
|
|
||||||
val ds = DerivativeStructure(this)
|
|
||||||
ds.data[0] = algebra { ds.data[0] + number(other) }
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun DerivativeStructure<T, A>.minus(other: Number): DerivativeStructure<T, A> =
|
|
||||||
this + -other.toDouble()
|
|
||||||
|
|
||||||
override operator fun Number.plus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other + this
|
|
||||||
override operator fun Number.minus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other - this
|
|
||||||
}
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class DerivativeStructureRingExpression<T, A>(
|
|
||||||
public val algebra: A,
|
|
||||||
public val bufferFactory: MutableBufferFactory<T>,
|
|
||||||
public val function: DerivativeStructureRing<T, A>.() -> DerivativeStructure<T, A>,
|
|
||||||
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
|
||||||
DerivativeStructureRing(algebra, bufferFactory, 0, arguments).function().value
|
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
|
||||||
with(
|
|
||||||
DerivativeStructureRing(
|
|
||||||
algebra,
|
|
||||||
bufferFactory,
|
|
||||||
symbols.size,
|
|
||||||
arguments
|
|
||||||
)
|
|
||||||
) { function().derivative(symbols) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A field over commons-math [DerivativeStructure].
|
|
||||||
*
|
|
||||||
* @property order The derivation order.
|
|
||||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class DerivativeStructureField<T, A : ExtendedField<T>>(
|
|
||||||
algebra: A,
|
|
||||||
bufferFactory: MutableBufferFactory<T>,
|
|
||||||
order: Int,
|
|
||||||
bindings: Map<Symbol, T>,
|
|
||||||
) : DerivativeStructureRing<T, A>(algebra, bufferFactory, order, bindings), ExtendedField<DerivativeStructure<T, A>> {
|
|
||||||
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
|
|
||||||
|
|
||||||
override fun divide(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
left.compiler.checkCompatibility(right.compiler)
|
|
||||||
val result = DerivativeStructure(this, left.compiler)
|
|
||||||
left.compiler.divide(left.data, 0, right.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.sin(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun cos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.cos(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun tan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.tan(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun asin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.asin(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun acos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.acos(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun atan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.atan(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.sinh(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun cosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.cosh(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun tanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.tanh(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun asinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.asinh(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun acosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.acosh(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun atanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.atanh(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun power(arg: DerivativeStructure<T, A>, pow: Number): DerivativeStructure<T, A> = when (pow) {
|
|
||||||
is Int -> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.pow(arg.data, 0, pow, result.data, 0)
|
|
||||||
result
|
|
||||||
}
|
|
||||||
else -> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.pow(arg.data, 0, pow.toDouble(), result.data, 0)
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sqrt(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.sqrt(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun power(arg: DerivativeStructure<T, A>, pow: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
arg.compiler.checkCompatibility(pow.compiler)
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.pow(arg.data, 0, pow.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun exp(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.exp(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun ln(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
val result = DerivativeStructure(this, arg.compiler)
|
|
||||||
arg.compiler.ln(arg.data, 0, result.data, 0)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class DerivativeStructureFieldExpression<T, A : ExtendedField<T>>(
|
|
||||||
public val algebra: A,
|
|
||||||
public val bufferFactory: MutableBufferFactory<T>,
|
|
||||||
public val function: DerivativeStructureField<T, A>.() -> DerivativeStructure<T, A>,
|
|
||||||
) : DifferentiableExpression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
|
||||||
DerivativeStructureField(algebra, bufferFactory, 0, arguments).function().value
|
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
|
||||||
with(
|
|
||||||
DerivativeStructureField(
|
|
||||||
algebra,
|
|
||||||
bufferFactory,
|
|
||||||
symbols.size,
|
|
||||||
arguments,
|
|
||||||
)
|
|
||||||
) { function().derivative(symbols) }
|
|
||||||
}
|
|
||||||
}
|
|
@ -34,12 +34,12 @@ public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
|
|||||||
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
{ left, right ->
|
{ left, right ->
|
||||||
Expression { arguments ->
|
Expression { arguments ->
|
||||||
algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments))
|
algebra.binaryOperationFunction(operation)(left(arguments), right(arguments))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
|
override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
|
||||||
Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) }
|
Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,8 +164,6 @@ public open class FunctionalExpressionExtendedField<T, out A : ExtendedField<T>>
|
|||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
||||||
super<FunctionalExpressionField>.binaryOperationFunction(operation)
|
super<FunctionalExpressionField>.binaryOperationFunction(operation)
|
||||||
|
|
||||||
override fun bindSymbol(value: String): Expression<T> = super<FunctionalExpressionField>.bindSymbol(value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <T, A : Group<T>> A.expressionInGroup(
|
public inline fun <T, A : Group<T>> A.expressionInGroup(
|
||||||
|
@ -24,7 +24,7 @@ public sealed interface MST {
|
|||||||
public data class Numeric(val value: Number) : MST
|
public data class Numeric(val value: Number) : MST
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A node containing an unary operation.
|
* A node containing a unary operation.
|
||||||
*
|
*
|
||||||
* @property operation the identifier of operation.
|
* @property operation the identifier of operation.
|
||||||
* @property value the argument of this operation.
|
* @property value the argument of this operation.
|
||||||
@ -34,7 +34,7 @@ public sealed interface MST {
|
|||||||
/**
|
/**
|
||||||
* A node containing binary operation.
|
* A node containing binary operation.
|
||||||
*
|
*
|
||||||
* @property operation the identifier operation.
|
* @property operation the identifier of operation.
|
||||||
* @property left the left operand.
|
* @property left the left operand.
|
||||||
* @property right the right operand.
|
* @property right the right operand.
|
||||||
*/
|
*/
|
||||||
|
@ -343,10 +343,7 @@ public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: Au
|
|||||||
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
||||||
context: F,
|
context: F,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>,
|
) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
|
||||||
SimpleAutoDiffField<T, F>(context, bindings) {
|
|
||||||
|
|
||||||
override fun bindSymbol(value: String): AutoDiffValue<T> = super<SimpleAutoDiffField>.bindSymbol(value)
|
|
||||||
|
|
||||||
override fun number(value: Number): AutoDiffValue<T> = const { number(value) }
|
override fun number(value: Number): AutoDiffValue<T> = const { number(value) }
|
||||||
|
|
||||||
|
@ -188,7 +188,7 @@ public interface LinearSpace<T, out A : Ring<T>> {
|
|||||||
*/
|
*/
|
||||||
public fun <T : Any, A : Ring<T>> buffered(
|
public fun <T : Any, A : Ring<T>> buffered(
|
||||||
algebra: A,
|
algebra: A,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
|
||||||
): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory))
|
): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory))
|
||||||
|
|
||||||
@Deprecated("use DoubleField.linearSpace")
|
@Deprecated("use DoubleField.linearSpace")
|
||||||
|
@ -27,5 +27,5 @@ public annotation class UnstableKMathAPI
|
|||||||
RequiresOptIn.Level.WARNING,
|
RequiresOptIn.Level.WARNING,
|
||||||
)
|
)
|
||||||
public annotation class PerformancePitfall(
|
public annotation class PerformancePitfall(
|
||||||
val message: String = "Potential performance problem"
|
val message: String = "Potential performance problem",
|
||||||
)
|
)
|
||||||
|
@ -3,43 +3,71 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
package space.kscience.kmath.misc
|
package space.kscience.kmath.misc
|
||||||
|
|
||||||
import kotlin.comparisons.*
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.structures.VirtualBuffer
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a new list filled with buffer indices. Indice order is defined by sorting associated buffer value.
|
* Return a new array filled with buffer indices. Indices order is defined by sorting associated buffer value.
|
||||||
* This feature allows to sort buffer values without reordering its content.
|
* This feature allows sorting buffer values without reordering its content.
|
||||||
*
|
*
|
||||||
* @return List of buffer indices, sorted by associated value.
|
* @return Buffer indices, sorted by associated value.
|
||||||
*/
|
*/
|
||||||
@PerformancePitfall
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <V: Comparable<V>> Buffer<V>.permSort() : IntArray = _permSortWith(compareBy<Int> { get(it) })
|
public fun <V : Comparable<V>> Buffer<V>.indicesSorted(): IntArray = permSortIndicesWith(compareBy { get(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a zero-copy virtual buffer that contains the same elements but in ascending order
|
||||||
|
*/
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public fun <V : Comparable<V>> Buffer<V>.sorted(): Buffer<V> {
|
||||||
|
val permutations = indicesSorted()
|
||||||
|
return VirtualBuffer(size) { this[permutations[it]] }
|
||||||
|
}
|
||||||
|
|
||||||
@PerformancePitfall
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <V: Comparable<V>> Buffer<V>.permSortDescending() : IntArray = _permSortWith(compareByDescending<Int> { get(it) })
|
public fun <V : Comparable<V>> Buffer<V>.indicesSortedDescending(): IntArray =
|
||||||
|
permSortIndicesWith(compareByDescending { get(it) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a zero-copy virtual buffer that contains the same elements but in descending order
|
||||||
|
*/
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public fun <V : Comparable<V>> Buffer<V>.sortedDescending(): Buffer<V> {
|
||||||
|
val permutations = indicesSortedDescending()
|
||||||
|
return VirtualBuffer(size) { this[permutations[it]] }
|
||||||
|
}
|
||||||
|
|
||||||
@PerformancePitfall
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <V, C: Comparable<C>> Buffer<V>.permSortBy(selector: (V) -> C) : IntArray = _permSortWith(compareBy<Int> { selector(get(it)) })
|
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedBy(selector: (V) -> C): IntArray =
|
||||||
|
permSortIndicesWith(compareBy { selector(get(it)) })
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public fun <V, C : Comparable<C>> Buffer<V>.sortedBy(selector: (V) -> C): Buffer<V> {
|
||||||
|
val permutations = indicesSortedBy(selector)
|
||||||
|
return VirtualBuffer(size) { this[permutations[it]] }
|
||||||
|
}
|
||||||
|
|
||||||
@PerformancePitfall
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <V, C: Comparable<C>> Buffer<V>.permSortByDescending(selector: (V) -> C) : IntArray = _permSortWith(compareByDescending<Int> { selector(get(it)) })
|
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedByDescending(selector: (V) -> C): IntArray =
|
||||||
|
permSortIndicesWith(compareByDescending { selector(get(it)) })
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public fun <V, C : Comparable<C>> Buffer<V>.sortedByDescending(selector: (V) -> C): Buffer<V> {
|
||||||
|
val permutations = indicesSortedByDescending(selector)
|
||||||
|
return VirtualBuffer(size) { this[permutations[it]] }
|
||||||
|
}
|
||||||
|
|
||||||
@PerformancePitfall
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <V> Buffer<V>.permSortWith(comparator : Comparator<V>) : IntArray = _permSortWith { i1, i2 -> comparator.compare(get(i1), get(i2)) }
|
public fun <V> Buffer<V>.indicesSortedWith(comparator: Comparator<V>): IntArray =
|
||||||
|
permSortIndicesWith { i1, i2 -> comparator.compare(get(i1), get(i2)) }
|
||||||
|
|
||||||
@PerformancePitfall
|
private fun <V> Buffer<V>.permSortIndicesWith(comparator: Comparator<Int>): IntArray {
|
||||||
@UnstableKMathAPI
|
if (size < 2) return IntArray(size) { 0 }
|
||||||
private fun <V> Buffer<V>._permSortWith(comparator : Comparator<Int>) : IntArray {
|
|
||||||
if (size < 2) return IntArray(size)
|
|
||||||
|
|
||||||
/* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indice
|
/* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indices
|
||||||
* arrays more efficiently by copying subpart of cached one. For bigger needs, we could copy entire
|
* arrays more efficiently by copying subpart of cached one. For bigger needs, we could copy entire
|
||||||
* cached array, then fill remaining indices manually. Not done for now, because:
|
* cached array, then fill remaining indices manually. Not done for now, because:
|
||||||
* 1. doing it right would require some statistics about common used buffer sizes.
|
* 1. doing it right would require some statistics about common used buffer sizes.
|
||||||
@ -53,3 +81,12 @@ private fun <V> Buffer<V>._permSortWith(comparator : Comparator<Int>) : IntArray
|
|||||||
*/
|
*/
|
||||||
return packedIndices.sortedWith(comparator).toIntArray()
|
return packedIndices.sortedWith(comparator).toIntArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks that the [Buffer] is sorted (ascending) and throws [IllegalArgumentException] if it is not.
|
||||||
|
*/
|
||||||
|
public fun <T : Comparable<T>> Buffer<T>.requireSorted() {
|
||||||
|
for (i in 0..(size - 2)) {
|
||||||
|
require(get(i + 1) >= get(i)) { "The buffer is not sorted at index $i" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -10,25 +10,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI
|
|||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
|
||||||
* An exception is thrown when the expected and actual shape of NDArray differ.
|
|
||||||
*
|
|
||||||
* @property expected the expected shape.
|
|
||||||
* @property actual the actual shape.
|
|
||||||
*/
|
|
||||||
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
|
|
||||||
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
|
||||||
|
|
||||||
public typealias Shape = IntArray
|
|
||||||
|
|
||||||
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
|
|
||||||
|
|
||||||
public interface WithShape {
|
|
||||||
public val shape: Shape
|
|
||||||
|
|
||||||
public val indices: ShapeIndexer get() = DefaultStrides(shape)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The base interface for all ND-algebra implementations.
|
* The base interface for all ND-algebra implementations.
|
||||||
*
|
*
|
||||||
@ -194,7 +175,7 @@ public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, Gro
|
|||||||
override fun multiply(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
|
override fun multiply(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
|
||||||
zip(left, right) { aValue, bValue -> multiply(aValue, bValue) }
|
zip(left, right) { aValue, bValue -> multiply(aValue, bValue) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions with context receivers
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Multiplies an ND structure by an element of it.
|
* Multiplies an ND structure by an element of it.
|
||||||
|
@ -47,7 +47,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
|||||||
zipInline(left.toBufferND(), right.toBufferND(), transform)
|
zipInline(left.toBufferND(), right.toBufferND(), transform)
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = DefaultStrides.Companion::invoke
|
public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = ::Strides
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,17 +32,22 @@ public open class BufferND<out T>(
|
|||||||
/**
|
/**
|
||||||
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferND]
|
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferND]
|
||||||
*/
|
*/
|
||||||
public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
public inline fun <T, R : Any> StructureND<T>.mapToBuffer(
|
||||||
factory: BufferFactory<R> = Buffer.Companion::auto,
|
factory: BufferFactory<R>,
|
||||||
crossinline transform: (T) -> R,
|
crossinline transform: (T) -> R,
|
||||||
): BufferND<R> {
|
): BufferND<R> = if (this is BufferND<T>)
|
||||||
return if (this is BufferND<T>)
|
|
||||||
BufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) })
|
BufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) })
|
||||||
else {
|
else {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
/**
|
||||||
|
* Transform structure to a new structure using inferred [BufferFactory]
|
||||||
|
*/
|
||||||
|
public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
||||||
|
crossinline transform: (T) -> R,
|
||||||
|
): BufferND<R> = mapToBuffer(Buffer.Companion::auto, transform)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [MutableStructureND] over [MutableBuffer].
|
* Represents [MutableStructureND] over [MutableBuffer].
|
||||||
@ -64,7 +69,7 @@ public class MutableBufferND<T>(
|
|||||||
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
|
* Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND]
|
||||||
*/
|
*/
|
||||||
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
|
public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
|
||||||
factory: MutableBufferFactory<R> = MutableBuffer.Companion::auto,
|
factory: MutableBufferFactory<R> = MutableBufferFactory(MutableBuffer.Companion::auto),
|
||||||
crossinline transform: (T) -> R,
|
crossinline transform: (T) -> R,
|
||||||
): MutableBufferND<R> {
|
): MutableBufferND<R> {
|
||||||
return if (this is MutableBufferND<T>)
|
return if (this is MutableBufferND<T>)
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An exception is thrown when the expected and actual shape of NDArray differ.
|
||||||
|
*
|
||||||
|
* @property expected the expected shape.
|
||||||
|
* @property actual the actual shape.
|
||||||
|
*/
|
||||||
|
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
|
||||||
|
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
||||||
|
|
||||||
|
public class IndexOutOfShapeException(public val shape: Shape, public val index: IntArray) :
|
||||||
|
RuntimeException("Index ${index.contentToString()} is out of shape ${shape.contentToString()}")
|
||||||
|
|
||||||
|
public typealias Shape = IntArray
|
||||||
|
|
||||||
|
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
|
||||||
|
|
||||||
|
public interface WithShape {
|
||||||
|
public val shape: Shape
|
||||||
|
|
||||||
|
public val indices: ShapeIndexer get() = DefaultStrides(shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun requireIndexInShape(index: IntArray, shape: Shape) {
|
||||||
|
if (index.size != shape.size) throw IndexOutOfShapeException(index, shape)
|
||||||
|
shape.forEachIndexed { axis, axisShape ->
|
||||||
|
if (index[axis] !in 0 until axisShape) throw IndexOutOfShapeException(index, shape)
|
||||||
|
}
|
||||||
|
}
|
@ -66,7 +66,7 @@ public abstract class Strides: ShapeIndexer {
|
|||||||
/**
|
/**
|
||||||
* Simple implementation of [Strides].
|
* Simple implementation of [Strides].
|
||||||
*/
|
*/
|
||||||
public class DefaultStrides private constructor(override val shape: IntArray) : Strides() {
|
public class DefaultStrides(override val shape: IntArray) : Strides() {
|
||||||
override val linearSize: Int get() = strides[shape.size]
|
override val linearSize: Int get() = strides[shape.size]
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -112,6 +112,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
/**
|
/**
|
||||||
* Cached builder for default strides
|
* Cached builder for default strides
|
||||||
*/
|
*/
|
||||||
|
@Deprecated("Replace by Strides(shape)")
|
||||||
public operator fun invoke(shape: IntArray): Strides =
|
public operator fun invoke(shape: IntArray): Strides =
|
||||||
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
||||||
}
|
}
|
||||||
@ -119,3 +120,8 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
|
|
||||||
@ThreadLocal
|
@ThreadLocal
|
||||||
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cached builder for default strides
|
||||||
|
*/
|
||||||
|
public fun Strides(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
@ -138,6 +138,10 @@ private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>)
|
|||||||
override fun equals(other: Any?): Boolean = false
|
override fun equals(other: Any?): Boolean = false
|
||||||
|
|
||||||
override fun hashCode(): Int = 0
|
override fun hashCode(): Int = 0
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
return StructureND.toString(structure)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -101,7 +101,7 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
|
|||||||
val bufferRepr: String = when (structure.shape.size) {
|
val bufferRepr: String = when (structure.shape.size) {
|
||||||
1 -> (0 until structure.shape[0]).map { structure[it] }
|
1 -> (0 until structure.shape[0]).map { structure[it] }
|
||||||
.joinToString(prefix = "[", postfix = "]", separator = ", ")
|
.joinToString(prefix = "[", postfix = "]", separator = ", ")
|
||||||
2 -> (0 until structure.shape[0]).joinToString(prefix = "[", postfix = "]", separator = ", ") { i ->
|
2 -> (0 until structure.shape[0]).joinToString(prefix = "[\n", postfix = "\n]", separator = ",\n") { i ->
|
||||||
(0 until structure.shape[1]).joinToString(prefix = " [", postfix = "]", separator = ", ") { j ->
|
(0 until structure.shape[1]).joinToString(prefix = " [", postfix = "]", separator = ", ") { j ->
|
||||||
structure[i, j].toString()
|
structure[i, j].toString()
|
||||||
}
|
}
|
||||||
@ -120,7 +120,7 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
|
|||||||
*/
|
*/
|
||||||
public fun <T> buffered(
|
public fun <T> buffered(
|
||||||
strides: Strides,
|
strides: Strides,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
|
||||||
initializer: (IntArray) -> T,
|
initializer: (IntArray) -> T,
|
||||||
): BufferND<T> = BufferND(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
): BufferND<T> = BufferND(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
|
|||||||
|
|
||||||
public fun <T> buffered(
|
public fun <T> buffered(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = BufferFactory(Buffer.Companion::boxing),
|
||||||
initializer: (IntArray) -> T,
|
initializer: (IntArray) -> T,
|
||||||
): BufferND<T> = buffered(DefaultStrides(shape), bufferFactory, initializer)
|
): BufferND<T> = buffered(DefaultStrides(shape), bufferFactory, initializer)
|
||||||
|
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
|
||||||
|
public open class VirtualStructureND<T>(
|
||||||
|
override val shape: Shape,
|
||||||
|
public val producer: (IntArray) -> T,
|
||||||
|
) : StructureND<T> {
|
||||||
|
override fun get(index: IntArray): T {
|
||||||
|
requireIndexInShape(index, shape)
|
||||||
|
return producer(index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class VirtualDoubleStructureND(
|
||||||
|
shape: Shape,
|
||||||
|
producer: (IntArray) -> Double,
|
||||||
|
) : VirtualStructureND<Double>(shape, producer)
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class VirtualIntStructureND(
|
||||||
|
shape: Shape,
|
||||||
|
producer: (IntArray) -> Int,
|
||||||
|
) : VirtualStructureND<Int>(shape, producer)
|
@ -0,0 +1,32 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
public fun <T> StructureND<T>.roll(axis: Int, step: Int = 1): StructureND<T> {
|
||||||
|
require(axis in shape.indices) { "Axis $axis is outside of shape dimensions: [0, ${shape.size})" }
|
||||||
|
return VirtualStructureND(shape) { index ->
|
||||||
|
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
|
||||||
|
if (indexAxis == axis) {
|
||||||
|
(index[indexAxis] + step).mod(shape[indexAxis])
|
||||||
|
} else {
|
||||||
|
index[indexAxis]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
get(newIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T> StructureND<T>.roll(pair: Pair<Int, Int>, vararg others: Pair<Int, Int>): StructureND<T> {
|
||||||
|
val axisMap: Map<Int, Int> = mapOf(pair, *others)
|
||||||
|
require(axisMap.keys.all { it in shape.indices }) { "Some of axes ${axisMap.keys} is outside of shape dimensions: [0, ${shape.size})" }
|
||||||
|
return VirtualStructureND(shape) { index ->
|
||||||
|
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
|
||||||
|
val offset = axisMap[indexAxis] ?: 0
|
||||||
|
(index[indexAxis] + offset).mod(shape[indexAxis])
|
||||||
|
}
|
||||||
|
get(newIndex)
|
||||||
|
}
|
||||||
|
}
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
@ -53,7 +52,7 @@ public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
|
|||||||
*/
|
*/
|
||||||
private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
|
private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
|
||||||
buffer: Buffer<T>,
|
buffer: Buffer<T>,
|
||||||
crossinline block: A.(T) -> T
|
crossinline block: A.(T) -> T,
|
||||||
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) }
|
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -61,7 +60,7 @@ private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
|
|||||||
*/
|
*/
|
||||||
private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
|
private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
|
||||||
buffer: Buffer<T>,
|
buffer: Buffer<T>,
|
||||||
crossinline block: A.(index: Int, arg: T) -> T
|
crossinline block: A.(index: Int, arg: T) -> T,
|
||||||
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) }
|
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -70,7 +69,7 @@ private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
|
|||||||
private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline(
|
private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline(
|
||||||
l: Buffer<T>,
|
l: Buffer<T>,
|
||||||
r: Buffer<T>,
|
r: Buffer<T>,
|
||||||
crossinline block: A.(l: T, r: T) -> T
|
crossinline block: A.(l: T, r: T) -> T,
|
||||||
): Buffer<T> {
|
): Buffer<T> {
|
||||||
require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" }
|
require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" }
|
||||||
return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) }
|
return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) }
|
||||||
@ -152,10 +151,11 @@ public val ShortRing.bufferAlgebra: BufferRingOps<Short, ShortRing>
|
|||||||
public open class BufferFieldOps<T, A : Field<T>>(
|
public open class BufferFieldOps<T, A : Field<T>>(
|
||||||
elementAlgebra: A,
|
elementAlgebra: A,
|
||||||
bufferFactory: BufferFactory<T>,
|
bufferFactory: BufferFactory<T>,
|
||||||
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> {
|
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>,
|
||||||
|
ScaleOperations<Buffer<T>> {
|
||||||
|
|
||||||
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
|
// override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
|
||||||
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
|
// override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
|
||||||
override fun divide(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l / r }
|
override fun divide(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l / r }
|
||||||
|
|
||||||
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
|
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
|
||||||
@ -168,7 +168,7 @@ public open class BufferFieldOps<T, A : Field<T>>(
|
|||||||
public class BufferField<T, A : Field<T>>(
|
public class BufferField<T, A : Field<T>>(
|
||||||
elementAlgebra: A,
|
elementAlgebra: A,
|
||||||
bufferFactory: BufferFactory<T>,
|
bufferFactory: BufferFactory<T>,
|
||||||
override val size: Int
|
override val size: Int,
|
||||||
) : BufferFieldOps<T, A>(elementAlgebra, bufferFactory), Field<Buffer<T>>, WithSize {
|
) : BufferFieldOps<T, A>(elementAlgebra, bufferFactory), Field<Buffer<T>>, WithSize {
|
||||||
|
|
||||||
override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero }
|
override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero }
|
||||||
|
@ -6,12 +6,10 @@
|
|||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
import space.kscience.kmath.linear.Point
|
import space.kscience.kmath.linear.Point
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.asBuffer
|
||||||
|
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -21,7 +19,7 @@ public abstract class DoubleBufferOps : BufferAlgebra<Double, DoubleField>, Exte
|
|||||||
Norm<Buffer<Double>, Double> {
|
Norm<Buffer<Double>, Double> {
|
||||||
|
|
||||||
override val elementAlgebra: DoubleField get() = DoubleField
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer
|
override val bufferFactory: BufferFactory<Double> get() = BufferFactory(::DoubleBuffer)
|
||||||
|
|
||||||
override fun Buffer<Double>.map(block: DoubleField.(Double) -> Double): DoubleBuffer =
|
override fun Buffer<Double>.map(block: DoubleField.(Double) -> Double): DoubleBuffer =
|
||||||
mapInline { DoubleField.block(it) }
|
mapInline { DoubleField.block(it) }
|
||||||
|
@ -61,31 +61,39 @@ public inline fun <reified T> Buffer<T>.toTypedArray(): Array<T> = Array(size, :
|
|||||||
/**
|
/**
|
||||||
* Create a new buffer from this one with the given mapping function and using [Buffer.Companion.auto] buffer factory.
|
* Create a new buffer from this one with the given mapping function and using [Buffer.Companion.auto] buffer factory.
|
||||||
*/
|
*/
|
||||||
public inline fun <T : Any, reified R : Any> Buffer<T>.map(block: (T) -> R): Buffer<R> =
|
public inline fun <T, reified R : Any> Buffer<T>.map(block: (T) -> R): Buffer<R> =
|
||||||
Buffer.auto(size) { block(get(it)) }
|
Buffer.auto(size) { block(get(it)) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new buffer from this one with the given mapping function.
|
* Create a new buffer from this one with the given mapping function.
|
||||||
* Provided [bufferFactory] is used to construct the new buffer.
|
* Provided [bufferFactory] is used to construct the new buffer.
|
||||||
*/
|
*/
|
||||||
public inline fun <T : Any, R : Any> Buffer<T>.map(
|
public inline fun <T, R> Buffer<T>.map(
|
||||||
bufferFactory: BufferFactory<R>,
|
bufferFactory: BufferFactory<R>,
|
||||||
crossinline block: (T) -> R,
|
crossinline block: (T) -> R,
|
||||||
): Buffer<R> = bufferFactory(size) { block(get(it)) }
|
): Buffer<R> = bufferFactory(size) { block(get(it)) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new buffer from this one with the given indexed mapping function.
|
* Create a new buffer from this one with the given mapping (indexed) function.
|
||||||
* Provided [BufferFactory] is used to construct the new buffer.
|
* Provided [bufferFactory] is used to construct the new buffer.
|
||||||
*/
|
*/
|
||||||
public inline fun <T : Any, reified R : Any> Buffer<T>.mapIndexed(
|
public inline fun <T, R> Buffer<T>.mapIndexed(
|
||||||
bufferFactory: BufferFactory<R> = Buffer.Companion::auto,
|
bufferFactory: BufferFactory<R>,
|
||||||
crossinline block: (index: Int, value: T) -> R,
|
crossinline block: (index: Int, value: T) -> R,
|
||||||
): Buffer<R> = bufferFactory(size) { block(it, get(it)) }
|
): Buffer<R> = bufferFactory(size) { block(it, get(it)) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new buffer from this one with the given indexed mapping function.
|
||||||
|
* Provided [BufferFactory] is used to construct the new buffer.
|
||||||
|
*/
|
||||||
|
public inline fun <T, reified R : Any> Buffer<T>.mapIndexed(
|
||||||
|
crossinline block: (index: Int, value: T) -> R,
|
||||||
|
): Buffer<R> = BufferFactory<R>(Buffer.Companion::auto).invoke(size) { block(it, get(it)) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fold given buffer according to [operation]
|
* Fold given buffer according to [operation]
|
||||||
*/
|
*/
|
||||||
public inline fun <T : Any, R> Buffer<T>.fold(initial: R, operation: (acc: R, T) -> R): R {
|
public inline fun <T, R> Buffer<T>.fold(initial: R, operation: (acc: R, T) -> R): R {
|
||||||
var accumulator = initial
|
var accumulator = initial
|
||||||
for (index in this.indices) accumulator = operation(accumulator, get(index))
|
for (index in this.indices) accumulator = operation(accumulator, get(index))
|
||||||
return accumulator
|
return accumulator
|
||||||
@ -95,9 +103,9 @@ public inline fun <T : Any, R> Buffer<T>.fold(initial: R, operation: (acc: R, T)
|
|||||||
* Zip two buffers using given [transform].
|
* Zip two buffers using given [transform].
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public inline fun <T1 : Any, T2 : Any, reified R : Any> Buffer<T1>.zip(
|
public inline fun <T1, T2 : Any, reified R : Any> Buffer<T1>.zip(
|
||||||
other: Buffer<T2>,
|
other: Buffer<T2>,
|
||||||
bufferFactory: BufferFactory<R> = Buffer.Companion::auto,
|
bufferFactory: BufferFactory<R> = BufferFactory(Buffer.Companion::auto),
|
||||||
crossinline transform: (T1, T2) -> R,
|
crossinline transform: (T1, T2) -> R,
|
||||||
): Buffer<R> {
|
): Buffer<R> {
|
||||||
require(size == other.size) { "Buffer size mismatch in zip: expected $size but found ${other.size}" }
|
require(size == other.size) { "Buffer size mismatch in zip: expected $size but found ${other.size}" }
|
||||||
|
@ -14,14 +14,18 @@ import kotlin.reflect.KClass
|
|||||||
*
|
*
|
||||||
* @param T the type of buffer.
|
* @param T the type of buffer.
|
||||||
*/
|
*/
|
||||||
public typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
public fun interface BufferFactory<T> {
|
||||||
|
public operator fun invoke(size: Int, builder: (Int) -> T): Buffer<T>
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Function that produces [MutableBuffer] from its size and function that supplies values.
|
* Function that produces [MutableBuffer] from its size and function that supplies values.
|
||||||
*
|
*
|
||||||
* @param T the type of buffer.
|
* @param T the type of buffer.
|
||||||
*/
|
*/
|
||||||
public typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
|
public fun interface MutableBufferFactory<T>: BufferFactory<T>{
|
||||||
|
override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer<T>
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A generic read-only random-access structure for both primitives and objects.
|
* A generic read-only random-access structure for both primitives and objects.
|
||||||
@ -105,6 +109,16 @@ public interface Buffer<out T> {
|
|||||||
*/
|
*/
|
||||||
public val Buffer<*>.indices: IntRange get() = 0 until size
|
public val Buffer<*>.indices: IntRange get() = 0 until size
|
||||||
|
|
||||||
|
public fun <T> Buffer<T>.first(): T {
|
||||||
|
require(size > 0) { "Can't get the first element of empty buffer" }
|
||||||
|
return get(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T> Buffer<T>.last(): T {
|
||||||
|
require(size > 0) { "Can't get the last element of empty buffer" }
|
||||||
|
return get(size - 1)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Immutable wrapper for [MutableBuffer].
|
* Immutable wrapper for [MutableBuffer].
|
||||||
*
|
*
|
||||||
|
@ -19,10 +19,10 @@ import kotlin.test.assertFails
|
|||||||
internal inline fun diff(
|
internal inline fun diff(
|
||||||
order: Int,
|
order: Int,
|
||||||
vararg parameters: Pair<Symbol, Double>,
|
vararg parameters: Pair<Symbol, Double>,
|
||||||
block: DerivativeStructureField<Double, DoubleField>.() -> Unit,
|
block: DSField<Double, DoubleField>.() -> Unit,
|
||||||
) {
|
) {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
DerivativeStructureField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block()
|
DSField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block()
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AutoDiffTest {
|
internal class AutoDiffTest {
|
||||||
@ -30,7 +30,7 @@ internal class AutoDiffTest {
|
|||||||
private val y by symbol
|
private val y by symbol
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun derivativeStructureFieldTest() {
|
fun dsAlgebraTest() {
|
||||||
diff(2, x to 1.0, y to 1.0) {
|
diff(2, x to 1.0, y to 1.0) {
|
||||||
val x = bindSymbol(x)//by binding()
|
val x = bindSymbol(x)//by binding()
|
||||||
val y = bindSymbol("y")
|
val y = bindSymbol("y")
|
||||||
@ -44,8 +44,8 @@ internal class AutoDiffTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun autoDifTest() {
|
fun dsExpressionTest() {
|
||||||
val f = DerivativeStructureFieldExpression(DoubleField, ::DoubleBuffer) {
|
val f = DSFieldExpression(DoubleField, ::DoubleBuffer) {
|
||||||
val x by binding
|
val x by binding
|
||||||
val y by binding
|
val y by binding
|
||||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||||
|
@ -6,14 +6,13 @@
|
|||||||
package space.kscience.kmath.misc
|
package space.kscience.kmath.misc
|
||||||
|
|
||||||
import space.kscience.kmath.misc.PermSortTest.Platform.*
|
import space.kscience.kmath.misc.PermSortTest.Platform.*
|
||||||
import kotlin.random.Random
|
|
||||||
import kotlin.test.Test
|
|
||||||
import kotlin.test.assertEquals
|
|
||||||
import kotlin.test.assertTrue
|
|
||||||
|
|
||||||
import space.kscience.kmath.structures.IntBuffer
|
import space.kscience.kmath.structures.IntBuffer
|
||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertContentEquals
|
import kotlin.test.assertContentEquals
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class PermSortTest {
|
class PermSortTest {
|
||||||
|
|
||||||
@ -29,9 +28,9 @@ class PermSortTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testOnEmptyBuffer() {
|
fun testOnEmptyBuffer() {
|
||||||
val emptyBuffer = IntBuffer(0) {it}
|
val emptyBuffer = IntBuffer(0) {it}
|
||||||
var permutations = emptyBuffer.permSort()
|
var permutations = emptyBuffer.indicesSorted()
|
||||||
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
|
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
|
||||||
permutations = emptyBuffer.permSortDescending()
|
permutations = emptyBuffer.indicesSortedDescending()
|
||||||
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
|
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,25 +46,25 @@ class PermSortTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPermSortBy() {
|
fun testPermSortBy() {
|
||||||
val permutations = platforms.permSortBy { it.name }
|
val permutations = platforms.indicesSortedBy { it.name }
|
||||||
val expected = listOf(ANDROID, JS, JVM, NATIVE, WASM)
|
val expected = listOf(ANDROID, JS, JVM, NATIVE, WASM)
|
||||||
assertContentEquals(expected, permutations.map { platforms[it] }, "Ascending PermSort by name")
|
assertContentEquals(expected, permutations.map { platforms[it] }, "Ascending PermSort by name")
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPermSortByDescending() {
|
fun testPermSortByDescending() {
|
||||||
val permutations = platforms.permSortByDescending { it.name }
|
val permutations = platforms.indicesSortedByDescending { it.name }
|
||||||
val expected = listOf(WASM, NATIVE, JVM, JS, ANDROID)
|
val expected = listOf(WASM, NATIVE, JVM, JS, ANDROID)
|
||||||
assertContentEquals(expected, permutations.map { platforms[it] }, "Descending PermSort by name")
|
assertContentEquals(expected, permutations.map { platforms[it] }, "Descending PermSort by name")
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPermSortWith() {
|
fun testPermSortWith() {
|
||||||
var permutations = platforms.permSortWith { p1, p2 -> p1.name.length.compareTo(p2.name.length) }
|
var permutations = platforms.indicesSortedWith { p1, p2 -> p1.name.length.compareTo(p2.name.length) }
|
||||||
val expected = listOf(JS, JVM, WASM, NATIVE, ANDROID)
|
val expected = listOf(JS, JVM, WASM, NATIVE, ANDROID)
|
||||||
assertContentEquals(expected, permutations.map { platforms[it] }, "PermSort using custom ascending comparator")
|
assertContentEquals(expected, permutations.map { platforms[it] }, "PermSort using custom ascending comparator")
|
||||||
|
|
||||||
permutations = platforms.permSortWith(compareByDescending { it.name.length })
|
permutations = platforms.indicesSortedWith(compareByDescending { it.name.length })
|
||||||
assertContentEquals(expected.reversed(), permutations.map { platforms[it] }, "PermSort using custom descending comparator")
|
assertContentEquals(expected.reversed(), permutations.map { platforms[it] }, "PermSort using custom descending comparator")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,7 +74,7 @@ class PermSortTest {
|
|||||||
println("Test randomization seed: $seed")
|
println("Test randomization seed: $seed")
|
||||||
|
|
||||||
val buffer = Random(seed).buffer(bufferSize)
|
val buffer = Random(seed).buffer(bufferSize)
|
||||||
val indices = buffer.permSort()
|
val indices = buffer.indicesSorted()
|
||||||
|
|
||||||
assertEquals(bufferSize, indices.size)
|
assertEquals(bufferSize, indices.size)
|
||||||
// Ensure no doublon is present in indices
|
// Ensure no doublon is present in indices
|
||||||
@ -87,7 +86,7 @@ class PermSortTest {
|
|||||||
assertTrue(current <= next, "Permutation indices not properly sorted")
|
assertTrue(current <= next, "Permutation indices not properly sorted")
|
||||||
}
|
}
|
||||||
|
|
||||||
val descIndices = buffer.permSortDescending()
|
val descIndices = buffer.indicesSortedDescending()
|
||||||
assertEquals(bufferSize, descIndices.size)
|
assertEquals(bufferSize, descIndices.size)
|
||||||
// Ensure no doublon is present in indices
|
// Ensure no doublon is present in indices
|
||||||
assertEquals(descIndices.toSet().size, descIndices.size)
|
assertEquals(descIndices.toSet().size, descIndices.size)
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class NdOperationsTest {
|
||||||
|
@Test
|
||||||
|
fun roll() {
|
||||||
|
val structure = DoubleField.ndAlgebra.structureND(5, 5) { index ->
|
||||||
|
index.sumOf { it.toDouble() }
|
||||||
|
}
|
||||||
|
|
||||||
|
println(StructureND.toString(structure))
|
||||||
|
|
||||||
|
val rolled = structure.roll(0,-1)
|
||||||
|
|
||||||
|
println(StructureND.toString(rolled))
|
||||||
|
|
||||||
|
assertEquals(4.0, rolled[0, 0])
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
32
kmath-coroutines/README.md
Normal file
32
kmath-coroutines/README.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Module kmath-coroutines
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
## Artifact:
|
||||||
|
|
||||||
|
The Maven coordinates of this project are `space.kscience:kmath-coroutines:0.3.0`.
|
||||||
|
|
||||||
|
**Gradle Groovy:**
|
||||||
|
```groovy
|
||||||
|
repositories {
|
||||||
|
maven { url 'https://repo.kotlin.link' }
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation 'space.kscience:kmath-coroutines:0.3.0'
|
||||||
|
}
|
||||||
|
```
|
||||||
|
**Gradle Kotlin DSL:**
|
||||||
|
```kotlin
|
||||||
|
repositories {
|
||||||
|
maven("https://repo.kotlin.link")
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("space.kscience:kmath-coroutines:0.3.0")
|
||||||
|
}
|
||||||
|
```
|
@ -18,12 +18,12 @@ public class LazyStructureND<out T>(
|
|||||||
) : StructureND<T> {
|
) : StructureND<T> {
|
||||||
private val cache: MutableMap<IntArray, Deferred<T>> = HashMap()
|
private val cache: MutableMap<IntArray, Deferred<T>> = HashMap()
|
||||||
|
|
||||||
public fun deferred(index: IntArray): Deferred<T> = cache.getOrPut(index) {
|
public fun async(index: IntArray): Deferred<T> = cache.getOrPut(index) {
|
||||||
scope.async(context = Dispatchers.Math) { function(index) }
|
scope.async(context = Dispatchers.Math) { function(index) }
|
||||||
}
|
}
|
||||||
|
|
||||||
public suspend fun await(index: IntArray): T = deferred(index).await()
|
public suspend fun await(index: IntArray): T = async(index).await()
|
||||||
override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() }
|
override operator fun get(index: IntArray): T = runBlocking { async(index).await() }
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> {
|
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||||
@ -33,8 +33,8 @@ public class LazyStructureND<out T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> StructureND<T>.deferred(index: IntArray): Deferred<T> =
|
public fun <T> StructureND<T>.async(index: IntArray): Deferred<T> =
|
||||||
if (this is LazyStructureND<T>) deferred(index) else CompletableDeferred(get(index))
|
if (this is LazyStructureND<T>) this@async.async(index) else CompletableDeferred(get(index))
|
||||||
|
|
||||||
public suspend fun <T> StructureND<T>.await(index: IntArray): T =
|
public suspend fun <T> StructureND<T>.await(index: IntArray): T =
|
||||||
if (this is LazyStructureND<T>) await(index) else get(index)
|
if (this is LazyStructureND<T>) await(index) else get(index)
|
||||||
|
32
kmath-dimensions/README.md
Normal file
32
kmath-dimensions/README.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Module kmath-dimensions
|
||||||
|
|
||||||
|
A proof of concept module for adding type-safe dimensions to structures
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
## Artifact:
|
||||||
|
|
||||||
|
The Maven coordinates of this project are `space.kscience:kmath-dimensions:0.3.0`.
|
||||||
|
|
||||||
|
**Gradle Groovy:**
|
||||||
|
```groovy
|
||||||
|
repositories {
|
||||||
|
maven { url 'https://repo.kotlin.link' }
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation 'space.kscience:kmath-dimensions:0.3.0'
|
||||||
|
}
|
||||||
|
```
|
||||||
|
**Gradle Kotlin DSL:**
|
||||||
|
```kotlin
|
||||||
|
repositories {
|
||||||
|
maven("https://repo.kotlin.link")
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("space.kscience:kmath-dimensions:0.3.0")
|
||||||
|
}
|
||||||
|
```
|
@ -9,17 +9,17 @@ EJML based linear algebra implementation.
|
|||||||
|
|
||||||
## Artifact:
|
## Artifact:
|
||||||
|
|
||||||
The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-17`.
|
The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0`.
|
||||||
|
|
||||||
**Gradle:**
|
**Gradle Groovy:**
|
||||||
```gradle
|
```groovy
|
||||||
repositories {
|
repositories {
|
||||||
maven { url 'https://repo.kotlin.link' }
|
maven { url 'https://repo.kotlin.link' }
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation 'space.kscience:kmath-ejml:0.3.0-dev-17'
|
implementation 'space.kscience:kmath-ejml:0.3.0'
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
**Gradle Kotlin DSL:**
|
**Gradle Kotlin DSL:**
|
||||||
@ -30,6 +30,6 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("space.kscience:kmath-ejml:0.3.0-dev-17")
|
implementation("space.kscience:kmath-ejml:0.3.0")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -271,7 +271,9 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, DoubleField, DMatrix
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -505,7 +507,9 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace<Float, FloatField, FMatrixRM
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -734,7 +738,9 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace<Double, DoubleField, DMatrix
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -963,7 +969,9 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace<Float, FloatField, FMatrixSp
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@file:OptIn(PerformancePitfall::class)
|
||||||
|
|
||||||
package space.kscience.kmath.ejml
|
package space.kscience.kmath.ejml
|
||||||
|
|
||||||
import org.ejml.data.DMatrixRMaj
|
import org.ejml.data.DMatrixRMaj
|
||||||
@ -18,11 +20,11 @@ import kotlin.random.Random
|
|||||||
import kotlin.random.asJavaRandom
|
import kotlin.random.asJavaRandom
|
||||||
import kotlin.test.*
|
import kotlin.test.*
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
internal fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T>) {
|
||||||
fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T>) {
|
|
||||||
assertTrue { StructureND.contentEquals(expected, actual) }
|
assertTrue { StructureND.contentEquals(expected, actual) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
internal class EjmlMatrixTest {
|
internal class EjmlMatrixTest {
|
||||||
private val random = Random(0)
|
private val random = Random(0)
|
||||||
|
|
||||||
|
@ -9,17 +9,17 @@ Specialization of KMath APIs for Double numbers.
|
|||||||
|
|
||||||
## Artifact:
|
## Artifact:
|
||||||
|
|
||||||
The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-17`.
|
The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0`.
|
||||||
|
|
||||||
**Gradle:**
|
**Gradle Groovy:**
|
||||||
```gradle
|
```groovy
|
||||||
repositories {
|
repositories {
|
||||||
maven { url 'https://repo.kotlin.link' }
|
maven { url 'https://repo.kotlin.link' }
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation 'space.kscience:kmath-for-real:0.3.0-dev-17'
|
implementation 'space.kscience:kmath-for-real:0.3.0'
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
**Gradle Kotlin DSL:**
|
**Gradle Kotlin DSL:**
|
||||||
@ -30,6 +30,6 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("space.kscience:kmath-for-real:0.3.0-dev-17")
|
implementation("space.kscience:kmath-for-real:0.3.0")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -21,19 +21,22 @@ readme {
|
|||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "DoubleVector",
|
id = "DoubleVector",
|
||||||
description = "Numpy-like operations for Buffers/Points",
|
|
||||||
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleVector.kt"
|
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleVector.kt"
|
||||||
)
|
){
|
||||||
|
"Numpy-like operations for Buffers/Points"
|
||||||
|
}
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "DoubleMatrix",
|
id = "DoubleMatrix",
|
||||||
description = "Numpy-like operations for 2d real structures",
|
|
||||||
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt"
|
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt"
|
||||||
)
|
){
|
||||||
|
"Numpy-like operations for 2d real structures"
|
||||||
|
}
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "grids",
|
id = "grids",
|
||||||
description = "Uniform grid generators",
|
|
||||||
ref = "src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt"
|
ref = "src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt"
|
||||||
)
|
){
|
||||||
|
"Uniform grid generators"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,17 +11,17 @@ Functions and interpolations.
|
|||||||
|
|
||||||
## Artifact:
|
## Artifact:
|
||||||
|
|
||||||
The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-17`.
|
The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0`.
|
||||||
|
|
||||||
**Gradle:**
|
**Gradle Groovy:**
|
||||||
```gradle
|
```groovy
|
||||||
repositories {
|
repositories {
|
||||||
maven { url 'https://repo.kotlin.link' }
|
maven { url 'https://repo.kotlin.link' }
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation 'space.kscience:kmath-functions:0.3.0-dev-17'
|
implementation 'space.kscience:kmath-functions:0.3.0'
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
**Gradle Kotlin DSL:**
|
**Gradle Kotlin DSL:**
|
||||||
@ -32,6 +32,6 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("space.kscience:kmath-functions:0.3.0-dev-17")
|
implementation("space.kscience:kmath-functions:0.3.0")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -7,6 +7,6 @@ package space.kscience.kmath.functions
|
|||||||
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
|
||||||
public typealias UnivariateFunction<T> = (T) -> T
|
public typealias Function1D<T> = (T) -> T
|
||||||
|
|
||||||
public typealias MultivariateFunction<T> = (Buffer<T>) -> T
|
public typealias FunctionND<T> = (Buffer<T>) -> T
|
@ -28,6 +28,8 @@ public fun <T : Comparable<T>> PiecewisePolynomial<T>.integrate(algebra: Field<T
|
|||||||
/**
|
/**
|
||||||
* Compute definite integral of given [PiecewisePolynomial] piece by piece in a given [range]
|
* Compute definite integral of given [PiecewisePolynomial] piece by piece in a given [range]
|
||||||
* Requires [UnivariateIntegrationNodes] or [IntegrationRange] and [IntegrandMaxCalls]
|
* Requires [UnivariateIntegrationNodes] or [IntegrationRange] and [IntegrandMaxCalls]
|
||||||
|
*
|
||||||
|
* TODO use context receiver for algebra
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <T : Comparable<T>> PiecewisePolynomial<T>.integrate(
|
public fun <T : Comparable<T>> PiecewisePolynomial<T>.integrate(
|
||||||
@ -98,6 +100,7 @@ public object DoubleSplineIntegrator : UnivariateIntegrator<Double> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Suppress("unused")
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public inline val DoubleField.splineIntegrator: UnivariateIntegrator<Double>
|
public inline val DoubleField.splineIntegrator: UnivariateIntegrator<Double>
|
||||||
get() = DoubleSplineIntegrator
|
get() = DoubleSplineIntegrator
|
@ -9,6 +9,7 @@ package space.kscience.kmath.interpolation
|
|||||||
|
|
||||||
import space.kscience.kmath.data.XYColumnarData
|
import space.kscience.kmath.data.XYColumnarData
|
||||||
import space.kscience.kmath.functions.PiecewisePolynomial
|
import space.kscience.kmath.functions.PiecewisePolynomial
|
||||||
|
import space.kscience.kmath.functions.asFunction
|
||||||
import space.kscience.kmath.functions.value
|
import space.kscience.kmath.functions.value
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Ring
|
||||||
@ -59,3 +60,33 @@ public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
|||||||
val pointSet = XYColumnarData.of(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer())
|
val pointSet = XYColumnarData.of(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer())
|
||||||
return interpolatePolynomials(pointSet)
|
return interpolatePolynomials(pointSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
|
||||||
|
x: Buffer<T>,
|
||||||
|
y: Buffer<T>,
|
||||||
|
): (T) -> T? = interpolatePolynomials(x, y).asFunction(algebra)
|
||||||
|
|
||||||
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
|
||||||
|
data: Map<T, T>,
|
||||||
|
): (T) -> T? = interpolatePolynomials(data).asFunction(algebra)
|
||||||
|
|
||||||
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
|
||||||
|
data: List<Pair<T, T>>,
|
||||||
|
): (T) -> T? = interpolatePolynomials(data).asFunction(algebra)
|
||||||
|
|
||||||
|
|
||||||
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
|
||||||
|
x: Buffer<T>,
|
||||||
|
y: Buffer<T>,
|
||||||
|
defaultValue: T,
|
||||||
|
): (T) -> T = interpolatePolynomials(x, y).asFunction(algebra, defaultValue)
|
||||||
|
|
||||||
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
|
||||||
|
data: Map<T, T>,
|
||||||
|
defaultValue: T,
|
||||||
|
): (T) -> T = interpolatePolynomials(data).asFunction(algebra, defaultValue)
|
||||||
|
|
||||||
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
|
||||||
|
data: List<Pair<T, T>>,
|
||||||
|
defaultValue: T,
|
||||||
|
): (T) -> T = interpolatePolynomials(data).asFunction(algebra, defaultValue)
|
@ -22,6 +22,7 @@ internal fun <T : Comparable<T>> insureSorted(points: XYColumnarData<*, T, *>) {
|
|||||||
* Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java
|
* Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java
|
||||||
*/
|
*/
|
||||||
public class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
public class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
override fun interpolatePolynomials(points: XYColumnarData<T, T, T>): PiecewisePolynomial<T> = algebra {
|
override fun interpolatePolynomials(points: XYColumnarData<T, T, T>): PiecewisePolynomial<T> = algebra {
|
||||||
require(points.size > 0) { "Point array should not be empty" }
|
require(points.size > 0) { "Point array should not be empty" }
|
||||||
@ -37,3 +38,6 @@ public class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public val <T : Comparable<T>> Field<T>.linearInterpolator: LinearInterpolator<T>
|
||||||
|
get() = LinearInterpolator(this)
|
||||||
|
@ -72,8 +72,12 @@ public class SplineInterpolator<T : Comparable<T>>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public companion object {
|
|
||||||
public val double: SplineInterpolator<Double> = SplineInterpolator(DoubleField, ::DoubleBuffer)
|
public fun <T : Comparable<T>> Field<T>.splineInterpolator(
|
||||||
}
|
bufferFactory: MutableBufferFactory<T>,
|
||||||
}
|
): SplineInterpolator<T> = SplineInterpolator(this, bufferFactory)
|
||||||
|
|
||||||
|
public val DoubleField.splineInterpolator: SplineInterpolator<Double>
|
||||||
|
get() = SplineInterpolator(this, ::DoubleBuffer)
|
@ -5,8 +5,6 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.interpolation
|
package space.kscience.kmath.interpolation
|
||||||
|
|
||||||
import space.kscience.kmath.functions.PiecewisePolynomial
|
|
||||||
import space.kscience.kmath.functions.asFunction
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
@ -21,8 +19,8 @@ internal class LinearInterpolatorTest {
|
|||||||
3.0 to 4.0
|
3.0 to 4.0
|
||||||
)
|
)
|
||||||
|
|
||||||
val polynomial: PiecewisePolynomial<Double> = LinearInterpolator(DoubleField).interpolatePolynomials(data)
|
//val polynomial: PiecewisePolynomial<Double> = DoubleField.linearInterpolator.interpolatePolynomials(data)
|
||||||
val function = polynomial.asFunction(DoubleField)
|
val function = DoubleField.linearInterpolator.interpolate(data)
|
||||||
assertEquals(null, function(-1.0))
|
assertEquals(null, function(-1.0))
|
||||||
assertEquals(0.5, function(0.5))
|
assertEquals(0.5, function(0.5))
|
||||||
assertEquals(2.0, function(1.5))
|
assertEquals(2.0, function(1.5))
|
||||||
|
@ -5,8 +5,6 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.interpolation
|
package space.kscience.kmath.interpolation
|
||||||
|
|
||||||
import space.kscience.kmath.functions.PiecewisePolynomial
|
|
||||||
import space.kscience.kmath.functions.asFunction
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
import kotlin.math.sin
|
import kotlin.math.sin
|
||||||
@ -21,9 +19,10 @@ internal class SplineInterpolatorTest {
|
|||||||
x to sin(x)
|
x to sin(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
val polynomial: PiecewisePolynomial<Double> = SplineInterpolator.double.interpolatePolynomials(data)
|
//val polynomial: PiecewisePolynomial<Double> = DoubleField.splineInterpolator.interpolatePolynomials(data)
|
||||||
|
|
||||||
|
val function = DoubleField.splineInterpolator.interpolate(data, Double.NaN)
|
||||||
|
|
||||||
val function = polynomial.asFunction(DoubleField, Double.NaN)
|
|
||||||
assertEquals(Double.NaN, function(-1.0))
|
assertEquals(Double.NaN, function(-1.0))
|
||||||
assertEquals(sin(0.5), function(0.5), 0.1)
|
assertEquals(sin(0.5), function(0.5), 0.1)
|
||||||
assertEquals(sin(1.5), function(1.5), 0.1)
|
assertEquals(sin(1.5), function(1.5), 0.1)
|
||||||
|
32
kmath-geometry/README.md
Normal file
32
kmath-geometry/README.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Module kmath-geometry
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
## Artifact:
|
||||||
|
|
||||||
|
The Maven coordinates of this project are `space.kscience:kmath-geometry:0.3.0`.
|
||||||
|
|
||||||
|
**Gradle Groovy:**
|
||||||
|
```groovy
|
||||||
|
repositories {
|
||||||
|
maven { url 'https://repo.kotlin.link' }
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation 'space.kscience:kmath-geometry:0.3.0'
|
||||||
|
}
|
||||||
|
```
|
||||||
|
**Gradle Kotlin DSL:**
|
||||||
|
```kotlin
|
||||||
|
repositories {
|
||||||
|
maven("https://repo.kotlin.link")
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("space.kscience:kmath-geometry:0.3.0")
|
||||||
|
}
|
||||||
|
```
|
@ -6,7 +6,7 @@ plugins {
|
|||||||
|
|
||||||
kotlin.sourceSets.commonMain {
|
kotlin.sourceSets.commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(projects.kmath.kmathComplex)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,12 +6,10 @@
|
|||||||
package space.kscience.kmath.geometry
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
import space.kscience.kmath.linear.Point
|
import space.kscience.kmath.linear.Point
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
import space.kscience.kmath.operations.ScaleOperations
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public interface Vector2D : Point<Double>, Vector {
|
public interface Vector2D : Point<Double>, Vector {
|
||||||
public val x: Double
|
public val x: Double
|
||||||
public val y: Double
|
public val y: Double
|
||||||
@ -29,7 +27,6 @@ public interface Vector2D : Point<Double>, Vector {
|
|||||||
public val Vector2D.r: Double
|
public val Vector2D.r: Double
|
||||||
get() = Euclidean2DSpace { norm() }
|
get() = Euclidean2DSpace { norm() }
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
|
||||||
public fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y)
|
public fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y)
|
||||||
|
|
||||||
private data class Vector2DImpl(
|
private data class Vector2DImpl(
|
||||||
|
@ -6,12 +6,11 @@
|
|||||||
package space.kscience.kmath.geometry
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
import space.kscience.kmath.linear.Point
|
import space.kscience.kmath.linear.Point
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
import space.kscience.kmath.operations.ScaleOperations
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public interface Vector3D : Point<Double>, Vector {
|
public interface Vector3D : Point<Double>, Vector {
|
||||||
public val x: Double
|
public val x: Double
|
||||||
public val y: Double
|
public val y: Double
|
||||||
@ -31,6 +30,19 @@ public interface Vector3D : Point<Double>, Vector {
|
|||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
public fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z)
|
public fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z)
|
||||||
|
|
||||||
|
public fun Buffer<Double>.asVector3D(): Vector3D = object : Vector3D {
|
||||||
|
init {
|
||||||
|
require(this@asVector3D.size == 3) { "Buffer of size 3 is required for Vector3D" }
|
||||||
|
}
|
||||||
|
|
||||||
|
override val x: Double get() = this@asVector3D[0]
|
||||||
|
override val y: Double get() = this@asVector3D[1]
|
||||||
|
override val z: Double get() = this@asVector3D[2]
|
||||||
|
|
||||||
|
override fun toString(): String = this@asVector3D.toString()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
public val Vector3D.r: Double get() = Euclidean3DSpace { norm() }
|
public val Vector3D.r: Double get() = Euclidean3DSpace { norm() }
|
||||||
|
|
||||||
private data class Vector3DImpl(
|
private data class Vector3DImpl(
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.geometry
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
|
//TODO move vector to receiver
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Project vector onto a line.
|
* Project vector onto a line.
|
||||||
* @param vector to project
|
* @param vector to project
|
@ -0,0 +1,116 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
|
import space.kscience.kmath.complex.Quaternion
|
||||||
|
import space.kscience.kmath.complex.QuaternionField
|
||||||
|
import space.kscience.kmath.complex.reciprocal
|
||||||
|
import space.kscience.kmath.linear.LinearSpace
|
||||||
|
import space.kscience.kmath.linear.Matrix
|
||||||
|
import space.kscience.kmath.linear.linearSpace
|
||||||
|
import space.kscience.kmath.linear.matrix
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
internal fun Vector3D.toQuaternion(): Quaternion = Quaternion(0.0, x, y, z)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Angle in radians denoted by this quaternion rotation
|
||||||
|
*/
|
||||||
|
public val Quaternion.theta: Double get() = kotlin.math.acos(w) * 2
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An axis of quaternion rotation
|
||||||
|
*/
|
||||||
|
public val Quaternion.vector: Vector3D
|
||||||
|
get() {
|
||||||
|
val sint2 = sqrt(1 - w * w)
|
||||||
|
|
||||||
|
return object : Vector3D {
|
||||||
|
override val x: Double get() = this@vector.x / sint2
|
||||||
|
override val y: Double get() = this@vector.y / sint2
|
||||||
|
override val z: Double get() = this@vector.z / sint2
|
||||||
|
override fun toString(): String = listOf(x, y, z).toString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rotate a vector in a [Euclidean3DSpace]
|
||||||
|
*/
|
||||||
|
public fun Euclidean3DSpace.rotate(vector: Vector3D, q: Quaternion): Vector3D = with(QuaternionField) {
|
||||||
|
val p = vector.toQuaternion()
|
||||||
|
(q * p * q.reciprocal).vector
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a composition of quaternions to create a rotation
|
||||||
|
*/
|
||||||
|
public fun Euclidean3DSpace.rotate(vector: Vector3D, composition: QuaternionField.() -> Quaternion): Vector3D =
|
||||||
|
rotate(vector, QuaternionField.composition())
|
||||||
|
|
||||||
|
public fun Euclidean3DSpace.rotate(vector: Vector3D, matrix: Matrix<Double>): Vector3D {
|
||||||
|
require(matrix.colNum == 3 && matrix.rowNum == 3) { "Square 3x3 rotation matrix is required" }
|
||||||
|
return with(DoubleField.linearSpace) { matrix.dot(vector).asVector3D() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a [Quaternion] to a rotation matrix
|
||||||
|
*/
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public fun Quaternion.toRotationMatrix(
|
||||||
|
linearSpace: LinearSpace<Double, *> = DoubleField.linearSpace,
|
||||||
|
): Matrix<Double> {
|
||||||
|
val s = QuaternionField.norm(this).pow(-2)
|
||||||
|
return linearSpace.matrix(3, 3)(
|
||||||
|
1.0 - 2 * s * (y * y + z * z), 2 * s * (x * y - z * w), 2 * s * (x * z + y * w),
|
||||||
|
2 * s * (x * y + z * w), 1.0 - 2 * s * (x * x + z * z), 2 * s * (y * z - x * w),
|
||||||
|
2 * s * (x * z - y * w), 2 * s * (y * z + x * w), 1.0 - 2 * s * (x * x + y * y)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* taken from https://www.euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/
|
||||||
|
*/
|
||||||
|
public fun Quaternion.Companion.fromRotationMatrix(matrix: Matrix<Double>): Quaternion {
|
||||||
|
require(matrix.colNum == 3 && matrix.rowNum == 3) { "Rotation matrix should be 3x3 but is ${matrix.rowNum}x${matrix.colNum}" }
|
||||||
|
val trace = matrix[0, 0] + matrix[1, 1] + matrix[2, 2]
|
||||||
|
|
||||||
|
return if (trace > 0) {
|
||||||
|
val s = sqrt(trace + 1.0) * 2 // S=4*qw
|
||||||
|
Quaternion(
|
||||||
|
w = 0.25 * s,
|
||||||
|
x = (matrix[2, 1] - matrix[1, 2]) / s,
|
||||||
|
y = (matrix[0, 2] - matrix[2, 0]) / s,
|
||||||
|
z = (matrix[1, 0] - matrix[0, 1]) / s,
|
||||||
|
)
|
||||||
|
} else if ((matrix[0, 0] > matrix[1, 1]) && (matrix[0, 0] > matrix[2, 2])) {
|
||||||
|
val s = sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 // S=4*qx
|
||||||
|
Quaternion(
|
||||||
|
w = (matrix[2, 1] - matrix[1, 2]) / s,
|
||||||
|
x = 0.25 * s,
|
||||||
|
y = (matrix[0, 1] + matrix[1, 0]) / s,
|
||||||
|
z = (matrix[0, 2] + matrix[2, 0]) / s,
|
||||||
|
)
|
||||||
|
} else if (matrix[1, 1] > matrix[2, 2]) {
|
||||||
|
val s = sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 // S=4*qy
|
||||||
|
Quaternion(
|
||||||
|
w = (matrix[0, 2] - matrix[2, 0]) / s,
|
||||||
|
x = (matrix[0, 1] + matrix[1, 0]) / s,
|
||||||
|
y = 0.25 * s,
|
||||||
|
z = (matrix[1, 2] + matrix[2, 1]) / s,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
val s = sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 // S=4*qz
|
||||||
|
Quaternion(
|
||||||
|
w = (matrix[1, 0] - matrix[0, 1]) / s,
|
||||||
|
x = (matrix[0, 2] + matrix[2, 0]) / s,
|
||||||
|
y = (matrix[1, 2] + matrix[2, 1]) / s,
|
||||||
|
z = 0.25 * s,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,35 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
|
import space.kscience.kmath.complex.Quaternion
|
||||||
|
import space.kscience.kmath.complex.normalized
|
||||||
|
import space.kscience.kmath.testutils.assertBufferEquals
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
class RotationTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun rotations() = with(Euclidean3DSpace) {
|
||||||
|
val vector = Vector3D(1.0, 1.0, 1.0)
|
||||||
|
val q = Quaternion(1.0, 2.0, -3.0, 4.0).normalized()
|
||||||
|
val rotatedByQ = rotate(vector, q)
|
||||||
|
val matrix = q.toRotationMatrix()
|
||||||
|
val rotatedByM = rotate(vector,matrix)
|
||||||
|
|
||||||
|
assertBufferEquals(rotatedByQ, rotatedByM, 1e-4)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun rotationConversion() {
|
||||||
|
|
||||||
|
val q = Quaternion(1.0, 2.0, -3.0, 4.0).normalized()
|
||||||
|
|
||||||
|
val matrix = q.toRotationMatrix()
|
||||||
|
|
||||||
|
assertBufferEquals(q, Quaternion.fromRotationMatrix(matrix))
|
||||||
|
}
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user