Merge remote-tracking branch 'origin/dev' into altavir/diff

This commit is contained in:
Alexander Nozik 2022-07-13 10:13:47 +03:00
commit 56f3c05907
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
147 changed files with 3897 additions and 3718 deletions

View File

@ -13,25 +13,13 @@ jobs:
runs-on: ${{matrix.os}} runs-on: ${{matrix.os}}
timeout-minutes: 40 timeout-minutes: 40
steps: steps:
- name: Checkout the repo - uses: actions/checkout@v3.0.0
uses: actions/checkout@v2 - uses: actions/setup-java@v3.0.0
- name: Set up JDK 11
uses: DeLaGuardo/setup-graalvm@4.0
with: with:
graalvm: 21.2.0 java-version: 11
java: java11 distribution: liberica
arch: amd64
- name: Cache gradle
uses: actions/cache@v2
with:
path: |
~/.gradle/caches
~/.gradle/wrapper
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
restore-keys: |
${{ runner.os }}-gradle-
- name: Cache konan - name: Cache konan
uses: actions/cache@v2 uses: actions/cache@v3.0.1
with: with:
path: ~/.konan path: ~/.konan
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }} key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
@ -39,5 +27,6 @@ jobs:
${{ runner.os }}-gradle- ${{ runner.os }}-gradle-
- name: Gradle Wrapper Validation - name: Gradle Wrapper Validation
uses: gradle/wrapper-validation-action@v1.0.4 uses: gradle/wrapper-validation-action@v1.0.4
- name: Build - uses: gradle/gradle-build-action@v2.1.5
run: ./gradlew build --build-cache --no-daemon --stacktrace with:
arguments: build

View File

@ -1,28 +1,31 @@
name: Dokka publication name: Dokka publication
on: on:
push: workflow_dispatch:
branches: [ master ] release:
types: [ created ]
jobs: jobs:
build: build:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
timeout-minutes: 40 timeout-minutes: 40
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3.0.0
- uses: DeLaGuardo/setup-graalvm@4.0 - uses: actions/setup-java@v3.0.0
with: with:
graalvm: 21.2.0 java-version: 11
java: java11 distribution: liberica
arch: amd64 - name: Cache konan
- uses: actions/cache@v2 uses: actions/cache@v3.0.1
with: with:
path: ~/.gradle/caches path: ~/.konan
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }} key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
restore-keys: | restore-keys: |
${{ runner.os }}-gradle- ${{ runner.os }}-gradle-
- run: ./gradlew dokkaHtmlMultiModule --build-cache --no-daemon --no-parallel --stacktrace - uses: gradle/gradle-build-action@v2.1.5
- uses: JamesIves/github-pages-deploy-action@4.1.0 with:
arguments: dokkaHtmlMultiModule --no-parallel
- uses: JamesIves/github-pages-deploy-action@v4.3.0
with: with:
branch: gh-pages branch: gh-pages
folder: build/dokka/htmlMultiModule folder: build/dokka/htmlMultiModule

View File

@ -14,42 +14,38 @@ jobs:
os: [ macOS-latest, windows-latest ] os: [ macOS-latest, windows-latest ]
runs-on: ${{matrix.os}} runs-on: ${{matrix.os}}
steps: steps:
- name: Checkout the repo - uses: actions/checkout@v3.0.0
uses: actions/checkout@v2 - uses: actions/setup-java@v3.0.0
- name: Set up JDK 11
uses: DeLaGuardo/setup-graalvm@4.0
with: with:
graalvm: 21.2.0 java-version: 11
java: java11 distribution: liberica
arch: amd64
- name: Cache gradle
uses: actions/cache@v2
with:
path: |
~/.gradle/caches
~/.gradle/wrapper
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
restore-keys: |
${{ runner.os }}-gradle-
- name: Cache konan - name: Cache konan
uses: actions/cache@v2 uses: actions/cache@v3.0.1
with: with:
path: ~/.konan path: ~/.konan
key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }} key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }}
restore-keys: | restore-keys: |
${{ runner.os }}-gradle- ${{ runner.os }}-gradle-
- name: Gradle Wrapper Validation - uses: gradle/wrapper-validation-action@v1.0.4
uses: gradle/wrapper-validation-action@v1.0.4
- name: Publish Windows Artifacts - name: Publish Windows Artifacts
if: matrix.os == 'windows-latest' if: matrix.os == 'windows-latest'
shell: cmd uses: gradle/gradle-build-action@v2.1.5
run: > with:
./gradlew release --no-daemon --build-cache -Ppublishing.enabled=true arguments: |
-Ppublishing.space.user=${{ secrets.SPACE_APP_ID }} releaseAll
-Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }} -Ppublishing.enabled=true
-Ppublishing.sonatype=false
-Ppublishing.space.user=${{ secrets.SPACE_APP_ID }}
-Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }}
- name: Publish Mac Artifacts - name: Publish Mac Artifacts
if: matrix.os == 'macOS-latest' if: matrix.os == 'macOS-latest'
run: > uses: gradle/gradle-build-action@v2.1.5
./gradlew release --no-daemon --build-cache -Ppublishing.enabled=true -Ppublishing.platform=macosX64 with:
-Ppublishing.space.user=${{ secrets.SPACE_APP_ID }} arguments: |
-Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }} releaseMacosX64
releaseIosArm64
releaseIosX64
-Ppublishing.enabled=true
-Ppublishing.sonatype=false
-Ppublishing.space.user=${{ secrets.SPACE_APP_ID }}
-Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }}

1
.gitignore vendored
View File

@ -19,3 +19,4 @@ out/
!/.idea/copyright/ !/.idea/copyright/
!/.idea/scopes/ !/.idea/scopes/
/kotlin-js-store/yarn.lock

View File

@ -2,6 +2,21 @@
## [Unreleased] ## [Unreleased]
### Added ### Added
### Changed
- Kotlin 1.7
- `LazyStructure` `deffered` -> `async` to comply with coroutines code style
### Deprecated
### Removed
### Fixed
### Security
## [0.3.0]
### Added
- `ScaleOperations` interface - `ScaleOperations` interface
- `Field` extends `ScaleOperations` - `Field` extends `ScaleOperations`
- Basic integration API - Basic integration API
@ -19,6 +34,12 @@
- Complex power - Complex power
- Separate methods for UInt, Int and Number powers. NaN safety. - Separate methods for UInt, Int and Number powers. NaN safety.
- Tensorflow prototype - Tensorflow prototype
- `ValueAndErrorField`
- MST compilation to WASM: #286
- Jafama integration: #176
- `contentEquals` with tolerance: #364
- Compilation to TeX for MST: #254
### Changed ### Changed
- Exponential operations merged with hyperbolic functions - Exponential operations merged with hyperbolic functions
@ -48,10 +69,15 @@
- Operations -> Ops - Operations -> Ops
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes. - Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
- Tensor algebra takes read-only structures as input and inherits AlgebraND - Tensor algebra takes read-only structures as input and inherits AlgebraND
- `UnivariateDistribution` renamed to `Distribution1D`
- Rework of histograms.
- `UnivariateFunction` -> `Function1D`, `MultivariateFunction` -> `FunctionND`
### Deprecated ### Deprecated
- Specialized `DoubleBufferAlgebra` - Specialized `DoubleBufferAlgebra`
### Removed ### Removed
- Nearest in Domain. To be implemented in geometry package. - Nearest in Domain. To be implemented in geometry package.
- Number multiplication and division in main Algebra chain - Number multiplication and division in main Algebra chain
@ -62,10 +88,12 @@
- Second generic from DifferentiableExpression - Second generic from DifferentiableExpression
- Algebra elements are completely removed. Use algebra contexts instead. - Algebra elements are completely removed. Use algebra contexts instead.
### Fixed ### Fixed
- Ring inherits RingOperations, not GroupOperations - Ring inherits RingOperations, not GroupOperations
- Univariate histogram filling - Univariate histogram filling
### Security ### Security
## [0.2.0] ## [0.2.0]
@ -88,6 +116,7 @@
- New `MatrixFeature` interfaces for matrix decompositions - New `MatrixFeature` interfaces for matrix decompositions
- Basic Quaternion vector support in `kmath-complex`. - Basic Quaternion vector support in `kmath-complex`.
### Changed ### Changed
- Package changed from `scientifik` to `space.kscience` - Package changed from `scientifik` to `space.kscience`
- Gradle version: 6.6 -> 6.8.2 - Gradle version: 6.6 -> 6.8.2
@ -112,7 +141,6 @@
- `symbol` method in `Algebra` renamed to `bindSymbol` to avoid ambiguity - `symbol` method in `Algebra` renamed to `bindSymbol` to avoid ambiguity
- Add `out` projection to `Buffer` generic - Add `out` projection to `Buffer` generic
### Deprecated
### Removed ### Removed
- `kmath-koma` module because it doesn't support Kotlin 1.4. - `kmath-koma` module because it doesn't support Kotlin 1.4.
@ -122,13 +150,11 @@
- `Real` class - `Real` class
- StructureND identity and equals - StructureND identity and equals
### Fixed ### Fixed
- `symbol` method in `MstExtendedField` (https://github.com/mipt-npm/kmath/pull/140) - `symbol` method in `MstExtendedField` (https://github.com/mipt-npm/kmath/pull/140)
### Security
## [0.1.4] ## [0.1.4]
### Added ### Added
- Functional Expressions API - Functional Expressions API
- Mathematical Syntax Tree, its interpreter and API - Mathematical Syntax Tree, its interpreter and API
@ -146,6 +172,7 @@
- Full hyperbolic functions support and default implementations within `ExtendedField` - Full hyperbolic functions support and default implementations within `ExtendedField`
- Norm support for `Complex` - Norm support for `Complex`
### Changed ### Changed
- `readAsMemory` now has `throws IOException` in JVM signature. - `readAsMemory` now has `throws IOException` in JVM signature.
- Several functions taking functional types were made `inline`. - Several functions taking functional types were made `inline`.
@ -157,9 +184,10 @@
- Gradle version: 6.3 -> 6.6 - Gradle version: 6.3 -> 6.6
- Moved probability distributions to commons-rng and to `kmath-prob` - Moved probability distributions to commons-rng and to `kmath-prob`
### Fixed ### Fixed
- Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106) - Missing copy method in Memory implementation on JS (https://github.com/mipt-npm/kmath/pull/106)
- D3.dim value in `kmath-dimensions` - D3.dim value in `kmath-dimensions`
- Multiplication in integer rings in `kmath-core` (https://github.com/mipt-npm/kmath/pull/101) - Multiplication in integer rings in `kmath-core` (https://github.com/mipt-npm/kmath/pull/101)
- Commons RNG compatibility (https://github.com/mipt-npm/kmath/issues/93) - Commons RNG compatibility (https://github.com/mipt-npm/kmath/issues/93)
- Multiplication of BigInt by scalar - Multiplication of BigInt by scalar

View File

@ -52,21 +52,18 @@ module definitions below. The module stability could have the following levels:
## Modules ## Modules
<hr/>
* ### [benchmarks](benchmarks) ### [benchmarks](benchmarks)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
<hr/>
* ### [examples](examples) ### [examples](examples)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-ast](kmath-ast) ### [kmath-ast](kmath-ast)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
@ -77,15 +74,13 @@ module definitions below. The module stability could have the following levels:
> - [mst-js-codegen](kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler > - [mst-js-codegen](kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt) : Dynamic MST to JS compiler
> - [rendering](kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt) : Extendable MST rendering > - [rendering](kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt) : Extendable MST rendering
<hr/>
* ### [kmath-commons](kmath-commons) ### [kmath-commons](kmath-commons)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-complex](kmath-complex) ### [kmath-complex](kmath-complex)
> Complex numbers and quaternions. > Complex numbers and quaternions.
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
@ -94,9 +89,8 @@ module definitions below. The module stability could have the following levels:
> - [complex](kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt) : Complex Numbers > - [complex](kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt) : Complex Numbers
> - [quaternion](kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt) : Quaternions > - [quaternion](kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt) : Quaternions
<hr/>
* ### [kmath-core](kmath-core) ### [kmath-core](kmath-core)
> Core classes, algebra definitions, basic linear algebra > Core classes, algebra definitions, basic linear algebra
> >
> **Maturity**: DEVELOPMENT > **Maturity**: DEVELOPMENT
@ -112,21 +106,18 @@ performance calculations to code generation.
> - [domains](kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains) : Domains > - [domains](kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains) : Domains
> - [autodiff](kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation > - [autodiff](kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
<hr/>
* ### [kmath-coroutines](kmath-coroutines) ### [kmath-coroutines](kmath-coroutines)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-dimensions](kmath-dimensions) ### [kmath-dimensions](kmath-dimensions)
> >
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-ejml](kmath-ejml) ### [kmath-ejml](kmath-ejml)
> >
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
@ -136,9 +127,8 @@ performance calculations to code generation.
> - [ejml-matrix](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : Matrix implementation. > - [ejml-matrix](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt) : Matrix implementation.
> - [ejml-linear-space](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : LinearSpace implementations. > - [ejml-linear-space](kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt) : LinearSpace implementations.
<hr/>
* ### [kmath-for-real](kmath-for-real) ### [kmath-for-real](kmath-for-real)
> Extension module that should be used to achieve numpy-like behavior. > Extension module that should be used to achieve numpy-like behavior.
All operations are specialized to work with `Double` numbers without declaring algebraic contexts. All operations are specialized to work with `Double` numbers without declaring algebraic contexts.
One can still use generic algebras though. One can still use generic algebras though.
@ -150,9 +140,8 @@ One can still use generic algebras though.
> - [DoubleMatrix](kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt) : Numpy-like operations for 2d real structures > - [DoubleMatrix](kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt) : Numpy-like operations for 2d real structures
> - [grids](kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt) : Uniform grid generators > - [grids](kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt) : Uniform grid generators
<hr/>
* ### [kmath-functions](kmath-functions) ### [kmath-functions](kmath-functions)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
@ -164,21 +153,18 @@ One can still use generic algebras though.
> - [spline interpolation](kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt) : Cubic spline XY interpolator. > - [spline interpolation](kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt) : Cubic spline XY interpolator.
> - [integration](kmath-functions/#) : Univariate and multivariate quadratures > - [integration](kmath-functions/#) : Univariate and multivariate quadratures
<hr/>
* ### [kmath-geometry](kmath-geometry) ### [kmath-geometry](kmath-geometry)
> >
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-histograms](kmath-histograms) ### [kmath-histograms](kmath-histograms)
> >
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-jafama](kmath-jafama) ### [kmath-jafama](kmath-jafama)
> >
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
@ -186,15 +172,13 @@ One can still use generic algebras though.
> **Features:** > **Features:**
> - [jafama-double](kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/) : Double ExtendedField implementations based on Jafama > - [jafama-double](kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/) : Double ExtendedField implementations based on Jafama
<hr/>
* ### [kmath-jupyter](kmath-jupyter) ### [kmath-jupyter](kmath-jupyter)
> >
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-kotlingrad](kmath-kotlingrad) ### [kmath-kotlingrad](kmath-kotlingrad)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
@ -203,21 +187,18 @@ One can still use generic algebras though.
> - [differentiable-mst-expression](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt) : MST based DifferentiableExpression. > - [differentiable-mst-expression](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt) : MST based DifferentiableExpression.
> - [scalars-adapters](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt) : Conversions between Kotlin∇'s SFun and MST > - [scalars-adapters](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt) : Conversions between Kotlin∇'s SFun and MST
<hr/>
* ### [kmath-memory](kmath-memory) ### [kmath-memory](kmath-memory)
> An API and basic implementation for arranging objects in a continuous memory block. > An API and basic implementation for arranging objects in a continuous memory block.
> >
> **Maturity**: DEVELOPMENT > **Maturity**: DEVELOPMENT
<hr/>
* ### [kmath-multik](kmath-multik) ### [kmath-multik](kmath-multik)
> >
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-nd4j](kmath-nd4j) ### [kmath-nd4j](kmath-nd4j)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
@ -227,27 +208,28 @@ One can still use generic algebras though.
> - [nd4jarrayrings](kmath-nd4j/#) : Rings over Nd4jArrayStructure of Int and Long > - [nd4jarrayrings](kmath-nd4j/#) : Rings over Nd4jArrayStructure of Int and Long
> - [nd4jarrayfields](kmath-nd4j/#) : Fields over Nd4jArrayStructure of Float and Double > - [nd4jarrayfields](kmath-nd4j/#) : Fields over Nd4jArrayStructure of Float and Double
<hr/>
* ### [kmath-optimization](kmath-optimization) ### [kmath-optimization](kmath-optimization)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-stat](kmath-stat) ### [kmath-stat](kmath-stat)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-symja](kmath-symja) ### [kmath-symja](kmath-symja)
> >
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
<hr/>
* ### [kmath-tensors](kmath-tensors) ### [kmath-tensorflow](kmath-tensorflow)
>
>
> **Maturity**: PROTOTYPE
### [kmath-tensors](kmath-tensors)
> >
> >
> **Maturity**: PROTOTYPE > **Maturity**: PROTOTYPE
@ -257,13 +239,11 @@ One can still use generic algebras though.
> - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting. > - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
> - [linear algebra operations](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc. > - [linear algebra operations](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.
<hr/>
* ### [kmath-viktor](kmath-viktor) ### [kmath-viktor](kmath-viktor)
> >
> >
> **Maturity**: DEVELOPMENT > **Maturity**: DEVELOPMENT
<hr/>
## Multi-platform support ## Multi-platform support
@ -302,8 +282,8 @@ repositories {
} }
dependencies { dependencies {
api("space.kscience:kmath-core:0.3.0-dev-17") api("space.kscience:kmath-core:$version")
// api("space.kscience:kmath-core-jvm:0.3.0-dev-17") for jvm-specific version // api("space.kscience:kmath-core-jvm:$version") for jvm-specific version
} }
``` ```

4
benchmarks/README.md Normal file
View File

@ -0,0 +1,4 @@
# Module benchmarks

View File

@ -52,6 +52,8 @@ kotlin {
implementation(project(":kmath-viktor")) implementation(project(":kmath-viktor"))
implementation(project(":kmath-jafama")) implementation(project(":kmath-jafama"))
implementation(project(":kmath-multik")) implementation(project(":kmath-multik"))
implementation(projects.kmath.kmathTensorflow)
implementation("org.tensorflow:tensorflow-core-platform:0.4.0")
implementation("org.nd4j:nd4j-native:1.0.0-M1") implementation("org.nd4j:nd4j-native:1.0.0-M1")
// uncomment if your system supports AVX2 // uncomment if your system supports AVX2
// val os = System.getProperty("os.name") // val os = System.getProperty("os.name")
@ -122,6 +124,11 @@ benchmark {
include("JafamaBenchmark") include("JafamaBenchmark")
} }
configurations.register("tensorAlgebra") {
commonConfiguration()
include("TensorAlgebraBenchmark")
}
configurations.register("viktor") { configurations.register("viktor") {
commonConfiguration() commonConfiguration()
include("ViktorBenchmark") include("ViktorBenchmark")
@ -148,7 +155,7 @@ kotlin.sourceSets.all {
} }
} }
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> { tasks.withType<org.jetbrains.kotlin.gradle.dsl.KotlinJvmCompile> {
kotlinOptions { kotlinOptions {
jvmTarget = "11" jvmTarget = "11"
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy" freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy"

View File

@ -15,7 +15,11 @@ import space.kscience.kmath.linear.invoke
import space.kscience.kmath.linear.linearSpace import space.kscience.kmath.linear.linearSpace
import space.kscience.kmath.multik.multikAlgebra import space.kscience.kmath.multik.multikAlgebra
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.tensorflow.produceWithTF
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.tensorAlgebra
import kotlin.random.Random import kotlin.random.Random
@State(Scope.Benchmark) @State(Scope.Benchmark)
@ -39,6 +43,16 @@ internal class DotBenchmark {
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() } val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
} }
@Benchmark
fun tfDot(blackhole: Blackhole) {
blackhole.consume(
DoubleField.produceWithTF {
matrix1 dot matrix1
}
)
}
@Benchmark @Benchmark
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace { fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
blackhole.consume(matrix1 dot matrix2) blackhole.consume(matrix1 dot matrix2)
@ -59,13 +73,13 @@ internal class DotBenchmark {
blackhole.consume(matrix1 dot matrix2) blackhole.consume(matrix1 dot matrix2)
} }
// @Benchmark @Benchmark
// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) { fun tensorDot(blackhole: Blackhole) = with(DoubleField.tensorAlgebra) {
// blackhole.consume(matrix1 dot matrix2) blackhole.consume(matrix1 dot matrix2)
// } }
@Benchmark @Benchmark
fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) { fun multikDot(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
blackhole.consume(matrix1 dot matrix2) blackhole.consume(matrix1 dot matrix2)
} }
@ -78,4 +92,9 @@ internal class DotBenchmark {
fun doubleDot(blackhole: Blackhole) = with(DoubleField.linearSpace) { fun doubleDot(blackhole: Blackhole) = with(DoubleField.linearSpace) {
blackhole.consume(matrix1 dot matrix2) blackhole.consume(matrix1 dot matrix2)
} }
@Benchmark
fun doubleTensorDot(blackhole: Blackhole) = DoubleTensorAlgebra.invoke {
blackhole.consume(matrix1 dot matrix2)
}
} }

View File

@ -0,0 +1,37 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.benchmarks
import kotlinx.benchmark.Benchmark
import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope
import kotlinx.benchmark.State
import space.kscience.kmath.linear.linearSpace
import space.kscience.kmath.linear.matrix
import space.kscience.kmath.linear.symmetric
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.tensors.core.tensorAlgebra
import kotlin.random.Random
@State(Scope.Benchmark)
internal class TensorAlgebraBenchmark {
companion object {
private val random = Random(12224)
private const val dim = 30
private val matrix = DoubleField.linearSpace.matrix(dim, dim).symmetric { _, _ -> random.nextDouble() }
}
@Benchmark
fun tensorSymEigSvd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
blackhole.consume(matrix.symEigSvd(1e-10))
}
@Benchmark
fun tensorSymEigJacobi(blackhole: Blackhole) = with(Double.tensorAlgebra) {
blackhole.consume(matrix.symEigJacobi(50, 1e-10))
}
}

View File

@ -1,22 +1,23 @@
plugins { plugins {
id("ru.mipt.npm.gradle.project") id("ru.mipt.npm.gradle.project")
id("org.jetbrains.kotlinx.kover") version "0.5.0-RC" id("org.jetbrains.kotlinx.kover") version "0.5.0"
} }
allprojects { allprojects {
repositories { repositories {
maven("https://repo.kotlin.link")
maven("https://oss.sonatype.org/content/repositories/snapshots") maven("https://oss.sonatype.org/content/repositories/snapshots")
mavenCentral() mavenCentral()
} }
group = "space.kscience" group = "space.kscience"
version = "0.3.0-dev-17" version = "0.3.1-dev-1"
} }
subprojects { subprojects {
if (name.startsWith("kmath")) apply<MavenPublishPlugin>() if (name.startsWith("kmath")) apply<MavenPublishPlugin>()
plugins.withId("org.jetbrains.dokka"){ plugins.withId("org.jetbrains.dokka") {
tasks.withType<org.jetbrains.dokka.gradle.DokkaTaskPartial> { tasks.withType<org.jetbrains.dokka.gradle.DokkaTaskPartial> {
dependsOn(tasks["assemble"]) dependsOn(tasks["assemble"])
@ -50,12 +51,24 @@ subprojects {
} }
} }
} }
plugins.withId("org.jetbrains.kotlin.multiplatform") {
configure<org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension> {
sourceSets {
val commonTest by getting {
dependencies {
implementation(projects.testUtils)
}
}
}
}
}
} }
readme.readmeTemplate = file("docs/templates/README-TEMPLATE.md") readme.readmeTemplate = file("docs/templates/README-TEMPLATE.md")
ksciencePublish { ksciencePublish {
github("kmath") github("kmath", addToRelease = false)
space() space()
sonatype() sonatype()
} }

View File

@ -1,4 +1,5 @@
plugins { plugins {
kotlin("jvm") version "1.7.0"
`kotlin-dsl` `kotlin-dsl`
`version-catalog` `version-catalog`
alias(npmlibs.plugins.kotlin.plugin.serialization) alias(npmlibs.plugins.kotlin.plugin.serialization)
@ -7,17 +8,19 @@ plugins {
java.targetCompatibility = JavaVersion.VERSION_11 java.targetCompatibility = JavaVersion.VERSION_11
repositories { repositories {
mavenLocal()
maven("https://repo.kotlin.link") maven("https://repo.kotlin.link")
mavenCentral() mavenCentral()
gradlePluginPortal() gradlePluginPortal()
} }
val toolsVersion: String by extra val toolsVersion = npmlibs.versions.tools.get()
val kotlinVersion = npmlibs.versions.kotlin.asProvider().get() val kotlinVersion = npmlibs.versions.kotlin.asProvider().get()
val benchmarksVersion = "0.4.2" val benchmarksVersion = npmlibs.versions.kotlinx.benchmark.get()
dependencies { dependencies {
api("ru.mipt.npm:gradle-tools:$toolsVersion") api("ru.mipt.npm:gradle-tools:$toolsVersion")
api(npmlibs.atomicfu.gradle)
//plugins form benchmarks //plugins form benchmarks
api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion") api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion")
api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion") api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")

View File

@ -1,14 +0,0 @@
#
# Copyright 2018-2021 KMath contributors.
# Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
#
kotlin.code.style=official
kotlin.mpp.stability.nowarn=true
kotlin.jupyter.add.scanner=false
org.gradle.configureondemand=true
org.gradle.parallel=true
toolsVersion=0.10.9-kotlin-1.6.10

View File

@ -3,17 +3,26 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS") enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
enableFeaturePreview("VERSION_CATALOGS")
dependencyResolutionManagement { dependencyResolutionManagement {
val projectProperties = java.util.Properties()
file("../gradle.properties").inputStream().use {
projectProperties.load(it)
}
val toolsVersion: String by extra projectProperties.forEach { key, value ->
extra.set(key.toString(), value)
}
val toolsVersion: String = projectProperties["toolsVersion"].toString()
repositories { repositories {
mavenLocal()
maven("https://repo.kotlin.link") maven("https://repo.kotlin.link")
mavenCentral() mavenCentral()
gradlePluginPortal()
} }
versionCatalogs { versionCatalogs {

View File

@ -319,7 +319,9 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra},
} }
else -> null else -> null
}?.let(type::cast) }?.let{
type.cast(it)
}
} }
/** /**

View File

@ -52,7 +52,7 @@ module definitions below. The module stability could have the following levels:
## Modules ## Modules
$modules ${modules}
## Multi-platform support ## Multi-platform support

4
examples/README.md Normal file
View File

@ -0,0 +1,4 @@
# Module examples

View File

@ -58,7 +58,7 @@ kotlin.sourceSets.all {
} }
} }
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> { tasks.withType<org.jetbrains.kotlin.gradle.dsl.KotlinJvmCompile> {
kotlinOptions { kotlinOptions {
jvmTarget = "11" jvmTarget = "11"
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy" freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy"

View File

@ -13,7 +13,7 @@ import kotlin.math.pow
fun main() { fun main() {
//Define a function //Define a function
val function: UnivariateFunction<Double> = { x -> 3 * x.pow(2) + 2 * x + 1 } val function: Function1D<Double> = { x -> 3 * x.pow(2) + 2 * x + 1 }
//get the result of the integration //get the result of the integration
val result = DoubleField.gaussIntegrator.integrate(0.0..10.0, function = function) val result = DoubleField.gaussIntegrator.integrate(0.0..10.0, function = function)

View File

@ -5,8 +5,8 @@
package space.kscience.kmath.functions package space.kscience.kmath.functions
import space.kscience.kmath.interpolation.SplineInterpolator
import space.kscience.kmath.interpolation.interpolatePolynomials import space.kscience.kmath.interpolation.interpolatePolynomials
import space.kscience.kmath.interpolation.splineInterpolator
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.real.map import space.kscience.kmath.real.map
import space.kscience.kmath.real.step import space.kscience.kmath.real.step
@ -18,7 +18,7 @@ import space.kscience.plotly.scatter
@OptIn(UnstablePlotlyAPI::class) @OptIn(UnstablePlotlyAPI::class)
fun main() { fun main() {
val function: UnivariateFunction<Double> = { x -> val function: Function1D<Double> = { x ->
if (x in 30.0..50.0) { if (x in 30.0..50.0) {
1.0 1.0
} else { } else {
@ -28,7 +28,7 @@ fun main() {
val xs = 0.0..100.0 step 0.5 val xs = 0.0..100.0 step 0.5
val ys = xs.map(function) val ys = xs.map(function)
val polynomial: PiecewisePolynomial<Double> = SplineInterpolator.double.interpolatePolynomials(xs, ys) val polynomial: PiecewisePolynomial<Double> = DoubleField.splineInterpolator.interpolatePolynomials(xs, ys)
val polyFunction = polynomial.asFunction(DoubleField, 0.0) val polyFunction = polynomial.asFunction(DoubleField, 0.0)

View File

@ -2,14 +2,14 @@
# Copyright 2018-2021 KMath contributors. # Copyright 2018-2021 KMath contributors.
# Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. # Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
# #
kotlin.code.style=official kotlin.code.style=official
kotlin.mpp.stability.nowarn=true
kotlin.jupyter.add.scanner=false kotlin.jupyter.add.scanner=false
kotlin.mpp.stability.nowarn=true
kotlin.native.ignoreDisabledTargets=true
//kotlin.incremental.js.ir=true
org.gradle.configureondemand=true org.gradle.configureondemand=true
org.gradle.parallel=true org.gradle.parallel=true
org.gradle.jvmargs=-XX:MaxMetaspaceSize=1G org.gradle.jvmargs=-Xmx4096m
toolsVersion=0.11.1-kotlin-1.6.10 toolsVersion=0.11.7-kotlin-1.7.0

Binary file not shown.

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.3.3-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.2-bin.zip
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -1,6 +1,6 @@
# Module kmath-ast # Module kmath-ast
Performance and visualization extensions to MST API. Extensions to MST API: transformations, dynamic compilation and visualization.
- [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser - [expression-language](src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt) : Expression language and its parser
- [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler - [mst-jvm-codegen](src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt) : Dynamic MST to JVM bytecode compiler
@ -10,17 +10,17 @@ Performance and visualization extensions to MST API.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-17`. The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0`.
**Gradle:** **Gradle Groovy:**
```gradle ```groovy
repositories { repositories {
maven { url 'https://repo.kotlin.link' } maven { url 'https://repo.kotlin.link' }
mavenCentral() mavenCentral()
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-ast:0.3.0-dev-17' implementation 'space.kscience:kmath-ast:0.3.0'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -31,10 +31,30 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-ast:0.3.0-dev-17") implementation("space.kscience:kmath-ast:0.3.0")
} }
``` ```
## Parsing expressions
In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
Supported literals:
1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`&mdash;all parsed either as `kotlin.Long` or `kotlin.Double`.
Supported binary operators (from the highest precedence to the lowest one):
1. `^`
2. `*`, `/`
3. `+`, `-`
Supported unary operator:
1. `-`, e.&nbsp;g. `-x`
Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
1. `sin(x)`
2. `add(x, y)`
## Dynamic expression code generation ## Dynamic expression code generation
### On JVM ### On JVM
@ -42,48 +62,66 @@ dependencies {
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a `kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression<T>` with implemented `invoke` function. special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder: For example, the following code:
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.* import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.*
import space.kscience.kmath.asm.*
MstField { x + 2 }.compileToExpression(DoubleField)
```
... leads to generation of bytecode, which can be decompiled to the following Java class:
```java
package space.kscience.kmath.asm.generated;
import java.util.Map;
import kotlin.jvm.functions.Function2;
import space.kscience.kmath.asm.internal.MapIntrinsics;
import space.kscience.kmath.expressions.Expression;
import space.kscience.kmath.expressions.Symbol;
public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
private final Object[] constants;
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
}
public AsmCompiledExpression_45045_0(Object[] constants) {
this.constants = constants;
}
}
"x^3-x+3".parseMath().compileToExpression(DoubleField)
``` ```
#### Known issues &mldr; leads to generation of bytecode, which can be decompiled to the following Java class:
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class ```java
loading overhead. import java.util.*;
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. import kotlin.jvm.functions.*;
import space.kscience.kmath.asm.internal.*;
import space.kscience.kmath.complex.*;
import space.kscience.kmath.expressions.*;
public final class CompiledExpression_45045_0 implements Expression<Complex> {
private final Object[] constants;
public Complex invoke(Map<Symbol, ? extends Complex> arguments) {
Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
}
}
```
For `LongRing`, `IntRing`, and `DoubleField` specialization is supported for better performance:
```java
import java.util.*;
import space.kscience.kmath.asm.internal.*;
import space.kscience.kmath.expressions.*;
public final class CompiledExpression_-386104628_0 implements DoubleExpression {
private final SymbolIndexer indexer;
public SymbolIndexer getIndexer() {
return this.indexer;
}
public double invoke(double[] arguments) {
double var2 = arguments[0];
return Math.pow(var2, 3.0D) - var2 + 3.0D;
}
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
double var2 = ((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue();
return Math.pow(var2, 3.0D) - var2 + 3.0D;
}
}
```
Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
#### Limitations
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
### On JS ### On JS
@ -121,15 +159,15 @@ MstField { x + 2 }.compileToExpression(DoubleField)
An example of emitted Wasm IR in the form of WAT: An example of emitted Wasm IR in the form of WAT:
```lisp ```lisp
(func $executable (param $0 f64) (result f64) (func \$executable (param \$0 f64) (result f64)
(f64.add (f64.add
(local.get $0) (local.get \$0)
(f64.const 2) (f64.const 2)
) )
) )
``` ```
#### Known issues #### Limitations
- ESTree expression compilation uses `eval` which can be unavailable in several environments. - ESTree expression compilation uses `eval` which can be unavailable in several environments.
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/). - WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
@ -161,10 +199,7 @@ public fun main() {
Result LaTeX: Result LaTeX:
<div style="background-color:white;"> $$\operatorname{exp}\\,\left(\sqrt{x}\right)-\frac{\frac{\operatorname{arcsin}\\,\left(2\\,x\right)}{2\times10^{10}+x^{3}}}{12}+x^{2/3}$$
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3})
</div>
Result MathML (can be used with MathJax or other renderers): Result MathML (can be used with MathJax or other renderers):

View File

@ -24,7 +24,7 @@ kotlin.sourceSets {
commonMain { commonMain {
dependencies { dependencies {
api("com.github.h0tk3y.betterParse:better-parse:0.4.2") api("com.github.h0tk3y.betterParse:better-parse:0.4.4")
api(project(":kmath-core")) api(project(":kmath-core"))
} }
} }
@ -57,7 +57,7 @@ tasks.dokkaHtml {
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1") if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
tasks.jvmTest { tasks.jvmTest {
jvmArgs = (jvmArgs ?: emptyList()) + listOf("-Dspace.kscience.kmath.ast.dump.generated.classes=1") jvmArgs("-Dspace.kscience.kmath.ast.dump.generated.classes=1")
} }
readme { readme {

View File

@ -1,11 +1,31 @@
# Module kmath-ast # Module kmath-ast
Performance and visualization extensions to MST API. Extensions to MST API: transformations, dynamic compilation and visualization.
${features} ${features}
${artifact} ${artifact}
## Parsing expressions
In this module there is a parser from human-readable strings like `"x^3-x+3"` (in the more specific [grammar](reference/ArithmeticsEvaluator.g4)) to MST instances.
Supported literals:
1. Constants and variables (consist of latin letters, digits and underscores, can't start with digit): `x`, `_Abc2`.
2. Numbers: `123`, `1.02`, `1e10`, `1e-10`, `1.0e+3`&mdash;all parsed either as `kotlin.Long` or `kotlin.Double`.
Supported binary operators (from the highest precedence to the lowest one):
1. `^`
2. `*`, `/`
3. `+`, `-`
Supported unary operator:
1. `-`, e.&nbsp;g. `-x`
Arbitrary unary and binary functions are also supported: names consist of latin letters, digits and underscores, can't start with digit. Examples:
1. `sin(x)`
2. `add(x, y)`
## Dynamic expression code generation ## Dynamic expression code generation
### On JVM ### On JVM
@ -13,48 +33,66 @@ ${artifact}
`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a `kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds a
special implementation of `Expression<T>` with implemented `invoke` function. special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder: For example, the following code:
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.* import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.*
import space.kscience.kmath.asm.*
MstField { x + 2 }.compileToExpression(DoubleField)
```
... leads to generation of bytecode, which can be decompiled to the following Java class:
```java
package space.kscience.kmath.asm.generated;
import java.util.Map;
import kotlin.jvm.functions.Function2;
import space.kscience.kmath.asm.internal.MapIntrinsics;
import space.kscience.kmath.expressions.Expression;
import space.kscience.kmath.expressions.Symbol;
public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
private final Object[] constants;
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
return (Double) ((Function2) this.constants[0]).invoke((Double) MapIntrinsics.getOrFail(arguments, "x"), 2);
}
public AsmCompiledExpression_45045_0(Object[] constants) {
this.constants = constants;
}
}
"x^3-x+3".parseMath().compileToExpression(DoubleField)
``` ```
#### Known issues &mldr; leads to generation of bytecode, which can be decompiled to the following Java class:
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class ```java
loading overhead. import java.util.*;
- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. import kotlin.jvm.functions.*;
import space.kscience.kmath.asm.internal.*;
import space.kscience.kmath.complex.*;
import space.kscience.kmath.expressions.*;
public final class CompiledExpression_45045_0 implements Expression<Complex> {
private final Object[] constants;
public Complex invoke(Map<Symbol, ? extends Complex> arguments) {
Complex var2 = (Complex)MapIntrinsics.getOrFail(arguments, "x");
return (Complex)((Function2)this.constants[0]).invoke(var2, (Complex)this.constants[1]);
}
}
```
For `LongRing`, `IntRing`, and `DoubleField` specialization is supported for better performance:
```java
import java.util.*;
import space.kscience.kmath.asm.internal.*;
import space.kscience.kmath.expressions.*;
public final class CompiledExpression_-386104628_0 implements DoubleExpression {
private final SymbolIndexer indexer;
public SymbolIndexer getIndexer() {
return this.indexer;
}
public double invoke(double[] arguments) {
double var2 = arguments[0];
return Math.pow(var2, 3.0D) - var2 + 3.0D;
}
public final Double invoke(Map<Symbol, ? extends Double> arguments) {
double var2 = ((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue();
return Math.pow(var2, 3.0D) - var2 + 3.0D;
}
}
```
Setting JVM system property `space.kscience.kmath.ast.dump.generated.classes` to `1` makes the translator dump class files to program's working directory, so they can be reviewed manually.
#### Limitations
- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid class loading overhead.
- This API is not supported by non-dynamic JVM implementations like TeaVM or GraalVM Native Image because they may not support class loaders.
### On JS ### On JS
@ -100,7 +138,7 @@ An example of emitted Wasm IR in the form of WAT:
) )
``` ```
#### Known issues #### Limitations
- ESTree expression compilation uses `eval` which can be unavailable in several environments. - ESTree expression compilation uses `eval` which can be unavailable in several environments.
- WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/). - WebAssembly isn't supported by old versions of browsers (see https://webassembly.org/roadmap/).
@ -132,10 +170,7 @@ public fun main() {
Result LaTeX: Result LaTeX:
<div style="background-color:white;"> $$\operatorname{exp}\\,\left(\sqrt{x}\right)-\frac{\frac{\operatorname{arcsin}\\,\left(2\\,x\right)}{2\times10^{10}+x^{3}}}{12}+x^{2/3}$$
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3})
</div>
Result MathML (can be used with MathJax or other renderers): Result MathML (can be used with MathJax or other renderers):

View File

@ -0,0 +1,177 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
/**
* MST form where all values belong to the type [T]. It is optimal for constant folding, dynamic compilation, etc.
*
* @param T the type.
*/
@UnstableKMathAPI
public sealed interface TypedMst<T> {
/**
* A node containing a unary operation.
*
* @param T the type.
* @property operation The identifier of operation.
* @property function The function implementing this operation.
* @property value The argument of this operation.
*/
public class Unary<T>(public val operation: String, public val function: (T) -> T, public val value: TypedMst<T>) :
TypedMst<T> {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as Unary<*>
if (operation != other.operation) return false
if (value != other.value) return false
return true
}
override fun hashCode(): Int {
var result = operation.hashCode()
result = 31 * result + value.hashCode()
return result
}
override fun toString(): String = "Unary(operation=$operation, value=$value)"
}
/**
* A node containing binary operation.
*
* @param T the type.
* @property operation The identifier of operation.
* @property function The binary function implementing this operation.
* @property left The left operand.
* @property right The right operand.
*/
public class Binary<T>(
public val operation: String,
public val function: Function<T>,
public val left: TypedMst<T>,
public val right: TypedMst<T>,
) : TypedMst<T> {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as Binary<*>
if (operation != other.operation) return false
if (left != other.left) return false
if (right != other.right) return false
return true
}
override fun hashCode(): Int {
var result = operation.hashCode()
result = 31 * result + left.hashCode()
result = 31 * result + right.hashCode()
return result
}
override fun toString(): String = "Binary(operation=$operation, left=$left, right=$right)"
}
/**
* The non-numeric constant value.
*
* @param T the type.
* @property value The held value.
* @property number The number this value corresponds.
*/
public class Constant<T>(public val value: T, public val number: Number?) : TypedMst<T> {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as Constant<*>
if (value != other.value) return false
if (number != other.number) return false
return true
}
override fun hashCode(): Int {
var result = value?.hashCode() ?: 0
result = 31 * result + (number?.hashCode() ?: 0)
return result
}
override fun toString(): String = "Constant(value=$value, number=$number)"
}
/**
* The node containing a variable
*
* @param T the type.
* @property symbol The symbol of the variable.
*/
public class Variable<T>(public val symbol: Symbol) : TypedMst<T> {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as Variable<*>
if (symbol != other.symbol) return false
return true
}
override fun hashCode(): Int = symbol.hashCode()
override fun toString(): String = "Variable(symbol=$symbol)"
}
}
/**
* Interprets the [TypedMst] node with this [Algebra] and [arguments].
*/
@UnstableKMathAPI
public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = when (this) {
is TypedMst.Unary -> algebra.unaryOperation(operation, interpret(algebra, arguments))
is TypedMst.Binary -> when {
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null ->
algebra.leftSideNumberOperation(operation, left.number, right.interpret(algebra, arguments))
algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null ->
algebra.rightSideNumberOperation(operation, left.interpret(algebra, arguments), right.number)
else -> algebra.binaryOperation(
operation,
left.interpret(algebra, arguments),
right.interpret(algebra, arguments),
)
}
is TypedMst.Constant -> value
is TypedMst.Variable -> arguments.getValue(symbol)
}
/**
* Interprets the [TypedMst] node with this [Algebra] and optional [arguments].
*/
@UnstableKMathAPI
public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = interpret(
algebra,
when (arguments.size) {
0 -> emptyMap()
1 -> mapOf(arguments[0])
else -> hashMapOf(*arguments)
},
)
/**
* Interpret this [TypedMst] node as expression.
*/
@UnstableKMathAPI
public fun <T : Any> TypedMst<T>.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
interpret(algebra, arguments)
}

View File

@ -0,0 +1,93 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
import space.kscience.kmath.operations.bindSymbolOrNull
/**
* Evaluates constants in given [MST] for given [algebra] at the same time with converting to [TypedMst].
*/
@UnstableKMathAPI
public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (this) {
is MST.Numeric -> TypedMst.Constant(
(algebra as? NumericAlgebra<T>)?.number(value) ?: error("Numeric nodes are not supported by $algebra"),
value,
)
is MST.Unary -> when (val arg = value.evaluateConstants(algebra)) {
is TypedMst.Constant<T> -> {
val value = algebra.unaryOperation(
operation,
arg.value,
)
TypedMst.Constant(value, if (value is Number) value else null)
}
else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg)
}
is MST.Binary -> {
val left = left.evaluateConstants(algebra)
val right = right.evaluateConstants(algebra)
when {
left is TypedMst.Constant<T> && right is TypedMst.Constant<T> -> {
val value = when {
algebra is NumericAlgebra && left.number != null -> algebra.leftSideNumberOperation(
operation,
left.number,
right.value,
)
algebra is NumericAlgebra && right.number != null -> algebra.rightSideNumberOperation(
operation,
left.value,
right.number,
)
else -> algebra.binaryOperation(
operation,
left.value,
right.value,
)
}
TypedMst.Constant(value, if (value is Number) value else null)
}
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary(
operation,
algebra.leftSideNumberOperationFunction(operation),
left,
right,
)
algebra is NumericAlgebra && right is TypedMst.Constant && right.number != null -> TypedMst.Binary(
operation,
algebra.rightSideNumberOperationFunction(operation),
left,
right,
)
else -> TypedMst.Binary(operation, algebra.binaryOperationFunction(operation), left, right)
}
}
is Symbol -> {
val boundSymbol = algebra.bindSymbolOrNull(this)
if (boundSymbol != null)
TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null)
else
TypedMst.Variable(this)
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.pi
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.fail
internal class TestFolding {
@Test
fun foldUnary() = assertEquals(
-1,
("-(1)".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
)
@Test
fun foldDeepUnary() = assertEquals(
1,
("-(-(1))".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
)
@Test
fun foldBinary() = assertEquals(
2,
("1*2".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
)
@Test
fun foldDeepBinary() = assertEquals(
10,
("1*2*5".parseMath().evaluateConstants(IntRing) as? TypedMst.Constant<Int> ?: fail()).value,
)
@Test
fun foldSymbol() = assertEquals(
DoubleField.pi,
("pi".parseMath().evaluateConstants(DoubleField) as? TypedMst.Constant<Double> ?: fail()).value,
)
@Test
fun foldNumeric() = assertEquals(
42.toByte(),
("42".parseMath().evaluateConstants(ByteRing) as? TypedMst.Constant<Byte> ?: fail()).value,
)
}

View File

@ -5,87 +5,48 @@
package space.kscience.kmath.estree package space.kscience.kmath.estree
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.estree.internal.ESTreeBuilder import space.kscience.kmath.estree.internal.ESTreeBuilder
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.internal.estree.BaseExpression import space.kscience.kmath.internal.estree.BaseExpression
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
import space.kscience.kmath.operations.bindSymbolOrNull
@PublishedApi
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
is Symbol -> {
val symbol = algebra.bindSymbolOrNull(node)
if (symbol != null)
constant(symbol)
else
variable(node.identity)
}
is Numeric -> constant(
(algebra as? NumericAlgebra<T>)?.number(node.value) ?: error("Numeric nodes are not supported by $this")
)
is Unary -> when {
algebra is NumericAlgebra && node.value is Numeric -> constant(
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))
)
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
}
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
algebra.binaryOperationFunction(node.operation).invoke(
algebra.number((node.left as Numeric).value),
algebra.number((node.right as Numeric).value)
)
)
algebra is NumericAlgebra && node.left is Numeric -> call(
algebra.leftSideNumberOperationFunction(node.operation),
visit(node.left),
visit(node.right),
)
algebra is NumericAlgebra && node.right is Numeric -> call(
algebra.rightSideNumberOperationFunction(node.operation),
visit(node.left),
visit(node.right),
)
else -> call(
algebra.binaryOperationFunction(node.operation),
visit(node.left),
visit(node.right),
)
}
}
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
}
/** /**
* Create a compiled expression with given [MST] and given [algebra]. * Create a compiled expression with given [MST] and given [algebra].
*/ */
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = compileWith(algebra) @OptIn(UnstableKMathAPI::class)
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> {
val typed = evaluateConstants(algebra)
if (typed is TypedMst.Constant<T>) return Expression { typed.value }
fun ESTreeBuilder<T>.visit(node: TypedMst<T>): BaseExpression = when (node) {
is TypedMst.Constant -> constant(node.value)
is TypedMst.Variable -> variable(node.symbol)
is TypedMst.Unary -> call(node.function, visit(node.value))
is TypedMst.Binary -> call(
node.function,
visit(node.left),
visit(node.right),
)
}
return ESTreeBuilder { visit(typed) }.instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments] * Compile given MST to expression and evaluate it against [arguments]
*/ */
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = public fun <T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
* Compile given MST to expression and evaluate it against [arguments] * Compile given MST to expression and evaluate it against [arguments]
*/ */
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = public fun <T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra)(*arguments)

View File

@ -22,28 +22,20 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
} }
} }
@Suppress("UNUSED_VARIABLE")
val instance: Expression<T> by lazy { val instance: Expression<T> by lazy {
val node = Program( val node = Program(
sourceType = "script", sourceType = "script",
VariableDeclaration( ReturnStatement(bodyCallback())
kind = "var",
VariableDeclarator(
id = Identifier("executable"),
init = FunctionExpression(
params = arrayOf(Identifier("constants"), Identifier("arguments")),
body = BlockStatement(ReturnStatement(bodyCallback())),
),
),
),
) )
eval(generate(node)) val code = generate(node)
GeneratedExpression(js("executable"), constants.toTypedArray()) GeneratedExpression(js("new Function('constants', 'arguments_0', code)"), constants.toTypedArray())
} }
private val constants = mutableListOf<Any>() private val constants = mutableListOf<Any>()
fun constant(value: Any?) = when { fun constant(value: Any?): BaseExpression = when {
value == null || jsTypeOf(value) == "number" || jsTypeOf(value) == "string" || jsTypeOf(value) == "boolean" -> value == null || jsTypeOf(value) == "number" || jsTypeOf(value) == "string" || jsTypeOf(value) == "boolean" ->
SimpleLiteral(value) SimpleLiteral(value)
@ -61,7 +53,8 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
} }
} }
fun variable(name: String): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name)) fun variable(name: Symbol): BaseExpression =
call(getOrFail, Identifier("arguments_0"), SimpleLiteral(name.identity))
fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression( fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
optional = false, optional = false,

View File

@ -3,6 +3,8 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
@file:Suppress("unused")
package space.kscience.kmath.internal.estree package space.kscience.kmath.internal.estree
internal fun Program(sourceType: String, vararg body: dynamic) = object : Program { internal fun Program(sourceType: String, vararg body: dynamic) = object : Program {
@ -28,9 +30,10 @@ internal fun Identifier(name: String) = object : Identifier {
override var name = name override var name = name
} }
internal fun FunctionExpression(params: Array<dynamic>, body: BlockStatement) = object : FunctionExpression { internal fun FunctionExpression(id: Identifier?, params: Array<dynamic>, body: BlockStatement) = object : FunctionExpression {
override var params = params override var params = params
override var type = "FunctionExpression" override var type = "FunctionExpression"
override var id: Identifier? = id
override var body = body override var body = body
} }

View File

@ -5,8 +5,8 @@
package space.kscience.kmath.wasm.internal package space.kscience.kmath.wasm.internal
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.internal.binaryen.* import space.kscience.kmath.internal.binaryen.*
import space.kscience.kmath.internal.webassembly.Instance import space.kscience.kmath.internal.webassembly.Instance
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
@ -16,11 +16,12 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
private val spreader = eval("(obj, args) => obj(...args)") private val spreader = eval("(obj, args) => obj(...args)")
@OptIn(UnstableKMathAPI::class)
@Suppress("UnsafeCastFromDynamic") @Suppress("UnsafeCastFromDynamic")
internal sealed class WasmBuilder<T : Number, out E : Expression<T>>( internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
protected val binaryenType: Type, protected val binaryenType: Type,
protected val algebra: Algebra<T>, protected val algebra: Algebra<T>,
protected val target: MST, protected val target: TypedMst<T>,
) { ) {
protected val keys: MutableList<Symbol> = mutableListOf() protected val keys: MutableList<Symbol> = mutableListOf()
protected lateinit var ctx: BinaryenModule protected lateinit var ctx: BinaryenModule
@ -51,59 +52,41 @@ internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
Instance(c, js("{}")).exports.executable Instance(c, js("{}")).exports.executable
} }
protected open fun visitSymbol(node: Symbol): ExpressionRef { protected abstract fun visitNumber(number: Number): ExpressionRef
algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) }
var idx = keys.indexOf(node) protected open fun visitVariable(node: TypedMst.Variable<T>): ExpressionRef {
var idx = keys.indexOf(node.symbol)
if (idx == -1) { if (idx == -1) {
keys += node keys += node.symbol
idx = keys.lastIndex idx = keys.lastIndex
} }
return ctx.local.get(idx, binaryenType) return ctx.local.get(idx, binaryenType)
} }
protected abstract fun visitNumeric(node: Numeric): ExpressionRef protected open fun visitUnary(node: TypedMst.Unary<T>): ExpressionRef =
protected open fun visitUnary(node: Unary): ExpressionRef =
error("Unary operation ${node.operation} not defined in $this") error("Unary operation ${node.operation} not defined in $this")
protected open fun visitBinary(mst: Binary): ExpressionRef = protected open fun visitBinary(mst: TypedMst.Binary<T>): ExpressionRef =
error("Binary operation ${mst.operation} not defined in $this") error("Binary operation ${mst.operation} not defined in $this")
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()") protected open fun createModule(): BinaryenModule = space.kscience.kmath.internal.binaryen.Module()
protected fun visit(node: MST): ExpressionRef = when (node) { protected fun visit(node: TypedMst<T>): ExpressionRef = when (node) {
is Symbol -> visitSymbol(node) is TypedMst.Constant -> visitNumber(
is Numeric -> visitNumeric(node) node.number ?: error("Object constants are not supported by pritimive ASM builder"),
)
is Unary -> when { is TypedMst.Variable -> visitVariable(node)
algebra is NumericAlgebra && node.value is Numeric -> visitNumeric( is TypedMst.Unary -> visitUnary(node)
Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))) is TypedMst.Binary -> visitBinary(node)
)
else -> visitUnary(node)
}
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric(
Numeric(
algebra.binaryOperationFunction(node.operation)
.invoke(
algebra.number((node.left as Numeric).value),
algebra.number((node.right as Numeric).value)
)
)
)
else -> visitBinary(node)
}
} }
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) { internal class DoubleWasmBuilder(target: TypedMst<Double>) :
WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) {
override val instance by lazy { override val instance by lazy {
object : DoubleExpression { object : DoubleExpression {
override val indexer = SimpleSymbolIndexer(keys) override val indexer = SimpleSymbolIndexer(keys)
@ -114,9 +97,9 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
override fun createModule() = readBinary(f64StandardFunctions) override fun createModule() = readBinary(f64StandardFunctions)
override fun visitNumeric(node: Numeric) = ctx.f64.const(node.value.toDouble()) override fun visitNumber(number: Number) = ctx.f64.const(number.toDouble())
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) { override fun visitUnary(node: TypedMst.Unary<Double>): ExpressionRef = when (node.operation) {
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value)) GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
GroupOps.PLUS_OPERATION -> visit(node.value) GroupOps.PLUS_OPERATION -> visit(node.value)
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value)) PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
@ -137,7 +120,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
else -> super.visitUnary(node) else -> super.visitUnary(node)
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: TypedMst.Binary<Double>): ExpressionRef = when (mst.operation) {
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
@ -148,7 +131,7 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpres
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) { internal class IntWasmBuilder(target: TypedMst<Int>) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) {
override val instance by lazy { override val instance by lazy {
object : IntExpression { object : IntExpression {
override val indexer = SimpleSymbolIndexer(keys) override val indexer = SimpleSymbolIndexer(keys)
@ -157,15 +140,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32
} }
} }
override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value.toInt()) override fun visitNumber(number: Number) = ctx.i32.const(number.toInt())
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) { override fun visitUnary(node: TypedMst.Unary<Int>): ExpressionRef = when (node.operation) {
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value)) GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
GroupOps.PLUS_OPERATION -> visit(node.value) GroupOps.PLUS_OPERATION -> visit(node.value)
else -> super.visitUnary(node) else -> super.visitUnary(node)
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: TypedMst.Binary<Int>): ExpressionRef = when (mst.operation) {
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))

View File

@ -7,7 +7,8 @@
package space.kscience.kmath.wasm package space.kscience.kmath.wasm
import space.kscience.kmath.estree.compileWith import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
@ -21,8 +22,16 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance public fun MST.compileToExpression(algebra: IntRing): IntExpression {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant) object : IntExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: IntArray): Int = typed.value
} else
IntWasmBuilder(typed).instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -31,7 +40,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBui
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
@ -49,7 +58,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleWasmBuilder(this).instance public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant) object : DoubleExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: DoubleArray): Double = typed.value
} else
DoubleWasmBuilder(typed).instance
}
/** /**
@ -59,7 +77,7 @@ public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = D
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
@ -69,4 +87,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra)(*arguments)

View File

@ -8,10 +8,14 @@
package space.kscience.kmath.asm package space.kscience.kmath.asm
import space.kscience.kmath.asm.internal.* import space.kscience.kmath.asm.internal.*
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.ast.evaluateConstants
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.LongRing
/** /**
* Compiles given MST to an Expression using AST compiler. * Compiles given MST to an Expression using AST compiler.
@ -21,102 +25,64 @@ import space.kscience.kmath.operations.*
* @return the compiled expression. * @return the compiled expression.
* @author Alexander Nozik * @author Alexander Nozik
*/ */
@OptIn(UnstableKMathAPI::class)
@PublishedApi @PublishedApi
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> { internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
fun GenericAsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) { val typed = evaluateConstants(algebra)
is Symbol -> prepareVariable(node.identity) if (typed is TypedMst.Constant<T>) return Expression { typed.value }
is Unary -> variablesVisitor(node.value)
is Binary -> { fun GenericAsmBuilder<T>.variablesVisitor(node: TypedMst<T>): Unit = when (node) {
is TypedMst.Unary -> variablesVisitor(node.value)
is TypedMst.Binary -> {
variablesVisitor(node.left) variablesVisitor(node.left)
variablesVisitor(node.right) variablesVisitor(node.right)
} }
else -> Unit is TypedMst.Variable -> prepareVariable(node.symbol)
is TypedMst.Constant -> Unit
} }
fun GenericAsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) { fun GenericAsmBuilder<T>.expressionVisitor(node: TypedMst<T>): Unit = when (node) {
is Symbol -> { is TypedMst.Constant -> if (node.number != null)
val symbol = algebra.bindSymbolOrNull(node) loadNumberConstant(node.number)
else
loadObjectConstant(node.value)
if (symbol != null) is TypedMst.Variable -> loadVariable(node.symbol)
loadObjectConstant(symbol as Any) is TypedMst.Unary -> buildCall(node.function) { expressionVisitor(node.value) }
else
loadVariable(node.identity)
}
is Numeric -> if (algebra is NumericAlgebra) { is TypedMst.Binary -> buildCall(node.function) {
if (Number::class.java.isAssignableFrom(type)) expressionVisitor(node.left)
loadNumberConstant(algebra.number(node.value) as Number) expressionVisitor(node.right)
else
loadObjectConstant(algebra.number(node.value))
} else
error("Numeric nodes are not supported by $this")
is Unary -> when {
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)),
)
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) }
}
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
algebra.binaryOperationFunction(node.operation).invoke(
algebra.number((node.left as Numeric).value),
algebra.number((node.right as Numeric).value),
)
)
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
algebra.leftSideNumberOperationFunction(node.operation),
) {
expressionVisitor(node.left)
expressionVisitor(node.right)
}
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
algebra.rightSideNumberOperationFunction(node.operation),
) {
expressionVisitor(node.left)
expressionVisitor(node.right)
}
else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
expressionVisitor(node.left)
expressionVisitor(node.right)
}
} }
} }
return GenericAsmBuilder<T>( return GenericAsmBuilder<T>(
type, type,
buildName(this), buildName("${typed.hashCode()}_${type.simpleName}"),
{ variablesVisitor(this@compileWith) }, { variablesVisitor(typed) },
{ expressionVisitor(this@compileWith) }, { expressionVisitor(typed) },
).instance ).instance
} }
/** /**
* Create a compiled expression with given [MST] and given [algebra]. * Create a compiled expression with given [MST] and given [algebra].
*/ */
public inline fun <reified T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = public inline fun <reified T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
compileWith(T::class.java, algebra) compileWith(T::class.java, algebra)
/** /**
* Compile given MST to expression and evaluate it against [arguments] * Compile given MST to expression and evaluate it against [arguments]
*/ */
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T = public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
* Compile given MST to expression and evaluate it against [arguments] * Compile given MST to expression and evaluate it against [arguments]
*/ */
public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T = public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra)(*arguments)
/** /**
@ -125,7 +91,16 @@ public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg argu
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance public fun MST.compileToExpression(algebra: IntRing): IntExpression {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant) object : IntExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: IntArray): Int = typed.value
} else
IntAsmBuilder(typed).instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -134,7 +109,7 @@ public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuil
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -152,8 +127,16 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance public fun MST.compileToExpression(algebra: LongRing): LongExpression {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant<Long>) object : LongExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: LongArray): Long = typed.value
} else
LongAsmBuilder(typed).instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -162,7 +145,7 @@ public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmB
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long = public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
@ -181,7 +164,17 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>):
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression {
val typed = evaluateConstants(algebra)
return if (typed is TypedMst.Constant) object : DoubleExpression {
override val indexer = SimpleSymbolIndexer(emptyList())
override fun invoke(arguments: DoubleArray): Double = typed.value
} else
DoubleAsmBuilder(typed).instance
}
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -190,7 +183,7 @@ public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = Dou
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra)(arguments)
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
@ -199,4 +192,4 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra)(*arguments)

View File

@ -49,5 +49,7 @@ internal abstract class AsmBuilder {
* ASM Type for [space.kscience.kmath.expressions.Symbol]. * ASM Type for [space.kscience.kmath.expressions.Symbol].
*/ */
val SYMBOL_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Symbol") val SYMBOL_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Symbol")
const val ARGUMENTS_NAME = "args"
} }
} }

View File

@ -56,7 +56,7 @@ internal class GenericAsmBuilder<T>(
/** /**
* Local variables indices are indices of symbols in this list. * Local variables indices are indices of symbols in this list.
*/ */
private val argumentsLocals = mutableListOf<String>() private val argumentsLocals = mutableListOf<Symbol>()
/** /**
* Subclasses, loads and instantiates [Expression] for given parameters. * Subclasses, loads and instantiates [Expression] for given parameters.
@ -253,10 +253,10 @@ internal class GenericAsmBuilder<T>(
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using * Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
* [loadVariable]. * [loadVariable].
*/ */
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run { fun prepareVariable(name: Symbol): Unit = invokeMethodVisitor.run {
if (name in argumentsLocals) return@run if (name in argumentsLocals) return@run
load(1, MAP_TYPE) load(1, MAP_TYPE)
aconst(name) aconst(name.identity)
invokestatic( invokestatic(
MAP_INTRINSICS_TYPE.internalName, MAP_INTRINSICS_TYPE.internalName,
@ -280,7 +280,7 @@ internal class GenericAsmBuilder<T>(
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
* with [prepareVariable] first. * with [prepareVariable] first.
*/ */
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType) fun loadVariable(name: Symbol): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
inline fun buildCall(function: Function<T>, parameters: GenericAsmBuilder<T>.() -> Unit) { inline fun buildCall(function: Function<T>, parameters: GenericAsmBuilder<T>.() -> Unit) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }

View File

@ -11,6 +11,7 @@ import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type import org.objectweb.asm.Type
import org.objectweb.asm.Type.* import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter import org.objectweb.asm.commons.InstructionAdapter
import space.kscience.kmath.ast.TypedMst
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
@ -25,9 +26,9 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
classOfT: Class<*>, classOfT: Class<*>,
protected val classOfTPrimitive: Class<*>, protected val classOfTPrimitive: Class<*>,
expressionParent: Class<E>, expressionParent: Class<E>,
protected val target: MST, protected val target: TypedMst<T>,
) : AsmBuilder() { ) : AsmBuilder() {
private val className: String = buildName(target) private val className: String = buildName("${target.hashCode()}_${classOfT.simpleName}")
/** /**
* ASM type for [tType]. * ASM type for [tType].
@ -329,63 +330,39 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
} }
private fun visitVariables( private fun visitVariables(
node: MST, node: TypedMst<T>,
arrayMode: Boolean, arrayMode: Boolean,
alreadyLoaded: MutableList<Symbol> = mutableListOf() alreadyLoaded: MutableList<Symbol> = mutableListOf()
): Unit = when (node) { ): Unit = when (node) {
is Symbol -> when (node) { is TypedMst.Variable -> if (node.symbol !in alreadyLoaded) {
!in alreadyLoaded -> { alreadyLoaded += node.symbol
alreadyLoaded += node prepareVariable(node.symbol, arrayMode)
prepareVariable(node, arrayMode) } else Unit
}
else -> {
}
}
is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded) is TypedMst.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
is MST.Binary -> { is TypedMst.Binary -> {
visitVariables(node.left, arrayMode, alreadyLoaded) visitVariables(node.left, arrayMode, alreadyLoaded)
visitVariables(node.right, arrayMode, alreadyLoaded) visitVariables(node.right, arrayMode, alreadyLoaded)
} }
else -> Unit is TypedMst.Constant -> Unit
} }
private fun visitExpression(node: MST): Unit = when (node) { private fun visitExpression(node: TypedMst<T>): Unit = when (node) {
is Symbol -> { is TypedMst.Variable -> loadVariable(node.symbol)
val symbol = algebra.bindSymbolOrNull(node)
if (symbol != null) is TypedMst.Constant -> loadNumberConstant(
loadNumberConstant(symbol) node.number ?: error("Object constants are not supported by pritimive ASM builder"),
else )
loadVariable(node)
}
is MST.Numeric -> loadNumberConstant(algebra.number(node.value)) is TypedMst.Unary -> visitUnary(node)
is TypedMst.Binary -> visitBinary(node)
is MST.Unary -> if (node.value is MST.Numeric)
loadNumberConstant(
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)),
)
else
visitUnary(node)
is MST.Binary -> when {
node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant(
algebra.binaryOperationFunction(node.operation)(
algebra.number((node.left as MST.Numeric).value),
algebra.number((node.right as MST.Numeric).value),
),
)
else -> visitBinary(node)
}
} }
protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value) protected open fun visitUnary(node: TypedMst.Unary<T>) = visitExpression(node.value)
protected open fun visitBinary(node: MST.Binary) { protected open fun visitBinary(node: TypedMst.Binary<T>) {
visitExpression(node.left) visitExpression(node.left)
visitExpression(node.right) visitExpression(node.right)
} }
@ -404,14 +381,13 @@ internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, DoubleExpression>( internal class DoubleAsmBuilder(target: TypedMst<Double>) : PrimitiveAsmBuilder<Double, DoubleExpression>(
DoubleField, DoubleField,
java.lang.Double::class.java, java.lang.Double::class.java,
java.lang.Double.TYPE, java.lang.Double.TYPE,
DoubleExpression::class.java, DoubleExpression::class.java,
target, target,
) { ) {
private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic( private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
MATH_TYPE.internalName, MATH_TYPE.internalName,
name, name,
@ -434,7 +410,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
false, false,
) )
override fun visitUnary(node: MST.Unary) { override fun visitUnary(node: TypedMst.Unary<Double>) {
super.visitUnary(node) super.visitUnary(node)
when (node.operation) { when (node.operation) {
@ -459,7 +435,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
} }
} }
override fun visitBinary(node: MST.Binary) { override fun visitBinary(node: TypedMst.Binary<Double>) {
super.visitBinary(node) super.visitBinary(node)
when (node.operation) { when (node.operation) {
@ -479,7 +455,7 @@ internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, Doubl
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class IntAsmBuilder(target: MST) : internal class IntAsmBuilder(target: TypedMst<Int>) :
PrimitiveAsmBuilder<Int, IntExpression>( PrimitiveAsmBuilder<Int, IntExpression>(
IntRing, IntRing,
Integer::class.java, Integer::class.java,
@ -487,7 +463,7 @@ internal class IntAsmBuilder(target: MST) :
IntExpression::class.java, IntExpression::class.java,
target target
) { ) {
override fun visitUnary(node: MST.Unary) { override fun visitUnary(node: TypedMst.Unary<Int>) {
super.visitUnary(node) super.visitUnary(node)
when (node.operation) { when (node.operation) {
@ -497,7 +473,7 @@ internal class IntAsmBuilder(target: MST) :
} }
} }
override fun visitBinary(node: MST.Binary) { override fun visitBinary(node: TypedMst.Binary<Int>) {
super.visitBinary(node) super.visitBinary(node)
when (node.operation) { when (node.operation) {
@ -510,14 +486,14 @@ internal class IntAsmBuilder(target: MST) :
} }
@UnstableKMathAPI @UnstableKMathAPI
internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpression>( internal class LongAsmBuilder(target: TypedMst<Long>) : PrimitiveAsmBuilder<Long, LongExpression>(
LongRing, LongRing,
java.lang.Long::class.java, java.lang.Long::class.java,
java.lang.Long.TYPE, java.lang.Long.TYPE,
LongExpression::class.java, LongExpression::class.java,
target, target,
) { ) {
override fun visitUnary(node: MST.Unary) { override fun visitUnary(node: TypedMst.Unary<Long>) {
super.visitUnary(node) super.visitUnary(node)
when (node.operation) { when (node.operation) {
@ -527,7 +503,7 @@ internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpre
} }
} }
override fun visitBinary(node: MST.Binary) { override fun visitBinary(node: TypedMst.Binary<Long>) {
super.visitBinary(node) super.visitBinary(node)
when (node.operation) { when (node.operation) {

View File

@ -55,15 +55,15 @@ internal inline fun MethodVisitor.instructionAdapter(block: InstructionAdapter.(
internal fun MethodVisitor.label(): Label = Label().also(::visitLabel) internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
/** /**
* Creates a class name for [Expression] subclassed to implement [mst] provided. * Creates a class name for [Expression] based with appending [marker] to reduce collisions.
* *
* These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there * These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there
* is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively.
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
internal tailrec fun buildName(mst: MST, collision: Int = 0): String { internal tailrec fun buildName(marker: String, collision: Int = 0): String {
val name = "space.kscience.kmath.asm.generated.CompiledExpression_${mst.hashCode()}_$collision" val name = "space.kscience.kmath.asm.generated.CompiledExpression_${marker}_$collision"
try { try {
Class.forName(name) Class.forName(name)
@ -71,7 +71,7 @@ internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
return name return name
} }
return buildName(mst, collision + 1) return buildName(marker, collision + 1)
} }
@Suppress("FunctionName") @Suppress("FunctionName")

32
kmath-commons/README.md Normal file
View File

@ -0,0 +1,32 @@
# Module kmath-commons
Commons math binding for kmath
## Usage
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-commons:0.3.0`.
**Gradle Groovy:**
```groovy
repositories {
maven { url 'https://repo.kotlin.link' }
mavenCentral()
}
dependencies {
implementation 'space.kscience:kmath-commons:0.3.0'
}
```
**Gradle Kotlin DSL:**
```kotlin
repositories {
maven("https://repo.kotlin.link")
mavenCentral()
}
dependencies {
implementation("space.kscience:kmath-commons:0.3.0")
}
```

View File

@ -8,17 +8,17 @@ Complex and hypercomplex number systems in KMath.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-17`. The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0`.
**Gradle:** **Gradle Groovy:**
```gradle ```groovy
repositories { repositories {
maven { url 'https://repo.kotlin.link' } maven { url 'https://repo.kotlin.link' }
mavenCentral() mavenCentral()
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-complex:0.3.0-dev-17' implementation 'space.kscience:kmath-complex:0.3.0'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -29,6 +29,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-complex:0.3.0-dev-17") implementation("space.kscience:kmath-complex:0.3.0")
} }
``` ```

View File

@ -19,13 +19,15 @@ readme {
feature( feature(
id = "complex", id = "complex",
description = "Complex Numbers",
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt"
) ){
"Complex numbers operations"
}
feature( feature(
id = "quaternion", id = "quaternion",
description = "Quaternions",
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt"
) ){
"Quaternions and their composition"
}
} }

View File

@ -16,52 +16,132 @@ import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.MutableMemoryBuffer import space.kscience.kmath.structures.MutableMemoryBuffer
import kotlin.math.* import kotlin.math.*
/**
* Represents `double`-based quaternion.
*
* @property w The first component.
* @property x The second component.
* @property y The third component.
* @property z The fourth component.
*/
public class Quaternion(
public val w: Double,
public val x: Double,
public val y: Double,
public val z: Double,
) : Buffer<Double> {
init {
require(!w.isNaN()) { "w-component of quaternion is not-a-number" }
require(!x.isNaN()) { "x-component of quaternion is not-a-number" }
require(!y.isNaN()) { "y-component of quaternion is not-a-number" }
require(!z.isNaN()) { "z-component of quaternion is not-a-number" }
}
/**
* Returns a string representation of this quaternion.
*/
override fun toString(): String = "($w + $x * i + $y * j + $z * k)"
override val size: Int get() = 4
override fun get(index: Int): Double = when (index) {
0 -> w
1 -> x
2 -> y
3 -> z
else -> error("Index $index out of bounds [0,3]")
}
override fun iterator(): Iterator<Double> = listOf(w, x, y, z).iterator()
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as Quaternion
if (w != other.w) return false
if (x != other.x) return false
if (y != other.y) return false
if (z != other.z) return false
return true
}
override fun hashCode(): Int {
var result = w.hashCode()
result = 31 * result + x.hashCode()
result = 31 * result + y.hashCode()
result = 31 * result + z.hashCode()
return result
}
public companion object : MemorySpec<Quaternion> {
override val objectSize: Int get() = 32
override fun MemoryReader.read(offset: Int): Quaternion = Quaternion(
readDouble(offset),
readDouble(offset + 8),
readDouble(offset + 16),
readDouble(offset + 24)
)
override fun MemoryWriter.write(offset: Int, value: Quaternion) {
writeDouble(offset, value.w)
writeDouble(offset + 8, value.x)
writeDouble(offset + 16, value.y)
writeDouble(offset + 24, value.z)
}
}
}
public fun Quaternion(w: Number, x: Number = 0.0, y: Number = 0.0, z: Number = 0.0): Quaternion = Quaternion(
w.toDouble(),
x.toDouble(),
y.toDouble(),
z.toDouble(),
)
/** /**
* This quaternion's conjugate. * This quaternion's conjugate.
*/ */
public val Quaternion.conjugate: Quaternion public val Quaternion.conjugate: Quaternion
get() = QuaternionField { z - x * i - y * j - z * k } get() = Quaternion(w, -x, -y, -z)
/** /**
* This quaternion's reciprocal. * This quaternion's reciprocal.
*/ */
public val Quaternion.reciprocal: Quaternion public val Quaternion.reciprocal: Quaternion
get() { get() {
QuaternionField { val norm2 = (w * w + x * x + y * y + z * z)
val n = norm(this@reciprocal) return Quaternion(w / norm2, -x / norm2, -y / norm2, -z / norm2)
return conjugate / (n * n)
}
} }
/** public fun Quaternion.normalized(): Quaternion = with(QuaternionField){ this@normalized / norm(this@normalized) }
* Absolute value of the quaternion.
*/
public val Quaternion.r: Double
get() = sqrt(w * w + x * x + y * y + z * z)
/** /**
* A field of [Quaternion]. * A field of [Quaternion].
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>, public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Double>, PowerOperations<Quaternion>,
ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> { ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
override val zero: Quaternion = 0.toQuaternion() override val zero: Quaternion = Quaternion(0.0)
override val one: Quaternion = 1.toQuaternion() override val one: Quaternion = Quaternion(1.0)
/** /**
* The `i` quaternion unit. * The `i` quaternion unit.
*/ */
public val i: Quaternion = Quaternion(0, 1) public val i: Quaternion = Quaternion(0.0, 1.0, 0.0, 0.0)
/** /**
* The `j` quaternion unit. * The `j` quaternion unit.
*/ */
public val j: Quaternion = Quaternion(0, 0, 1) public val j: Quaternion = Quaternion(0.0, 0.0, 1.0, 0.0)
/** /**
* The `k` quaternion unit. * The `k` quaternion unit.
*/ */
public val k: Quaternion = Quaternion(0, 0, 0, 1) public val k: Quaternion = Quaternion(0.0, 0.0, 0.0, 1.0)
override fun add(left: Quaternion, right: Quaternion): Quaternion = override fun add(left: Quaternion, right: Quaternion): Quaternion =
Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z) Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z)
@ -133,7 +213,7 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
override fun exp(arg: Quaternion): Quaternion { override fun exp(arg: Quaternion): Quaternion {
val un = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z val un = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z
if (un == 0.0) return exp(arg.w).toQuaternion() if (un == 0.0) return Quaternion(exp(arg.w))
val n1 = sqrt(un) val n1 = sqrt(un)
val ea = exp(arg.w) val ea = exp(arg.w)
val n2 = ea * sin(n1) / n1 val n2 = ea * sin(n1) / n1
@ -158,7 +238,8 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z) return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z)
} }
override operator fun Number.plus(other: Quaternion): Quaternion = Quaternion(toDouble() + other.w, other.x, other.y, other.z) override operator fun Number.plus(other: Quaternion): Quaternion =
Quaternion(toDouble() + other.w, other.x, other.y, other.z)
override operator fun Number.minus(other: Quaternion): Quaternion = override operator fun Number.minus(other: Quaternion): Quaternion =
Quaternion(toDouble() - other.w, -other.x, -other.y, -other.z) Quaternion(toDouble() - other.w, -other.x, -other.y, -other.z)
@ -170,7 +251,12 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z) Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z)
override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z) override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z)
override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg) override fun norm(arg: Quaternion): Double = sqrt(
arg.w.pow(2) +
arg.x.pow(2) +
arg.y.pow(2) +
arg.z.pow(2)
)
override fun bindSymbolOrNull(value: String): Quaternion? = when (value) { override fun bindSymbolOrNull(value: String): Quaternion? = when (value) {
"i" -> i "i" -> i
@ -179,7 +265,7 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
else -> null else -> null
} }
override fun number(value: Number): Quaternion = value.toQuaternion() override fun number(value: Number): Quaternion = Quaternion(value)
override fun sinh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / 2.0 override fun sinh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / 2.0
override fun cosh(arg: Quaternion): Quaternion = (exp(arg) + exp(-arg)) / 2.0 override fun cosh(arg: Quaternion): Quaternion = (exp(arg) + exp(-arg)) / 2.0
@ -189,76 +275,6 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
override fun atanh(arg: Quaternion): Quaternion = (ln(arg + one) - ln(one - arg)) / 2.0 override fun atanh(arg: Quaternion): Quaternion = (ln(arg + one) - ln(one - arg)) / 2.0
} }
/**
* Represents `double`-based quaternion.
*
* @property w The first component.
* @property x The second component.
* @property y The third component.
* @property z The fourth component.
*/
@OptIn(UnstableKMathAPI::class)
public data class Quaternion(
val w: Double, val x: Double, val y: Double, val z: Double,
) {
public constructor(w: Number, x: Number, y: Number, z: Number) : this(
w.toDouble(),
x.toDouble(),
y.toDouble(),
z.toDouble(),
)
public constructor(w: Number, x: Number, y: Number) : this(w.toDouble(), x.toDouble(), y.toDouble(), 0.0)
public constructor(w: Number, x: Number) : this(w.toDouble(), x.toDouble(), 0.0, 0.0)
public constructor(w: Number) : this(w.toDouble(), 0.0, 0.0, 0.0)
public constructor(wx: Complex, yz: Complex) : this(wx.re, wx.im, yz.re, yz.im)
public constructor(wx: Complex) : this(wx.re, wx.im, 0, 0)
init {
require(!w.isNaN()) { "w-component of quaternion is not-a-number" }
require(!x.isNaN()) { "x-component of quaternion is not-a-number" }
require(!y.isNaN()) { "x-component of quaternion is not-a-number" }
require(!z.isNaN()) { "x-component of quaternion is not-a-number" }
}
/**
* Returns a string representation of this quaternion.
*/
override fun toString(): String = "($w + $x * i + $y * j + $z * k)"
public companion object : MemorySpec<Quaternion> {
override val objectSize: Int
get() = 32
override fun MemoryReader.read(offset: Int): Quaternion =
Quaternion(readDouble(offset), readDouble(offset + 8), readDouble(offset + 16), readDouble(offset + 24))
override fun MemoryWriter.write(offset: Int, value: Quaternion) {
writeDouble(offset, value.w)
writeDouble(offset + 8, value.x)
writeDouble(offset + 16, value.y)
writeDouble(offset + 24, value.z)
}
}
}
/**
* Creates a quaternion with real part equal to this real.
*
* @receiver the real part.
* @return a new quaternion.
*/
public fun Number.toQuaternion(): Quaternion = Quaternion(this)
/**
* Creates a quaternion with `w`-component equal to `re`-component of given complex and `x`-component equal to
* `im`-component of given complex.
*
* @receiver the complex number.
* @return a new quaternion.
*/
public fun Complex.toQuaternion(): Quaternion = Quaternion(this)
/** /**
* Creates a new buffer of quaternions with the specified [size], where each element is calculated by calling the * Creates a new buffer of quaternions with the specified [size], where each element is calculated by calling the
* specified [init] function. * specified [init] function.

View File

@ -6,10 +6,23 @@
package space.kscience.kmath.complex package space.kscience.kmath.complex
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.testutils.assertBufferEquals
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class QuaternionFieldTest { internal class QuaternionTest {
@Test
fun testNorm() {
assertEquals(2.0, QuaternionField.norm(Quaternion(1.0, 1.0, 1.0, 1.0)))
}
@Test
fun testInverse() = QuaternionField {
val q = Quaternion(1.0, 2.0, -3.0, 4.0)
assertBufferEquals(one, q * q.reciprocal, 1e-4)
}
@Test @Test
fun testAddition() { fun testAddition() {
assertEquals(Quaternion(42, 42), QuaternionField { Quaternion(16, 16) + Quaternion(26, 26) }) assertEquals(Quaternion(42, 42), QuaternionField { Quaternion(16, 16) + Quaternion(26, 26) })

View File

@ -15,17 +15,17 @@ performance calculations to code generation.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-17`. The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0`.
**Gradle:** **Gradle Groovy:**
```gradle ```groovy
repositories { repositories {
maven { url 'https://repo.kotlin.link' } maven { url 'https://repo.kotlin.link' }
mavenCentral() mavenCentral()
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-core:0.3.0-dev-17' implementation 'space.kscience:kmath-core:0.3.0'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -36,6 +36,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-core:0.3.0-dev-17") implementation("space.kscience:kmath-core:0.3.0")
} }
``` ```

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,6 @@
plugins { plugins {
kotlin("multiplatform") id("ru.mipt.npm.gradle.mpp")
id("ru.mipt.npm.gradle.common")
id("ru.mipt.npm.gradle.native") id("ru.mipt.npm.gradle.native")
// id("com.xcporter.metaview") version "0.0.5"
} }
kotlin.sourceSets { kotlin.sourceSets {

View File

@ -0,0 +1,59 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.domains
import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
@UnstableKMathAPI
public abstract class Domain1D<T : Comparable<T>>(public val range: ClosedRange<T>) : Domain<T> {
override val dimension: Int get() = 1
public operator fun contains(value: T): Boolean = range.contains(value)
override operator fun contains(point: Point<T>): Boolean {
require(point.size == 0)
return contains(point[0])
}
}
@UnstableKMathAPI
public class DoubleDomain1D(
@Suppress("CanBeParameter") public val doubleRange: ClosedFloatingPointRange<Double>,
) : Domain1D<Double>(doubleRange), DoubleDomain {
override fun getLowerBound(num: Int): Double {
require(num == 0)
return range.start
}
override fun getUpperBound(num: Int): Double {
require(num == 0)
return range.endInclusive
}
override fun volume(): Double = range.endInclusive - range.start
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as DoubleDomain1D
if (doubleRange != other.doubleRange) return false
return true
}
override fun hashCode(): Int = doubleRange.hashCode()
override fun toString(): String = doubleRange.toString()
}
@UnstableKMathAPI
public val Domain1D<Double>.center: Double
get() = (range.endInclusive + range.start) / 2

View File

@ -7,18 +7,28 @@ package space.kscience.kmath.domains
import space.kscience.kmath.linear.Point import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.indices import space.kscience.kmath.structures.indices
/** /**
* * A hyper-square (or hyper-cube) real-space domain. It is formed by a [Buffer] of [lower] boundaries
* HyperSquareDomain class. * and a [Buffer] of upper boundaries. Upper should be greater or equals than lower.
*
* @author Alexander Nozik
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public class HyperSquareDomain(private val lower: Buffer<Double>, private val upper: Buffer<Double>) : DoubleDomain { public class HyperSquareDomain(public val lower: Buffer<Double>, public val upper: Buffer<Double>) : DoubleDomain {
init {
require(lower.size == upper.size) {
"Domain borders size mismatch. Lower borders size is ${lower.size}, but upper borders size is ${upper.size}."
}
require(lower.indices.all { lower[it] <= upper[it] }) {
"Domain borders order mismatch. Lower borders must be less or equals than upper borders."
}
}
override val dimension: Int get() = lower.size override val dimension: Int get() = lower.size
public val center: DoubleBuffer get() = DoubleBuffer(dimension) { (lower[it] + upper[it]) / 2.0 }
override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i -> override operator fun contains(point: Point<Double>): Boolean = point.indices.all { i ->
point[i] in lower[i]..upper[i] point[i] in lower[i]..upper[i]
} }

View File

@ -1,33 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.domains
import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
@UnstableKMathAPI
public class UnivariateDomain(public val range: ClosedFloatingPointRange<Double>) : DoubleDomain {
override val dimension: Int get() = 1
public operator fun contains(d: Double): Boolean = range.contains(d)
override operator fun contains(point: Point<Double>): Boolean {
require(point.size == 0)
return contains(point[0])
}
override fun getLowerBound(num: Int): Double {
require(num == 0)
return range.start
}
override fun getUpperBound(num: Int): Double {
require(num == 0)
return range.endInclusive
}
override fun volume(): Double = range.endInclusive - range.start
}

View File

@ -34,12 +34,12 @@ public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> = override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
{ left, right -> { left, right ->
Expression { arguments -> Expression { arguments ->
algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments)) algebra.binaryOperationFunction(operation)(left(arguments), right(arguments))
} }
} }
override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg -> override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) } Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) }
} }
} }
@ -164,8 +164,6 @@ public open class FunctionalExpressionExtendedField<T, out A : ExtendedField<T>>
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> = override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionField>.binaryOperationFunction(operation) super<FunctionalExpressionField>.binaryOperationFunction(operation)
override fun bindSymbol(value: String): Expression<T> = super<FunctionalExpressionField>.bindSymbol(value)
} }
public inline fun <T, A : Group<T>> A.expressionInGroup( public inline fun <T, A : Group<T>> A.expressionInGroup(

View File

@ -24,7 +24,7 @@ public sealed interface MST {
public data class Numeric(val value: Number) : MST public data class Numeric(val value: Number) : MST
/** /**
* A node containing an unary operation. * A node containing a unary operation.
* *
* @property operation the identifier of operation. * @property operation the identifier of operation.
* @property value the argument of this operation. * @property value the argument of this operation.
@ -34,7 +34,7 @@ public sealed interface MST {
/** /**
* A node containing binary operation. * A node containing binary operation.
* *
* @property operation the identifier operation. * @property operation the identifier of operation.
* @property left the left operand. * @property left the left operand.
* @property right the right operand. * @property right the right operand.
*/ */

View File

@ -272,7 +272,7 @@ public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: Aut
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow( public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>, x: AutoDiffValue<T>,
y: Double, y: Double,
): AutoDiffValue<T> = derive(const { x.value.pow(y)}) { z -> ): AutoDiffValue<T> = derive(const { x.value.pow(y) }) { z ->
x.d += z.d * y * x.value.pow(y - 1) x.d += z.d * y * x.value.pow(y - 1)
} }
@ -343,10 +343,7 @@ public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: Au
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>( public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
context: F, context: F,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>, ) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
SimpleAutoDiffField<T, F>(context, bindings) {
override fun bindSymbol(value: String): AutoDiffValue<T> = super<SimpleAutoDiffField>.bindSymbol(value)
override fun number(value: Number): AutoDiffValue<T> = const { number(value) } override fun number(value: Number): AutoDiffValue<T> = const { number(value) }

View File

@ -3,43 +3,71 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/ */
package space.kscience.kmath.misc package space.kscience.kmath.misc
import kotlin.comparisons.*
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.VirtualBuffer
/** /**
* Return a new list filled with buffer indices. Indice order is defined by sorting associated buffer value. * Return a new array filled with buffer indices. Indices order is defined by sorting associated buffer value.
* This feature allows to sort buffer values without reordering its content. * This feature allows sorting buffer values without reordering its content.
* *
* @return List of buffer indices, sorted by associated value. * @return Buffer indices, sorted by associated value.
*/ */
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V: Comparable<V>> Buffer<V>.permSort() : IntArray = _permSortWith(compareBy<Int> { get(it) }) public fun <V : Comparable<V>> Buffer<V>.indicesSorted(): IntArray = permSortIndicesWith(compareBy { get(it) })
/**
* Create a zero-copy virtual buffer that contains the same elements but in ascending order
*/
@OptIn(UnstableKMathAPI::class)
public fun <V : Comparable<V>> Buffer<V>.sorted(): Buffer<V> {
val permutations = indicesSorted()
return VirtualBuffer(size) { this[permutations[it]] }
}
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V: Comparable<V>> Buffer<V>.permSortDescending() : IntArray = _permSortWith(compareByDescending<Int> { get(it) }) public fun <V : Comparable<V>> Buffer<V>.indicesSortedDescending(): IntArray =
permSortIndicesWith(compareByDescending { get(it) })
/**
* Create a zero-copy virtual buffer that contains the same elements but in descending order
*/
@OptIn(UnstableKMathAPI::class)
public fun <V : Comparable<V>> Buffer<V>.sortedDescending(): Buffer<V> {
val permutations = indicesSortedDescending()
return VirtualBuffer(size) { this[permutations[it]] }
}
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V, C: Comparable<C>> Buffer<V>.permSortBy(selector: (V) -> C) : IntArray = _permSortWith(compareBy<Int> { selector(get(it)) }) public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedBy(selector: (V) -> C): IntArray =
permSortIndicesWith(compareBy { selector(get(it)) })
@OptIn(UnstableKMathAPI::class)
public fun <V, C : Comparable<C>> Buffer<V>.sortedBy(selector: (V) -> C): Buffer<V> {
val permutations = indicesSortedBy(selector)
return VirtualBuffer(size) { this[permutations[it]] }
}
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V, C: Comparable<C>> Buffer<V>.permSortByDescending(selector: (V) -> C) : IntArray = _permSortWith(compareByDescending<Int> { selector(get(it)) }) public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedByDescending(selector: (V) -> C): IntArray =
permSortIndicesWith(compareByDescending { selector(get(it)) })
@OptIn(UnstableKMathAPI::class)
public fun <V, C : Comparable<C>> Buffer<V>.sortedByDescending(selector: (V) -> C): Buffer<V> {
val permutations = indicesSortedByDescending(selector)
return VirtualBuffer(size) { this[permutations[it]] }
}
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V> Buffer<V>.permSortWith(comparator : Comparator<V>) : IntArray = _permSortWith { i1, i2 -> comparator.compare(get(i1), get(i2)) } public fun <V> Buffer<V>.indicesSortedWith(comparator: Comparator<V>): IntArray =
permSortIndicesWith { i1, i2 -> comparator.compare(get(i1), get(i2)) }
@PerformancePitfall private fun <V> Buffer<V>.permSortIndicesWith(comparator: Comparator<Int>): IntArray {
@UnstableKMathAPI if (size < 2) return IntArray(size) { 0 }
private fun <V> Buffer<V>._permSortWith(comparator : Comparator<Int>) : IntArray {
if (size < 2) return IntArray(size)
/* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indice /* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indices
* arrays more efficiently by copying subpart of cached one. For bigger needs, we could copy entire * arrays more efficiently by copying subpart of cached one. For bigger needs, we could copy entire
* cached array, then fill remaining indices manually. Not done for now, because: * cached array, then fill remaining indices manually. Not done for now, because:
* 1. doing it right would require some statistics about common used buffer sizes. * 1. doing it right would require some statistics about common used buffer sizes.
@ -53,3 +81,12 @@ private fun <V> Buffer<V>._permSortWith(comparator : Comparator<Int>) : IntArray
*/ */
return packedIndices.sortedWith(comparator).toIntArray() return packedIndices.sortedWith(comparator).toIntArray()
} }
/**
* Checks that the [Buffer] is sorted (ascending) and throws [IllegalArgumentException] if it is not.
*/
public fun <T : Comparable<T>> Buffer<T>.requireSorted() {
for (i in 0..(size - 2)) {
require(get(i + 1) >= get(i)) { "The buffer is not sorted at index $i" }
}
}

View File

@ -10,25 +10,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import kotlin.reflect.KClass import kotlin.reflect.KClass
/**
* An exception is thrown when the expected and actual shape of NDArray differ.
*
* @property expected the expected shape.
* @property actual the actual shape.
*/
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
public typealias Shape = IntArray
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
public interface WithShape {
public val shape: Shape
public val indices: ShapeIndexer get() = DefaultStrides(shape)
}
/** /**
* The base interface for all ND-algebra implementations. * The base interface for all ND-algebra implementations.
* *
@ -194,7 +175,7 @@ public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, Gro
override fun multiply(left: StructureND<T>, right: StructureND<T>): StructureND<T> = override fun multiply(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
zip(left, right) { aValue, bValue -> multiply(aValue, bValue) } zip(left, right) { aValue, bValue -> multiply(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions with context receivers
/** /**
* Multiplies an ND structure by an element of it. * Multiplies an ND structure by an element of it.

View File

@ -47,7 +47,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
zipInline(left.toBufferND(), right.toBufferND(), transform) zipInline(left.toBufferND(), right.toBufferND(), transform)
public companion object { public companion object {
public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = DefaultStrides.Companion::invoke public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = ::Strides
} }
} }

View File

@ -32,18 +32,23 @@ public open class BufferND<out T>(
/** /**
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferND] * Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferND]
*/ */
public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer( public inline fun <T, R : Any> StructureND<T>.mapToBuffer(
factory: BufferFactory<R> = Buffer.Companion::auto, factory: BufferFactory<R>,
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): BufferND<R> { ): BufferND<R> = if (this is BufferND<T>)
return if (this is BufferND<T>) BufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) })
BufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) }) else {
else { val strides = DefaultStrides(shape)
val strides = DefaultStrides(shape) BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
} }
/**
* Transform structure to a new structure using inferred [BufferFactory]
*/
public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
crossinline transform: (T) -> R,
): BufferND<R> = mapToBuffer(Buffer.Companion::auto, transform)
/** /**
* Represents [MutableStructureND] over [MutableBuffer]. * Represents [MutableStructureND] over [MutableBuffer].
* *

View File

@ -0,0 +1,35 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.nd
/**
* An exception is thrown when the expected and actual shape of NDArray differ.
*
* @property expected the expected shape.
* @property actual the actual shape.
*/
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
public class IndexOutOfShapeException(public val shape: Shape, public val index: IntArray) :
RuntimeException("Index ${index.contentToString()} is out of shape ${shape.contentToString()}")
public typealias Shape = IntArray
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
public interface WithShape {
public val shape: Shape
public val indices: ShapeIndexer get() = DefaultStrides(shape)
}
internal fun requireIndexInShape(index: IntArray, shape: Shape) {
if (index.size != shape.size) throw IndexOutOfShapeException(index, shape)
shape.forEachIndexed { axis, axisShape ->
if (index[axis] !in 0 until axisShape) throw IndexOutOfShapeException(index, shape)
}
}

View File

@ -10,7 +10,7 @@ import kotlin.native.concurrent.ThreadLocal
/** /**
* A converter from linear index to multivariate index * A converter from linear index to multivariate index
*/ */
public interface ShapeIndexer: Iterable<IntArray>{ public interface ShapeIndexer : Iterable<IntArray> {
public val shape: Shape public val shape: Shape
/** /**
@ -44,7 +44,7 @@ public interface ShapeIndexer: Iterable<IntArray>{
/** /**
* Linear transformation of indexes * Linear transformation of indexes
*/ */
public abstract class Strides: ShapeIndexer { public abstract class Strides : ShapeIndexer {
/** /**
* Array strides * Array strides
*/ */
@ -66,7 +66,7 @@ public abstract class Strides: ShapeIndexer {
/** /**
* Simple implementation of [Strides]. * Simple implementation of [Strides].
*/ */
public class DefaultStrides private constructor(override val shape: IntArray) : Strides() { public class DefaultStrides(override val shape: IntArray) : Strides() {
override val linearSize: Int get() = strides[shape.size] override val linearSize: Int get() = strides[shape.size]
/** /**
@ -112,10 +112,16 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
/** /**
* Cached builder for default strides * Cached builder for default strides
*/ */
@Deprecated("Replace by Strides(shape)")
public operator fun invoke(shape: IntArray): Strides = public operator fun invoke(shape: IntArray): Strides =
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) } defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
} }
} }
@ThreadLocal @ThreadLocal
private val defaultStridesCache = HashMap<IntArray, Strides>() private val defaultStridesCache = HashMap<IntArray, Strides>()
/**
* Cached builder for default strides
*/
public fun Strides(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }

View File

@ -138,6 +138,10 @@ private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>)
override fun equals(other: Any?): Boolean = false override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0 override fun hashCode(): Int = 0
override fun toString(): String {
return StructureND.toString(structure)
}
} }
/** /**

View File

@ -101,8 +101,8 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
val bufferRepr: String = when (structure.shape.size) { val bufferRepr: String = when (structure.shape.size) {
1 -> (0 until structure.shape[0]).map { structure[it] } 1 -> (0 until structure.shape[0]).map { structure[it] }
.joinToString(prefix = "[", postfix = "]", separator = ", ") .joinToString(prefix = "[", postfix = "]", separator = ", ")
2 -> (0 until structure.shape[0]).joinToString(prefix = "[", postfix = "]", separator = ", ") { i -> 2 -> (0 until structure.shape[0]).joinToString(prefix = "[\n", postfix = "\n]", separator = ",\n") { i ->
(0 until structure.shape[1]).joinToString(prefix = "[", postfix = "]", separator = ", ") { j -> (0 until structure.shape[1]).joinToString(prefix = " [", postfix = "]", separator = ", ") { j ->
structure[i, j].toString() structure[i, j].toString()
} }
} }

View File

@ -0,0 +1,30 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI
public open class VirtualStructureND<T>(
override val shape: Shape,
public val producer: (IntArray) -> T,
) : StructureND<T> {
override fun get(index: IntArray): T {
requireIndexInShape(index, shape)
return producer(index)
}
}
@UnstableKMathAPI
public class VirtualDoubleStructureND(
shape: Shape,
producer: (IntArray) -> Double,
) : VirtualStructureND<Double>(shape, producer)
@UnstableKMathAPI
public class VirtualIntStructureND(
shape: Shape,
producer: (IntArray) -> Int,
) : VirtualStructureND<Int>(shape, producer)

View File

@ -0,0 +1,32 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.nd
public fun <T> StructureND<T>.roll(axis: Int, step: Int = 1): StructureND<T> {
require(axis in shape.indices) { "Axis $axis is outside of shape dimensions: [0, ${shape.size})" }
return VirtualStructureND(shape) { index ->
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
if (indexAxis == axis) {
(index[indexAxis] + step).mod(shape[indexAxis])
} else {
index[indexAxis]
}
}
get(newIndex)
}
}
public fun <T> StructureND<T>.roll(pair: Pair<Int, Int>, vararg others: Pair<Int, Int>): StructureND<T> {
val axisMap: Map<Int, Int> = mapOf(pair, *others)
require(axisMap.keys.all { it in shape.indices }) { "Some of axes ${axisMap.keys} is outside of shape dimensions: [0, ${shape.size})" }
return VirtualStructureND(shape) { index ->
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
val offset = axisMap[indexAxis] ?: 0
(index[indexAxis] + offset).mod(shape[indexAxis])
}
get(newIndex)
}
}

View File

@ -5,7 +5,6 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
@ -53,7 +52,7 @@ public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
*/ */
private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline( private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
buffer: Buffer<T>, buffer: Buffer<T>,
crossinline block: A.(T) -> T crossinline block: A.(T) -> T,
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) } ): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) }
/** /**
@ -61,7 +60,7 @@ private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
*/ */
private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline( private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
buffer: Buffer<T>, buffer: Buffer<T>,
crossinline block: A.(index: Int, arg: T) -> T crossinline block: A.(index: Int, arg: T) -> T,
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) } ): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) }
/** /**
@ -70,7 +69,7 @@ private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline( private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline(
l: Buffer<T>, l: Buffer<T>,
r: Buffer<T>, r: Buffer<T>,
crossinline block: A.(l: T, r: T) -> T crossinline block: A.(l: T, r: T) -> T,
): Buffer<T> { ): Buffer<T> {
require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" } require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" }
return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) } return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) }
@ -127,13 +126,13 @@ public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.atanh(arg: Buff
mapInline(arg) { atanh(it) } mapInline(arg) { atanh(it) }
public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> = public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> =
mapInline(arg) {it.pow(pow) } mapInline(arg) { it.pow(pow) }
public open class BufferRingOps<T, A: Ring<T>>( public open class BufferRingOps<T, A : Ring<T>>(
override val elementAlgebra: A, override val elementAlgebra: A,
override val bufferFactory: BufferFactory<T>, override val bufferFactory: BufferFactory<T>,
) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{ ) : BufferAlgebra<T, A>, RingOps<Buffer<T>> {
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r } override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r } override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
@ -152,10 +151,11 @@ public val ShortRing.bufferAlgebra: BufferRingOps<Short, ShortRing>
public open class BufferFieldOps<T, A : Field<T>>( public open class BufferFieldOps<T, A : Field<T>>(
elementAlgebra: A, elementAlgebra: A,
bufferFactory: BufferFactory<T>, bufferFactory: BufferFactory<T>,
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> { ) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>,
ScaleOperations<Buffer<T>> {
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r } // override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r } // override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
override fun divide(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l / r } override fun divide(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l / r }
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) } override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
@ -168,7 +168,7 @@ public open class BufferFieldOps<T, A : Field<T>>(
public class BufferField<T, A : Field<T>>( public class BufferField<T, A : Field<T>>(
elementAlgebra: A, elementAlgebra: A,
bufferFactory: BufferFactory<T>, bufferFactory: BufferFactory<T>,
override val size: Int override val size: Int,
) : BufferFieldOps<T, A>(elementAlgebra, bufferFactory), Field<Buffer<T>>, WithSize { ) : BufferFieldOps<T, A>(elementAlgebra, bufferFactory), Field<Buffer<T>>, WithSize {
override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero } override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero }

View File

@ -105,6 +105,16 @@ public interface Buffer<out T> {
*/ */
public val Buffer<*>.indices: IntRange get() = 0 until size public val Buffer<*>.indices: IntRange get() = 0 until size
public fun <T> Buffer<T>.first(): T {
require(size > 0) { "Can't get the first element of empty buffer" }
return get(0)
}
public fun <T> Buffer<T>.last(): T {
require(size > 0) { "Can't get the last element of empty buffer" }
return get(size - 1)
}
/** /**
* Immutable wrapper for [MutableBuffer]. * Immutable wrapper for [MutableBuffer].
* *

View File

@ -6,14 +6,13 @@
package space.kscience.kmath.misc package space.kscience.kmath.misc
import space.kscience.kmath.misc.PermSortTest.Platform.* import space.kscience.kmath.misc.PermSortTest.Platform.*
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import space.kscience.kmath.structures.IntBuffer import space.kscience.kmath.structures.IntBuffer
import space.kscience.kmath.structures.asBuffer import space.kscience.kmath.structures.asBuffer
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertContentEquals import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class PermSortTest { class PermSortTest {
@ -29,9 +28,9 @@ class PermSortTest {
@Test @Test
fun testOnEmptyBuffer() { fun testOnEmptyBuffer() {
val emptyBuffer = IntBuffer(0) {it} val emptyBuffer = IntBuffer(0) {it}
var permutations = emptyBuffer.permSort() var permutations = emptyBuffer.indicesSorted()
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result") assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
permutations = emptyBuffer.permSortDescending() permutations = emptyBuffer.indicesSortedDescending()
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result") assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
} }
@ -47,25 +46,25 @@ class PermSortTest {
@Test @Test
fun testPermSortBy() { fun testPermSortBy() {
val permutations = platforms.permSortBy { it.name } val permutations = platforms.indicesSortedBy { it.name }
val expected = listOf(ANDROID, JS, JVM, NATIVE, WASM) val expected = listOf(ANDROID, JS, JVM, NATIVE, WASM)
assertContentEquals(expected, permutations.map { platforms[it] }, "Ascending PermSort by name") assertContentEquals(expected, permutations.map { platforms[it] }, "Ascending PermSort by name")
} }
@Test @Test
fun testPermSortByDescending() { fun testPermSortByDescending() {
val permutations = platforms.permSortByDescending { it.name } val permutations = platforms.indicesSortedByDescending { it.name }
val expected = listOf(WASM, NATIVE, JVM, JS, ANDROID) val expected = listOf(WASM, NATIVE, JVM, JS, ANDROID)
assertContentEquals(expected, permutations.map { platforms[it] }, "Descending PermSort by name") assertContentEquals(expected, permutations.map { platforms[it] }, "Descending PermSort by name")
} }
@Test @Test
fun testPermSortWith() { fun testPermSortWith() {
var permutations = platforms.permSortWith { p1, p2 -> p1.name.length.compareTo(p2.name.length) } var permutations = platforms.indicesSortedWith { p1, p2 -> p1.name.length.compareTo(p2.name.length) }
val expected = listOf(JS, JVM, WASM, NATIVE, ANDROID) val expected = listOf(JS, JVM, WASM, NATIVE, ANDROID)
assertContentEquals(expected, permutations.map { platforms[it] }, "PermSort using custom ascending comparator") assertContentEquals(expected, permutations.map { platforms[it] }, "PermSort using custom ascending comparator")
permutations = platforms.permSortWith(compareByDescending { it.name.length }) permutations = platforms.indicesSortedWith(compareByDescending { it.name.length })
assertContentEquals(expected.reversed(), permutations.map { platforms[it] }, "PermSort using custom descending comparator") assertContentEquals(expected.reversed(), permutations.map { platforms[it] }, "PermSort using custom descending comparator")
} }
@ -75,7 +74,7 @@ class PermSortTest {
println("Test randomization seed: $seed") println("Test randomization seed: $seed")
val buffer = Random(seed).buffer(bufferSize) val buffer = Random(seed).buffer(bufferSize)
val indices = buffer.permSort() val indices = buffer.indicesSorted()
assertEquals(bufferSize, indices.size) assertEquals(bufferSize, indices.size)
// Ensure no doublon is present in indices // Ensure no doublon is present in indices
@ -87,7 +86,7 @@ class PermSortTest {
assertTrue(current <= next, "Permutation indices not properly sorted") assertTrue(current <= next, "Permutation indices not properly sorted")
} }
val descIndices = buffer.permSortDescending() val descIndices = buffer.indicesSortedDescending()
assertEquals(bufferSize, descIndices.size) assertEquals(bufferSize, descIndices.size)
// Ensure no doublon is present in indices // Ensure no doublon is present in indices
assertEquals(descIndices.toSet().size, descIndices.size) assertEquals(descIndices.toSet().size, descIndices.size)

View File

@ -0,0 +1,28 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.nd
import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test
import kotlin.test.assertEquals
class NdOperationsTest {
@Test
fun roll() {
val structure = DoubleField.ndAlgebra.structureND(5, 5) { index ->
index.sumOf { it.toDouble() }
}
println(StructureND.toString(structure))
val rolled = structure.roll(0,-1)
println(StructureND.toString(rolled))
assertEquals(4.0, rolled[0, 0])
}
}

View File

@ -0,0 +1,32 @@
# Module kmath-coroutines
## Usage
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-coroutines:0.3.0`.
**Gradle Groovy:**
```groovy
repositories {
maven { url 'https://repo.kotlin.link' }
mavenCentral()
}
dependencies {
implementation 'space.kscience:kmath-coroutines:0.3.0'
}
```
**Gradle Kotlin DSL:**
```kotlin
repositories {
maven("https://repo.kotlin.link")
mavenCentral()
}
dependencies {
implementation("space.kscience:kmath-coroutines:0.3.0")
}
```

View File

@ -18,12 +18,12 @@ public class LazyStructureND<out T>(
) : StructureND<T> { ) : StructureND<T> {
private val cache: MutableMap<IntArray, Deferred<T>> = HashMap() private val cache: MutableMap<IntArray, Deferred<T>> = HashMap()
public fun deferred(index: IntArray): Deferred<T> = cache.getOrPut(index) { public fun async(index: IntArray): Deferred<T> = cache.getOrPut(index) {
scope.async(context = Dispatchers.Math) { function(index) } scope.async(context = Dispatchers.Math) { function(index) }
} }
public suspend fun await(index: IntArray): T = deferred(index).await() public suspend fun await(index: IntArray): T = async(index).await()
override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() } override operator fun get(index: IntArray): T = runBlocking { async(index).await() }
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override fun elements(): Sequence<Pair<IntArray, T>> { override fun elements(): Sequence<Pair<IntArray, T>> {
@ -33,8 +33,8 @@ public class LazyStructureND<out T>(
} }
} }
public fun <T> StructureND<T>.deferred(index: IntArray): Deferred<T> = public fun <T> StructureND<T>.async(index: IntArray): Deferred<T> =
if (this is LazyStructureND<T>) deferred(index) else CompletableDeferred(get(index)) if (this is LazyStructureND<T>) this@async.async(index) else CompletableDeferred(get(index))
public suspend fun <T> StructureND<T>.await(index: IntArray): T = public suspend fun <T> StructureND<T>.await(index: IntArray): T =
if (this is LazyStructureND<T>) await(index) else get(index) if (this is LazyStructureND<T>) await(index) else get(index)

View File

@ -0,0 +1,32 @@
# Module kmath-dimensions
A proof of concept module for adding type-safe dimensions to structures
## Usage
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-dimensions:0.3.0`.
**Gradle Groovy:**
```groovy
repositories {
maven { url 'https://repo.kotlin.link' }
mavenCentral()
}
dependencies {
implementation 'space.kscience:kmath-dimensions:0.3.0'
}
```
**Gradle Kotlin DSL:**
```kotlin
repositories {
maven("https://repo.kotlin.link")
mavenCentral()
}
dependencies {
implementation("space.kscience:kmath-dimensions:0.3.0")
}
```

View File

@ -9,17 +9,17 @@ EJML based linear algebra implementation.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-17`. The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0`.
**Gradle:** **Gradle Groovy:**
```gradle ```groovy
repositories { repositories {
maven { url 'https://repo.kotlin.link' } maven { url 'https://repo.kotlin.link' }
mavenCentral() mavenCentral()
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-ejml:0.3.0-dev-17' implementation 'space.kscience:kmath-ejml:0.3.0'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -30,6 +30,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-ejml:0.3.0-dev-17") implementation("space.kscience:kmath-ejml:0.3.0")
} }
``` ```

View File

@ -271,7 +271,9 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, DoubleField, DMatrix
} }
else -> null else -> null
}?.let(type::cast) }?.let{
type.cast(it)
}
} }
/** /**
@ -505,7 +507,9 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace<Float, FloatField, FMatrixRM
} }
else -> null else -> null
}?.let(type::cast) }?.let{
type.cast(it)
}
} }
/** /**
@ -734,7 +738,9 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace<Double, DoubleField, DMatrix
} }
else -> null else -> null
}?.let(type::cast) }?.let{
type.cast(it)
}
} }
/** /**
@ -963,7 +969,9 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace<Float, FloatField, FMatrixSp
} }
else -> null else -> null
}?.let(type::cast) }?.let{
type.cast(it)
}
} }
/** /**

View File

@ -3,6 +3,8 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
@file:OptIn(PerformancePitfall::class)
package space.kscience.kmath.ejml package space.kscience.kmath.ejml
import org.ejml.data.DMatrixRMaj import org.ejml.data.DMatrixRMaj
@ -18,11 +20,11 @@ import kotlin.random.Random
import kotlin.random.asJavaRandom import kotlin.random.asJavaRandom
import kotlin.test.* import kotlin.test.*
@OptIn(PerformancePitfall::class) internal fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T>) {
fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T>) {
assertTrue { StructureND.contentEquals(expected, actual) } assertTrue { StructureND.contentEquals(expected, actual) }
} }
@OptIn(UnstableKMathAPI::class)
internal class EjmlMatrixTest { internal class EjmlMatrixTest {
private val random = Random(0) private val random = Random(0)

View File

@ -9,17 +9,17 @@ Specialization of KMath APIs for Double numbers.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-17`. The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0`.
**Gradle:** **Gradle Groovy:**
```gradle ```groovy
repositories { repositories {
maven { url 'https://repo.kotlin.link' } maven { url 'https://repo.kotlin.link' }
mavenCentral() mavenCentral()
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-for-real:0.3.0-dev-17' implementation 'space.kscience:kmath-for-real:0.3.0'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -30,6 +30,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-for-real:0.3.0-dev-17") implementation("space.kscience:kmath-for-real:0.3.0")
} }
``` ```

View File

@ -21,19 +21,22 @@ readme {
feature( feature(
id = "DoubleVector", id = "DoubleVector",
description = "Numpy-like operations for Buffers/Points",
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleVector.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleVector.kt"
) ){
"Numpy-like operations for Buffers/Points"
}
feature( feature(
id = "DoubleMatrix", id = "DoubleMatrix",
description = "Numpy-like operations for 2d real structures",
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt"
) ){
"Numpy-like operations for 2d real structures"
}
feature( feature(
id = "grids", id = "grids",
description = "Uniform grid generators",
ref = "src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt"
) ){
"Uniform grid generators"
}
} }

View File

@ -11,17 +11,17 @@ Functions and interpolations.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-17`. The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0`.
**Gradle:** **Gradle Groovy:**
```gradle ```groovy
repositories { repositories {
maven { url 'https://repo.kotlin.link' } maven { url 'https://repo.kotlin.link' }
mavenCentral() mavenCentral()
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-functions:0.3.0-dev-17' implementation 'space.kscience:kmath-functions:0.3.0'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -32,6 +32,6 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-functions:0.3.0-dev-17") implementation("space.kscience:kmath-functions:0.3.0")
} }
``` ```

View File

@ -7,6 +7,6 @@ package space.kscience.kmath.functions
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
public typealias UnivariateFunction<T> = (T) -> T public typealias Function1D<T> = (T) -> T
public typealias MultivariateFunction<T> = (Buffer<T>) -> T public typealias FunctionND<T> = (Buffer<T>) -> T

View File

@ -28,6 +28,8 @@ public fun <T : Comparable<T>> PiecewisePolynomial<T>.integrate(algebra: Field<T
/** /**
* Compute definite integral of given [PiecewisePolynomial] piece by piece in a given [range] * Compute definite integral of given [PiecewisePolynomial] piece by piece in a given [range]
* Requires [UnivariateIntegrationNodes] or [IntegrationRange] and [IntegrandMaxCalls] * Requires [UnivariateIntegrationNodes] or [IntegrationRange] and [IntegrandMaxCalls]
*
* TODO use context receiver for algebra
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun <T : Comparable<T>> PiecewisePolynomial<T>.integrate( public fun <T : Comparable<T>> PiecewisePolynomial<T>.integrate(
@ -98,6 +100,7 @@ public object DoubleSplineIntegrator : UnivariateIntegrator<Double> {
} }
} }
@Suppress("unused")
@UnstableKMathAPI @UnstableKMathAPI
public inline val DoubleField.splineIntegrator: UnivariateIntegrator<Double> public inline val DoubleField.splineIntegrator: UnivariateIntegrator<Double>
get() = DoubleSplineIntegrator get() = DoubleSplineIntegrator

View File

@ -9,6 +9,7 @@ package space.kscience.kmath.interpolation
import space.kscience.kmath.data.XYColumnarData import space.kscience.kmath.data.XYColumnarData
import space.kscience.kmath.functions.PiecewisePolynomial import space.kscience.kmath.functions.PiecewisePolynomial
import space.kscience.kmath.functions.asFunction
import space.kscience.kmath.functions.value import space.kscience.kmath.functions.value
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
@ -59,3 +60,33 @@ public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
val pointSet = XYColumnarData.of(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer()) val pointSet = XYColumnarData.of(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer())
return interpolatePolynomials(pointSet) return interpolatePolynomials(pointSet)
} }
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
x: Buffer<T>,
y: Buffer<T>,
): (T) -> T? = interpolatePolynomials(x, y).asFunction(algebra)
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
data: Map<T, T>,
): (T) -> T? = interpolatePolynomials(data).asFunction(algebra)
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
data: List<Pair<T, T>>,
): (T) -> T? = interpolatePolynomials(data).asFunction(algebra)
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
x: Buffer<T>,
y: Buffer<T>,
defaultValue: T,
): (T) -> T = interpolatePolynomials(x, y).asFunction(algebra, defaultValue)
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
data: Map<T, T>,
defaultValue: T,
): (T) -> T = interpolatePolynomials(data).asFunction(algebra, defaultValue)
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolate(
data: List<Pair<T, T>>,
defaultValue: T,
): (T) -> T = interpolatePolynomials(data).asFunction(algebra, defaultValue)

View File

@ -22,6 +22,7 @@ internal fun <T : Comparable<T>> insureSorted(points: XYColumnarData<*, T, *>) {
* Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java * Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java
*/ */
public class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> { public class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
override fun interpolatePolynomials(points: XYColumnarData<T, T, T>): PiecewisePolynomial<T> = algebra { override fun interpolatePolynomials(points: XYColumnarData<T, T, T>): PiecewisePolynomial<T> = algebra {
require(points.size > 0) { "Point array should not be empty" } require(points.size > 0) { "Point array should not be empty" }
@ -37,3 +38,6 @@ public class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T
} }
} }
} }
public val <T : Comparable<T>> Field<T>.linearInterpolator: LinearInterpolator<T>
get() = LinearInterpolator(this)

View File

@ -63,8 +63,8 @@ public class SplineInterpolator<T : Comparable<T>>(
//Shift coefficients to represent absolute polynomial instead of one with an offset //Shift coefficients to represent absolute polynomial instead of one with an offset
val polynomial = Polynomial( val polynomial = Polynomial(
a - b * x0 + c * x02 - d * x03, a - b * x0 + c * x02 - d * x03,
b - 2*c*x0 + 3*d*x02, b - 2 * c * x0 + 3 * d * x02,
c - 3*d*x0, c - 3 * d * x0,
d d
) )
cOld = c cOld = c
@ -72,8 +72,12 @@ public class SplineInterpolator<T : Comparable<T>>(
} }
} }
} }
public companion object {
public val double: SplineInterpolator<Double> = SplineInterpolator(DoubleField, ::DoubleBuffer)
}
} }
public fun <T : Comparable<T>> Field<T>.splineInterpolator(
bufferFactory: MutableBufferFactory<T>,
): SplineInterpolator<T> = SplineInterpolator(this, bufferFactory)
public val DoubleField.splineInterpolator: SplineInterpolator<Double>
get() = SplineInterpolator(this, ::DoubleBuffer)

View File

@ -5,8 +5,6 @@
package space.kscience.kmath.interpolation package space.kscience.kmath.interpolation
import space.kscience.kmath.functions.PiecewisePolynomial
import space.kscience.kmath.functions.asFunction
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -21,8 +19,8 @@ internal class LinearInterpolatorTest {
3.0 to 4.0 3.0 to 4.0
) )
val polynomial: PiecewisePolynomial<Double> = LinearInterpolator(DoubleField).interpolatePolynomials(data) //val polynomial: PiecewisePolynomial<Double> = DoubleField.linearInterpolator.interpolatePolynomials(data)
val function = polynomial.asFunction(DoubleField) val function = DoubleField.linearInterpolator.interpolate(data)
assertEquals(null, function(-1.0)) assertEquals(null, function(-1.0))
assertEquals(0.5, function(0.5)) assertEquals(0.5, function(0.5))
assertEquals(2.0, function(1.5)) assertEquals(2.0, function(1.5))

View File

@ -5,8 +5,6 @@
package space.kscience.kmath.interpolation package space.kscience.kmath.interpolation
import space.kscience.kmath.functions.PiecewisePolynomial
import space.kscience.kmath.functions.asFunction
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import kotlin.math.PI import kotlin.math.PI
import kotlin.math.sin import kotlin.math.sin
@ -21,9 +19,10 @@ internal class SplineInterpolatorTest {
x to sin(x) x to sin(x)
} }
val polynomial: PiecewisePolynomial<Double> = SplineInterpolator.double.interpolatePolynomials(data) //val polynomial: PiecewisePolynomial<Double> = DoubleField.splineInterpolator.interpolatePolynomials(data)
val function = DoubleField.splineInterpolator.interpolate(data, Double.NaN)
val function = polynomial.asFunction(DoubleField, Double.NaN)
assertEquals(Double.NaN, function(-1.0)) assertEquals(Double.NaN, function(-1.0))
assertEquals(sin(0.5), function(0.5), 0.1) assertEquals(sin(0.5), function(0.5), 0.1)
assertEquals(sin(1.5), function(1.5), 0.1) assertEquals(sin(1.5), function(1.5), 0.1)

32
kmath-geometry/README.md Normal file
View File

@ -0,0 +1,32 @@
# Module kmath-geometry
## Usage
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-geometry:0.3.0`.
**Gradle Groovy:**
```groovy
repositories {
maven { url 'https://repo.kotlin.link' }
mavenCentral()
}
dependencies {
implementation 'space.kscience:kmath-geometry:0.3.0'
}
```
**Gradle Kotlin DSL:**
```kotlin
repositories {
maven("https://repo.kotlin.link")
mavenCentral()
}
dependencies {
implementation("space.kscience:kmath-geometry:0.3.0")
}
```

View File

@ -6,7 +6,7 @@ plugins {
kotlin.sourceSets.commonMain { kotlin.sourceSets.commonMain {
dependencies { dependencies {
api(project(":kmath-core")) api(projects.kmath.kmathComplex)
} }
} }

View File

@ -6,12 +6,10 @@
package space.kscience.kmath.geometry package space.kscience.kmath.geometry
import space.kscience.kmath.linear.Point import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.math.sqrt import kotlin.math.sqrt
@OptIn(UnstableKMathAPI::class)
public interface Vector2D : Point<Double>, Vector { public interface Vector2D : Point<Double>, Vector {
public val x: Double public val x: Double
public val y: Double public val y: Double
@ -29,7 +27,6 @@ public interface Vector2D : Point<Double>, Vector {
public val Vector2D.r: Double public val Vector2D.r: Double
get() = Euclidean2DSpace { norm() } get() = Euclidean2DSpace { norm() }
@Suppress("FunctionName")
public fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y) public fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y)
private data class Vector2DImpl( private data class Vector2DImpl(

View File

@ -6,12 +6,11 @@
package space.kscience.kmath.geometry package space.kscience.kmath.geometry
import space.kscience.kmath.linear.Point import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer
import kotlin.math.sqrt import kotlin.math.sqrt
@OptIn(UnstableKMathAPI::class)
public interface Vector3D : Point<Double>, Vector { public interface Vector3D : Point<Double>, Vector {
public val x: Double public val x: Double
public val y: Double public val y: Double
@ -31,6 +30,19 @@ public interface Vector3D : Point<Double>, Vector {
@Suppress("FunctionName") @Suppress("FunctionName")
public fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z) public fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z)
public fun Buffer<Double>.asVector3D(): Vector3D = object : Vector3D {
init {
require(this@asVector3D.size == 3) { "Buffer of size 3 is required for Vector3D" }
}
override val x: Double get() = this@asVector3D[0]
override val y: Double get() = this@asVector3D[1]
override val z: Double get() = this@asVector3D[2]
override fun toString(): String = this@asVector3D.toString()
}
public val Vector3D.r: Double get() = Euclidean3DSpace { norm() } public val Vector3D.r: Double get() = Euclidean3DSpace { norm() }
private data class Vector3DImpl( private data class Vector3DImpl(

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.geometry package space.kscience.kmath.geometry
//TODO move vector to receiver
/** /**
* Project vector onto a line. * Project vector onto a line.
* @param vector to project * @param vector to project

View File

@ -0,0 +1,116 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.geometry
import space.kscience.kmath.complex.Quaternion
import space.kscience.kmath.complex.QuaternionField
import space.kscience.kmath.complex.reciprocal
import space.kscience.kmath.linear.LinearSpace
import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.linear.linearSpace
import space.kscience.kmath.linear.matrix
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
import kotlin.math.pow
import kotlin.math.sqrt
internal fun Vector3D.toQuaternion(): Quaternion = Quaternion(0.0, x, y, z)
/**
* Angle in radians denoted by this quaternion rotation
*/
public val Quaternion.theta: Double get() = kotlin.math.acos(w) * 2
/**
* An axis of quaternion rotation
*/
public val Quaternion.vector: Vector3D
get() {
val sint2 = sqrt(1 - w * w)
return object : Vector3D {
override val x: Double get() = this@vector.x / sint2
override val y: Double get() = this@vector.y / sint2
override val z: Double get() = this@vector.z / sint2
override fun toString(): String = listOf(x, y, z).toString()
}
}
/**
* Rotate a vector in a [Euclidean3DSpace]
*/
public fun Euclidean3DSpace.rotate(vector: Vector3D, q: Quaternion): Vector3D = with(QuaternionField) {
val p = vector.toQuaternion()
(q * p * q.reciprocal).vector
}
/**
* Use a composition of quaternions to create a rotation
*/
public fun Euclidean3DSpace.rotate(vector: Vector3D, composition: QuaternionField.() -> Quaternion): Vector3D =
rotate(vector, QuaternionField.composition())
public fun Euclidean3DSpace.rotate(vector: Vector3D, matrix: Matrix<Double>): Vector3D {
require(matrix.colNum == 3 && matrix.rowNum == 3) { "Square 3x3 rotation matrix is required" }
return with(DoubleField.linearSpace) { matrix.dot(vector).asVector3D() }
}
/**
* Convert a [Quaternion] to a rotation matrix
*/
@OptIn(UnstableKMathAPI::class)
public fun Quaternion.toRotationMatrix(
linearSpace: LinearSpace<Double, *> = DoubleField.linearSpace,
): Matrix<Double> {
val s = QuaternionField.norm(this).pow(-2)
return linearSpace.matrix(3, 3)(
1.0 - 2 * s * (y * y + z * z), 2 * s * (x * y - z * w), 2 * s * (x * z + y * w),
2 * s * (x * y + z * w), 1.0 - 2 * s * (x * x + z * z), 2 * s * (y * z - x * w),
2 * s * (x * z - y * w), 2 * s * (y * z + x * w), 1.0 - 2 * s * (x * x + y * y)
)
}
/**
* taken from https://www.euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/
*/
public fun Quaternion.Companion.fromRotationMatrix(matrix: Matrix<Double>): Quaternion {
require(matrix.colNum == 3 && matrix.rowNum == 3) { "Rotation matrix should be 3x3 but is ${matrix.rowNum}x${matrix.colNum}" }
val trace = matrix[0, 0] + matrix[1, 1] + matrix[2, 2]
return if (trace > 0) {
val s = sqrt(trace + 1.0) * 2 // S=4*qw
Quaternion(
w = 0.25 * s,
x = (matrix[2, 1] - matrix[1, 2]) / s,
y = (matrix[0, 2] - matrix[2, 0]) / s,
z = (matrix[1, 0] - matrix[0, 1]) / s,
)
} else if ((matrix[0, 0] > matrix[1, 1]) && (matrix[0, 0] > matrix[2, 2])) {
val s = sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 // S=4*qx
Quaternion(
w = (matrix[2, 1] - matrix[1, 2]) / s,
x = 0.25 * s,
y = (matrix[0, 1] + matrix[1, 0]) / s,
z = (matrix[0, 2] + matrix[2, 0]) / s,
)
} else if (matrix[1, 1] > matrix[2, 2]) {
val s = sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 // S=4*qy
Quaternion(
w = (matrix[0, 2] - matrix[2, 0]) / s,
x = (matrix[0, 1] + matrix[1, 0]) / s,
y = 0.25 * s,
z = (matrix[1, 2] + matrix[2, 1]) / s,
)
} else {
val s = sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 // S=4*qz
Quaternion(
w = (matrix[1, 0] - matrix[0, 1]) / s,
x = (matrix[0, 2] + matrix[2, 0]) / s,
y = (matrix[1, 2] + matrix[2, 1]) / s,
z = 0.25 * s,
)
}
}

View File

@ -0,0 +1,35 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.geometry
import space.kscience.kmath.complex.Quaternion
import space.kscience.kmath.complex.normalized
import space.kscience.kmath.testutils.assertBufferEquals
import kotlin.test.Test
class RotationTest {
@Test
fun rotations() = with(Euclidean3DSpace) {
val vector = Vector3D(1.0, 1.0, 1.0)
val q = Quaternion(1.0, 2.0, -3.0, 4.0).normalized()
val rotatedByQ = rotate(vector, q)
val matrix = q.toRotationMatrix()
val rotatedByM = rotate(vector,matrix)
assertBufferEquals(rotatedByQ, rotatedByM, 1e-4)
}
@Test
fun rotationConversion() {
val q = Quaternion(1.0, 2.0, -3.0, 4.0).normalized()
val matrix = q.toRotationMatrix()
assertBufferEquals(q, Quaternion.fromRotationMatrix(matrix))
}
}

View File

@ -0,0 +1,32 @@
# Module kmath-histograms
## Usage
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-histograms:0.3.0`.
**Gradle Groovy:**
```groovy
repositories {
maven { url 'https://repo.kotlin.link' }
mavenCentral()
}
dependencies {
implementation 'space.kscience:kmath-histograms:0.3.0'
}
```
**Gradle Kotlin DSL:**
```kotlin
repositories {
maven("https://repo.kotlin.link")
mavenCentral()
}
dependencies {
implementation("space.kscience:kmath-histograms:0.3.0")
}
```

View File

@ -1,22 +1,22 @@
plugins { plugins {
kotlin("multiplatform") id("ru.mipt.npm.gradle.mpp")
id("ru.mipt.npm.gradle.common")
id("ru.mipt.npm.gradle.native") id("ru.mipt.npm.gradle.native")
} }
kscience { //apply(plugin = "kotlinx-atomicfu")
useAtomic()
}
kotlin.sourceSets { kotlin.sourceSets {
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
api(npmlibs.atomicfu)
} }
} }
commonTest { commonTest {
dependencies { dependencies {
implementation(project(":kmath-for-real")) implementation(project(":kmath-for-real"))
implementation(projects.kmath.kmathStat)
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.6.0")
} }
} }
} }

View File

@ -18,7 +18,8 @@ public interface Counter<T : Any> {
public val value: T public val value: T
public companion object { public companion object {
public fun double(): ObjectCounter<Double> = ObjectCounter(DoubleField) public fun ofDouble(): ObjectCounter<Double> = ObjectCounter(DoubleField)
public fun <T: Any> of(group: Group<T>): ObjectCounter<T> = ObjectCounter(group)
} }
} }

View File

@ -1,130 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.histogram
import space.kscience.kmath.domains.Domain
import space.kscience.kmath.domains.HyperSquareDomain
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.*
import kotlin.math.floor
public class DoubleHistogramSpace(
private val lower: Buffer<Double>,
private val upper: Buffer<Double>,
private val binNums: IntArray = IntArray(lower.size) { 20 },
) : IndexedHistogramSpace<Double, Double> {
init {
// argument checks
require(lower.size == upper.size) { "Dimension mismatch in histogram lower and upper limits." }
require(lower.size == binNums.size) { "Dimension mismatch in bin count." }
require(!lower.indices.any { upper[it] - lower[it] < 0 }) { "Range for one of axis is not strictly positive" }
}
public val dimension: Int get() = lower.size
override val shape: IntArray = IntArray(binNums.size) { binNums[it] + 2 }
override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape)
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
/**
* Get internal [StructureND] bin index for given axis
*/
private fun getIndex(axis: Int, value: Double): Int = when {
value >= upper[axis] -> binNums[axis] + 1 // overflow
value < lower[axis] -> 0 // underflow
else -> floor((value - lower[axis]) / binSize[axis]).toInt()
}
override fun getIndex(point: Buffer<Double>): IntArray = IntArray(dimension) {
getIndex(it, point[it])
}
@OptIn(UnstableKMathAPI::class)
override fun getDomain(index: IntArray): Domain<Double> {
val lowerBoundary = index.mapIndexed { axis, i ->
when (i) {
0 -> Double.NEGATIVE_INFINITY
shape[axis] - 1 -> upper[axis]
else -> lower[axis] + (i.toDouble()) * binSize[axis]
}
}.asBuffer()
val upperBoundary = index.mapIndexed { axis, i ->
when (i) {
0 -> lower[axis]
shape[axis] - 1 -> Double.POSITIVE_INFINITY
else -> lower[axis] + (i.toDouble() + 1) * binSize[axis]
}
}.asBuffer()
return HyperSquareDomain(lowerBoundary, upperBoundary)
}
override fun produceBin(index: IntArray, value: Double): Bin<Double> {
val domain = getDomain(index)
return DomainBin(domain, value)
}
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
val ndCounter = StructureND.auto(shape) { Counter.double() }
val hBuilder = HistogramBuilder<Double> { point, value ->
val index = getIndex(point)
ndCounter[index].add(value.toDouble())
}
hBuilder.apply(builder)
val values: BufferND<Double> = ndCounter.mapToBuffer { it.value }
return IndexedHistogram(this, values)
}
override fun IndexedHistogram<Double, Double>.unaryMinus(): IndexedHistogram<Double, Double> = this * (-1)
public companion object {
/**
* Use it like
* ```
*FastHistogram.fromRanges(
* (-1.0..1.0),
* (-1.0..1.0)
*)
*```
*/
public fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): DoubleHistogramSpace = DoubleHistogramSpace(
ranges.map(ClosedFloatingPointRange<Double>::start).asBuffer(),
ranges.map(ClosedFloatingPointRange<Double>::endInclusive).asBuffer()
)
/**
* Use it like
* ```
*FastHistogram.fromRanges(
* (-1.0..1.0) to 50,
* (-1.0..1.0) to 32
*)
*```
*/
public fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): DoubleHistogramSpace =
DoubleHistogramSpace(
ListBuffer(
ranges
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
.map(ClosedFloatingPointRange<Double>::start)
),
ListBuffer(
ranges
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
.map(ClosedFloatingPointRange<Double>::endInclusive)
),
ranges.map(Pair<ClosedFloatingPointRange<Double>, Int>::second).toIntArray()
)
}
}

View File

@ -13,14 +13,23 @@ import space.kscience.kmath.structures.asBuffer
/** /**
* The binned data element. Could be a histogram bin with a number of counts or an artificial construct. * The binned data element. Could be a histogram bin with a number of counts or an artificial construct.
*/ */
public interface Bin<in T : Any> : Domain<T> { public interface Bin<in T : Any, out V> : Domain<T> {
/** /**
* The value of this bin. * The value of this bin.
*/ */
public val value: Number public val binValue: V
} }
public interface Histogram<in T : Any, out B : Bin<T>> { /**
* A simple histogram bin based on domain
*/
public data class DomainBin<in T : Comparable<T>, D : Domain<T>, out V>(
public val domain: D,
override val binValue: V,
) : Bin<T, V>, Domain<T> by domain
public interface Histogram<in T : Any, out V, out B : Bin<T, V>> {
/** /**
* Find existing bin, corresponding to given coordinates * Find existing bin, corresponding to given coordinates
*/ */
@ -32,29 +41,38 @@ public interface Histogram<in T : Any, out B : Bin<T>> {
public val dimension: Int public val dimension: Int
public val bins: Iterable<B> public val bins: Iterable<B>
public companion object {
//A discoverability root
}
} }
public fun interface HistogramBuilder<in T : Any> { public interface HistogramBuilder<in T : Any, V : Any> {
/** /**
* Increment appropriate bin * The default value increment for a bin
*/ */
public fun putValue(point: Point<out T>, value: Number) public val defaultValue: V
/**
* Increment appropriate bin with given value
*/
public fun putValue(point: Point<out T>, value: V = defaultValue)
} }
public fun <T : Any, B : Bin<T>> HistogramBuilder<T>.put(point: Point<out T>): Unit = putValue(point, 1.0) public fun <T : Any> HistogramBuilder<T, *>.put(point: Point<out T>): Unit = putValue(point)
public fun <T : Any> HistogramBuilder<T>.put(vararg point: T): Unit = put(point.asBuffer()) public fun <T : Any> HistogramBuilder<T, *>.put(vararg point: T): Unit = put(point.asBuffer())
public fun HistogramBuilder<Double>.put(vararg point: Number): Unit = public fun HistogramBuilder<Double, *>.put(vararg point: Number): Unit =
put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray()))
public fun HistogramBuilder<Double>.put(vararg point: Double): Unit = put(DoubleBuffer(point)) public fun HistogramBuilder<Double, *>.put(vararg point: Double): Unit = put(DoubleBuffer(point))
public fun <T : Any> HistogramBuilder<T>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) } public fun <T : Any> HistogramBuilder<T, *>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) }
/** /**
* Pass a sequence builder into histogram * Pass a sequence builder into histogram
*/ */
public fun <T : Any> HistogramBuilder<T>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit = public fun <T : Any> HistogramBuilder<T, *>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit =
fill(sequence(block).asIterable()) fill(sequence(block).asIterable())

View File

@ -0,0 +1,64 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.histogram
import space.kscience.kmath.domains.Domain1D
import space.kscience.kmath.domains.center
import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.asSequence
import space.kscience.kmath.structures.Buffer
/**
* A univariate bin based on a range
*
* @property binValue The value of histogram including weighting
*/
@UnstableKMathAPI
public data class Bin1D<T : Comparable<T>, out V>(
public val domain: Domain1D<T>,
override val binValue: V,
) : Bin<T, V>, ClosedRange<T> by domain.range {
override val dimension: Int get() = 1
override fun contains(point: Buffer<T>): Boolean = point.size == 1 && contains(point[0])
}
@OptIn(UnstableKMathAPI::class)
public interface Histogram1D<T : Comparable<T>, V> : Histogram<T, V, Bin1D<T, V>> {
override val dimension: Int get() = 1
public operator fun get(value: T): Bin1D<T, V>?
override operator fun get(point: Buffer<T>): Bin1D<T, V>? = get(point[0])
}
public interface Histogram1DBuilder<in T : Any, V : Any> : HistogramBuilder<T, V> {
/**
* Thread safe put operation
*/
public fun putValue(at: T, value: V = defaultValue)
override fun putValue(point: Point<out T>, value: V) {
require(point.size == 1) { "Only points with single value could be used in Histogram1D" }
putValue(point[0], value)
}
}
@UnstableKMathAPI
public fun Histogram1DBuilder<Double, *>.fill(items: Iterable<Double>): Unit =
items.forEach(this::putValue)
@UnstableKMathAPI
public fun Histogram1DBuilder<Double, *>.fill(array: DoubleArray): Unit =
array.forEach(this::putValue)
@UnstableKMathAPI
public fun <T : Any> Histogram1DBuilder<T, *>.fill(buffer: Buffer<T>): Unit =
buffer.asSequence().forEach(this::putValue)
@OptIn(UnstableKMathAPI::class)
public val Bin1D<Double, *>.center: Double get() = domain.center

View File

@ -0,0 +1,76 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.histogram
import space.kscience.kmath.domains.Domain
import space.kscience.kmath.linear.Point
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.FieldOpsND
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke
/**
* @param T the type of the argument space
* @param V the type of bin value
*/
public class HistogramND<T : Comparable<T>, D : Domain<T>, V : Any>(
public val group: HistogramGroupND<T, D, V>,
internal val values: StructureND<V>,
) : Histogram<T, V, DomainBin<T, D, V>> {
override fun get(point: Point<T>): DomainBin<T, D, V>? {
val index = group.getIndexOrNull(point) ?: return null
return group.produceBin(index, values[index])
}
override val dimension: Int get() = group.shape.size
override val bins: Iterable<DomainBin<T, D, V>>
get() = DefaultStrides(group.shape).asSequence().map {
group.produceBin(it, values[it])
}.asIterable()
}
/**
* A space for producing histograms with values in a NDStructure
*/
public interface HistogramGroupND<T : Comparable<T>, D : Domain<T>, V : Any> :
Group<HistogramND<T, D, V>>, ScaleOperations<HistogramND<T, D, V>> {
public val shape: Shape
public val valueAlgebraND: FieldOpsND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
/**
* Resolve index of the bin including given [point]. Return null if point is outside histogram area
*/
public fun getIndexOrNull(point: Point<T>): IntArray?
/**
* Get a bin domain represented by given index
*/
public fun getDomain(index: IntArray): Domain<T>?
public fun produceBin(index: IntArray, value: V): DomainBin<T, D, V>
public fun produce(builder: HistogramBuilder<T, V>.() -> Unit): HistogramND<T, D, V>
override fun add(left: HistogramND<T, D, V>, right: HistogramND<T, D, V>): HistogramND<T, D, V> {
require(left.group == this && right.group == this) {
"A histogram belonging to a different group cannot be operated."
}
return HistogramND(this, valueAlgebraND { left.values + right.values })
}
override fun scale(a: HistogramND<T, D, V>, value: Double): HistogramND<T, D, V> {
require(a.group == this) { "A histogram belonging to a different group cannot be operated." }
return HistogramND(this, valueAlgebraND { a.values * value })
}
override val zero: HistogramND<T, D, V> get() = produce { }
}

View File

@ -1,82 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.histogram
import space.kscience.kmath.domains.Domain
import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.FieldND
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke
/**
* A simple histogram bin based on domain
*/
public data class DomainBin<in T : Comparable<T>>(
public val domain: Domain<T>,
override val value: Number,
) : Bin<T>, Domain<T> by domain
@OptIn(UnstableKMathAPI::class)
public class IndexedHistogram<T : Comparable<T>, V : Any>(
public val context: IndexedHistogramSpace<T, V>,
public val values: StructureND<V>,
) : Histogram<T, Bin<T>> {
override fun get(point: Point<T>): Bin<T>? {
val index = context.getIndex(point) ?: return null
return context.produceBin(index, values[index])
}
override val dimension: Int get() = context.shape.size
override val bins: Iterable<Bin<T>>
get() = DefaultStrides(context.shape).asSequence().map {
context.produceBin(it, values[it])
}.asIterable()
}
/**
* A space for producing histograms with values in a NDStructure
*/
public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
: Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> {
//public val valueSpace: Space<V>
public val shape: Shape
public val histogramValueSpace: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
/**
* Resolve index of the bin including given [point]
*/
public fun getIndex(point: Point<T>): IntArray?
/**
* Get a bin domain represented by given index
*/
public fun getDomain(index: IntArray): Domain<T>?
public fun produceBin(index: IntArray, value: V): Bin<T>
public fun produce(builder: HistogramBuilder<T>.() -> Unit): IndexedHistogram<T, V>
override fun add(left: IndexedHistogram<T, V>, right: IndexedHistogram<T, V>): IndexedHistogram<T, V> {
require(left.context == this) { "Can't operate on a histogram produced by external space" }
require(right.context == this) { "Can't operate on a histogram produced by external space" }
return IndexedHistogram(this, histogramValueSpace { left.values + right.values })
}
override fun scale(a: IndexedHistogram<T, V>, value: Double): IndexedHistogram<T, V> {
require(a.context == this) { "Can't operate on a histogram produced by external space" }
return IndexedHistogram(this, histogramValueSpace { a.values * value })
}
override val zero: IndexedHistogram<T, V> get() = produce { }
}

View File

@ -0,0 +1,156 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.histogram
import space.kscience.kmath.domains.DoubleDomain1D
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer
import kotlin.math.floor
@OptIn(UnstableKMathAPI::class)
public class UniformHistogram1D<V : Any>(
public val group: UniformHistogram1DGroup<V, *>,
internal val values: Map<Int, V>,
) : Histogram1D<Double, V> {
private val startPoint get() = group.startPoint
private val binSize get() = group.binSize
private fun produceBin(index: Int, value: V): Bin1D<Double, V> {
val domain = DoubleDomain1D((startPoint + index * binSize)..(startPoint + (index + 1) * binSize))
return Bin1D(domain, value)
}
override val bins: Iterable<Bin1D<Double, V>> get() = values.map { produceBin(it.key, it.value) }
override fun get(value: Double): Bin1D<Double, V>? {
val index: Int = group.getIndex(value)
val v = values[index]
return v?.let { produceBin(index, it) }
}
}
/**
* An algebra for uniform histograms in 1D real space
*/
public class UniformHistogram1DGroup<V : Any, A>(
public val valueAlgebra: A,
public val binSize: Double,
public val startPoint: Double = 0.0,
) : Group<Histogram1D<Double, V>>, ScaleOperations<Histogram1D<Double, V>> where A : Ring<V>, A : ScaleOperations<V> {
override val zero: UniformHistogram1D<V> = UniformHistogram1D(this, emptyMap())
/**
* Get index of a bin
*/
@PublishedApi
internal fun getIndex(at: Double): Int = floor((at - startPoint) / binSize).toInt()
override fun add(
left: Histogram1D<Double, V>,
right: Histogram1D<Double, V>,
): UniformHistogram1D<V> = valueAlgebra {
val leftUniform = produceFrom(left)
val rightUniform = produceFrom(right)
val keys = leftUniform.values.keys + rightUniform.values.keys
UniformHistogram1D(
this@UniformHistogram1DGroup,
keys.associateWith {
(leftUniform.values[it] ?: valueAlgebra.zero) + (rightUniform.values[it] ?: valueAlgebra.zero)
}
)
}
override fun Histogram1D<Double, V>.unaryMinus(): UniformHistogram1D<V> = valueAlgebra {
UniformHistogram1D(this@UniformHistogram1DGroup, produceFrom(this@unaryMinus).values.mapValues { -it.value })
}
override fun scale(
a: Histogram1D<Double, V>,
value: Double,
): UniformHistogram1D<V> = UniformHistogram1D(
this@UniformHistogram1DGroup,
produceFrom(a).values.mapValues { valueAlgebra.scale(it.value, value) }
)
/**
* Fill histogram.
*/
public inline fun produce(block: Histogram1DBuilder<Double, V>.() -> Unit): UniformHistogram1D<V> {
val map = HashMap<Int, V>()
val builder = object : Histogram1DBuilder<Double, V> {
override val defaultValue: V get() = valueAlgebra.zero
override fun putValue(at: Double, value: V) {
val index = getIndex(at)
map[index] = with(valueAlgebra) { (map[index] ?: zero) + one }
}
}
builder.block()
return UniformHistogram1D(this, map)
}
/**
* Re-bin given histogram to be compatible if exiting bin fully falls inside existing bin, this bin value
* is increased by one. If not, all bins including values from this bin are increased by fraction
* (conserving the norming).
*/
@OptIn(UnstableKMathAPI::class)
public fun produceFrom(
histogram: Histogram1D<Double, V>,
): UniformHistogram1D<V> = if ((histogram as? UniformHistogram1D)?.group == this) {
histogram
} else {
val map = HashMap<Int, V>()
histogram.bins.forEach { bin ->
val range = bin.domain.range
val indexOfLeft = getIndex(range.start)
val indexOfRight = getIndex(range.endInclusive)
val numBins = indexOfRight - indexOfLeft + 1
for (i in indexOfLeft..indexOfRight) {
map[indexOfLeft] = with(valueAlgebra) {
(map[indexOfLeft] ?: zero) + bin.binValue / numBins
}
}
}
UniformHistogram1D(this, map)
}
}
public fun <V : Any, A> Histogram.Companion.uniform1D(
valueAlgebra: A,
binSize: Double,
startPoint: Double = 0.0,
): UniformHistogram1DGroup<V, A> where A : Ring<V>, A : ScaleOperations<V> =
UniformHistogram1DGroup(valueAlgebra, binSize, startPoint)
@UnstableKMathAPI
public fun <V : Any> UniformHistogram1DGroup<V, *>.produce(
buffer: Buffer<Double>,
): UniformHistogram1D<V> = produce { fill(buffer) }
/**
* Map of bin centers to bin values
*/
@OptIn(UnstableKMathAPI::class)
public val <V : Any> UniformHistogram1D<V>.binValues: Map<Double, V>
get() = bins.associate { it.center to it.binValue }
//TODO add normalized values inside Field-based histogram spaces with context receivers
///**
// * Map of bin centers to normalized bin values (bin size as normalization)
// */
//@OptIn(UnstableKMathAPI::class)
//public val <V : Any> UniformHistogram1D<V>.binValuesNormalized: Map<Double, V>
// get() = group.valueAlgebra {
// bins.associate { it.center to it.binValue / group.binSize }
// }

Some files were not shown because too many files have changed in this diff Show More