diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 98f4d2ec2..f2c76a354 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,8 +2,8 @@ name: Gradle build on: push: + branches: [ dev, master ] pull_request: - types: [opened, edited] jobs: build: @@ -11,23 +11,22 @@ jobs: matrix: os: [ macOS-latest, windows-latest ] runs-on: ${{matrix.os}} - timeout-minutes: 30 + timeout-minutes: 40 steps: - name: Checkout the repo uses: actions/checkout@v2 - name: Set up JDK 11 uses: DeLaGuardo/setup-graalvm@4.0 with: - graalvm: 21.1.0 + graalvm: 21.2.0 java: java11 arch: amd64 - - name: Add msys to path - if: matrix.os == 'windows-latest' - run: SETX PATH "%PATH%;C:\msys64\mingw64\bin" - name: Cache gradle uses: actions/cache@v2 with: - path: ~/.gradle/caches + path: | + ~/.gradle/caches + ~/.gradle/wrapper key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }} restore-keys: | ${{ runner.os }}-gradle- @@ -38,5 +37,7 @@ jobs: key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }} restore-keys: | ${{ runner.os }}-gradle- + - name: Gradle Wrapper Validation + uses: gradle/wrapper-validation-action@v1.0.4 - name: Build - run: ./gradlew build --no-daemon --stacktrace + run: ./gradlew build --build-cache --no-daemon --stacktrace diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index 86fdac6a6..23ed54357 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -2,23 +2,27 @@ name: Dokka publication on: push: - branches: - - master + branches: [ master ] jobs: build: runs-on: ubuntu-20.04 + timeout-minutes: 40 steps: - - name: Checkout the repo - uses: actions/checkout@v2 - - name: Set up JDK 11 - uses: actions/setup-java@v1 + - uses: actions/checkout@v2 + - uses: DeLaGuardo/setup-graalvm@4.0 with: - java-version: 11 - - name: Build - run: ./gradlew dokkaHtmlMultiModule --no-daemon --no-parallel --stacktrace - - name: Deploy to GitHub Pages - uses: JamesIves/github-pages-deploy-action@4.1.0 + graalvm: 21.2.0 + java: java11 + arch: amd64 + - uses: actions/cache@v2 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }} + restore-keys: | + ${{ runner.os }}-gradle- + - run: ./gradlew dokkaHtmlMultiModule --build-cache --no-daemon --no-parallel --stacktrace + - uses: JamesIves/github-pages-deploy-action@4.1.0 with: branch: gh-pages folder: build/dokka/htmlMultiModule diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c5c110e89..fa3cb700c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -3,8 +3,7 @@ name: Gradle publish on: workflow_dispatch: release: - types: - - created + types: [ created ] jobs: publish: @@ -20,16 +19,15 @@ jobs: - name: Set up JDK 11 uses: DeLaGuardo/setup-graalvm@4.0 with: - graalvm: 21.1.0 + graalvm: 21.2.0 java: java11 arch: amd64 - - name: Add msys to path - if: matrix.os == 'windows-latest' - run: SETX PATH "%PATH%;C:\msys64\mingw64\bin" - name: Cache gradle uses: actions/cache@v2 with: - path: ~/.gradle/caches + path: | + ~/.gradle/caches + ~/.gradle/wrapper key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }} restore-keys: | ${{ runner.os }}-gradle- @@ -40,22 +38,18 @@ jobs: key: ${{ runner.os }}-gradle-${{ hashFiles('*.gradle.kts') }} restore-keys: | ${{ runner.os }}-gradle- + - name: Gradle Wrapper Validation + uses: gradle/wrapper-validation-action@v1.0.4 - name: Publish Windows Artifacts if: matrix.os == 'windows-latest' + shell: cmd run: > - ./gradlew release --no-daemon - -Ppublishing.enabled=true - -Ppublishing.github.user=${{ secrets.PUBLISHING_GITHUB_USER }} - -Ppublishing.github.token=${{ secrets.PUBLISHING_GITHUB_TOKEN }} - -Ppublishing.space.user=${{ secrets.PUBLISHING_SPACE_USER }} - -Ppublishing.space.token=${{ secrets.PUBLISHING_SPACE_TOKEN }} + ./gradlew release --no-daemon --build-cache -Ppublishing.enabled=true + -Ppublishing.space.user=${{ secrets.SPACE_APP_ID }} + -Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }} - name: Publish Mac Artifacts if: matrix.os == 'macOS-latest' run: > - ./gradlew release --no-daemon - -Ppublishing.enabled=true - -Ppublishing.platform=macosX64 - -Ppublishing.github.user=${{ secrets.PUBLISHING_GITHUB_USER }} - -Ppublishing.github.token=${{ secrets.PUBLISHING_GITHUB_TOKEN }} - -Ppublishing.space.user=${{ secrets.PUBLISHING_SPACE_USER }} - -Ppublishing.space.token=${{ secrets.PUBLISHING_SPACE_TOKEN }} + ./gradlew release --no-daemon --build-cache -Ppublishing.enabled=true -Ppublishing.platform=macosX64 + -Ppublishing.space.user=${{ secrets.SPACE_APP_ID }} + -Ppublishing.space.token=${{ secrets.SPACE_APP_SECRET }} diff --git a/.idea/copyright/kmath.xml b/.idea/copyright/kmath.xml deleted file mode 100644 index 17e44e4d0..000000000 --- a/.idea/copyright/kmath.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - diff --git a/.idea/copyright/profiles_settings.xml b/.idea/copyright/profiles_settings.xml deleted file mode 100644 index b538bdf41..000000000 --- a/.idea/copyright/profiles_settings.xml +++ /dev/null @@ -1,21 +0,0 @@ - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/scopes/Apply_copyright.xml b/.idea/scopes/Apply_copyright.xml deleted file mode 100644 index 0eb589133..000000000 --- a/.idea/scopes/Apply_copyright.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - diff --git a/CHANGELOG.md b/CHANGELOG.md index 12540821e..6733c1211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,10 @@ - Extended operations for ND4J fields - Jupyter Notebook integration module (kmath-jupyter) - `@PerformancePitfall` annotation to mark possibly slow API +- Unified architecture for Integration and Optimization using features. - `BigInt` operation performance improvement and fixes by @zhelenskiy (#328) - Integration between `MST` and Symja `IExpr` +- Complex power ### Changed - Exponential operations merged with hyperbolic functions @@ -36,8 +38,17 @@ - Remove Any restriction on polynomials - Add `out` variance to type parameters of `StructureND` and its implementations where possible - Rename `DifferentiableMstExpression` to `KotlingradExpression` +- `FeatureSet` now accepts only `Feature`. It is possible to override keys and use interfaces. +- Use `Symbol` factory function instead of `StringSymbol` +- New discoverability pattern: `.algebra.` +- Adjusted commons-math API for linear solvers to match conventions. +- Buffer algebra does not require size anymore +- Operations -> Ops +- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes. +- Tensor algebra takes read-only structures as input and inherits AlgebraND ### Deprecated +- Specialized `DoubleBufferAlgebra` ### Removed - Nearest in Domain. To be implemented in geometry package. @@ -47,6 +58,7 @@ - Expression algebra builders - Complex and Quaternion no longer are elements. - Second generic from DifferentiableExpression +- Algebra elements are completely removed. Use algebra contexts instead. ### Fixed - Ring inherits RingOperations, not GroupOperations diff --git a/README.md b/README.md index 015988cd3..db069d4e0 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,10 @@ # KMath -Could be pronounced as `key-math`. The **K**otlin **Math**ematics library was initially intended as a Kotlin-based analog to -Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture -designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could -be achieved with [kmath-for-real](/kmath-for-real) extension module. +Could be pronounced as `key-math`. The **K**otlin **Math**ematics library was initially intended as a Kotlin-based +analog to Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior +architecture designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like +experience could be achieved with [kmath-for-real](/kmath-for-real) extension module. [Documentation site (**WIP**)](https://mipt-npm.github.io/kmath/) @@ -21,26 +21,33 @@ be achieved with [kmath-for-real](/kmath-for-real) extension module. # Goal -* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native). +* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native) + . * Provide basic multiplatform implementations for those abstractions (without significant performance optimization). * Provide bindings and wrappers with those abstractions for popular optimized platform libraries. ## Non-goals -* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API. +* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in API. * Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them. * Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually. -* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like -for `Double` in the core. For that we will have specialization modules like `kmath-for-real`, which will give better -experience for those, who want to work with specific types. +* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like + for `Double` in the core. For that we will have specialization modules like `kmath-for-real`, which will give better + experience for those, who want to work with specific types. ## Features and stability -KMath is a modular library. Different modules provide different features with different API stability guarantees. All core modules are released with the same version, but with different API change policy. The features are described in module definitions below. The module stability could have following levels: +KMath is a modular library. Different modules provide different features with different API stability guarantees. All +core modules are released with the same version, but with different API change policy. The features are described in +module definitions below. The module stability could have the following levels: -* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could break any moment. You can still use it, but be sure to fix the specific version. -* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked with `@UnstableKmathAPI` or other stability warning annotations. -* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor versions, but not in patch versions. API is protected with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool. +* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could + break any moment. You can still use it, but be sure to fix the specific version. +* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked + with `@UnstableKmathAPI` or other stability warning annotations. +* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor + versions, but not in patch versions. API is protected + with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool. * **STABLE**. The API stabilized. Breaking changes are allowed only in major releases. @@ -132,7 +139,7 @@ KMath is a modular library. Different modules provide different features with di objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high performance calculations to code generation. > - [domains](kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains) : Domains -> - [autodif](kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation +> - [autodiff](kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
@@ -161,7 +168,7 @@ performance calculations to code generation.
* ### [kmath-for-real](kmath-for-real) -> Extension module that should be used to achieve numpy-like behavior. +> Extension module that should be used to achieve numpy-like behavior. All operations are specialized to work with `Double` numbers without declaring algebraic contexts. One can still use generic algebras though. > @@ -222,8 +229,8 @@ One can still use generic algebras though. > **Maturity**: EXPERIMENTAL > > **Features:** -> - [differentiable-mst-expression](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt) : MST based DifferentiableExpression. -> - [differentiable-mst-expression](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt) : Conversions between Kotlin∇'s SFun and MST +> - [differentiable-mst-expression](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt) : MST based DifferentiableExpression. +> - [scalars-adapters](kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt) : Conversions between Kotlin∇'s SFun and MST
@@ -264,7 +271,7 @@ One can still use generic algebras though. > > **Features:** > - [tensor algebra](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.) -> - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting. +> - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting. > - [linear algebra operations](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.
@@ -278,30 +285,33 @@ One can still use generic algebras though. ## Multi-platform support -KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the -[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features -are delegated to platform-specific implementations even if they could be provided in the common module for performance -reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and +KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the +[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features +are delegated to platform-specific implementations even if they could be provided in the common module for performance +reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and feedback are also welcome. ## Performance -Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve -both performance and flexibility. +Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve both +performance and flexibility. -We expect to focus on creating convenient universal API first and then work on increasing performance for specific -cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized -native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be +We expect to focus on creating convenient universal API first and then work on increasing performance for specific +cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized +native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be better than SciPy. ## Requirements -KMath currently relies on JDK 11 for compilation and execution of Kotlin-JVM part. We recommend to use GraalVM-CE 11 for execution in order to get better performance. +KMath currently relies on JDK 11 for compilation and execution of Kotlin-JVM part. We recommend to use GraalVM-CE 11 for +execution to get better performance. ### Repositories -Release and development artifacts are accessible from mipt-npm [Space](https://www.jetbrains.com/space/) repository `https://maven.pkg.jetbrains.space/mipt-npm/p/sci/maven` (see documentation of -[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details). The repository could be reached through [repo.kotlin.link](https://repo.kotlin.link) proxy: +Release and development artifacts are accessible from mipt-npm [Space](https://www.jetbrains.com/space/) +repository `https://maven.pkg.jetbrains.space/mipt-npm/p/sci/maven` (see documentation of +[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details). The repository could +be reached through [repo.kotlin.link](https://repo.kotlin.link) proxy: ```kotlin repositories { @@ -309,8 +319,8 @@ repositories { } dependencies { - api("space.kscience:kmath-core:0.3.0-dev-13") - // api("space.kscience:kmath-core-jvm:0.3.0-dev-13") for jvm-specific version + api("space.kscience:kmath-core:0.3.0-dev-14") + // api("space.kscience:kmath-core-jvm:0.3.0-dev-14") for jvm-specific version } ``` @@ -318,7 +328,7 @@ Gradle `6.0+` is required for multiplatform artifacts. ## Contributing -The project requires a lot of additional work. The most important thing we need is a feedback about what features are -required the most. Feel free to create feature requests. We are also welcome to code contributions, -especially in issues marked with +The project requires a lot of additional work. The most important thing we need is a feedback about what features are +required the most. Feel free to create feature requests. We are also welcome to code contributions, especially in issues +marked with [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label. diff --git a/benchmarks/build.gradle.kts b/benchmarks/build.gradle.kts index 2198ac5d6..cca3d312d 100644 --- a/benchmarks/build.gradle.kts +++ b/benchmarks/build.gradle.kts @@ -36,6 +36,7 @@ kotlin { implementation(project(":kmath-dimensions")) implementation(project(":kmath-for-real")) implementation(project(":kmath-jafama")) + implementation(project(":kmath-tensors")) implementation("org.jetbrains.kotlinx:kotlinx-benchmark-runtime:0.3.1") } } @@ -47,7 +48,8 @@ kotlin { implementation(project(":kmath-nd4j")) implementation(project(":kmath-kotlingrad")) implementation(project(":kmath-viktor")) - implementation("org.nd4j:nd4j-native:1.0.0-beta7") + implementation(projects.kmathMultik) + implementation("org.nd4j:nd4j-native:1.0.0-M1") // uncomment if your system supports AVX2 // val os = System.getProperty("os.name") // @@ -81,6 +83,11 @@ benchmark { include("BufferBenchmark") } + configurations.register("nd") { + commonConfiguration() + include("NDFieldBenchmark") + } + configurations.register("dot") { commonConfiguration() include("DotBenchmark") @@ -105,6 +112,16 @@ benchmark { commonConfiguration() include("JafamaBenchmark") } + + configurations.register("viktor") { + commonConfiguration() + include("ViktorBenchmark") + } + + configurations.register("viktorLog") { + commonConfiguration() + include("ViktorLogBenchmark") + } } // Fix kotlinx-benchmarks bug diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ArrayBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ArrayBenchmark.kt index ff933997f..17983e88c 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ArrayBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ArrayBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/BigIntBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/BigIntBenchmark.kt index 749cd5e75..f2b2d4d7a 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/BigIntBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/BigIntBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks @@ -11,7 +11,10 @@ import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.* +import space.kscience.kmath.operations.BigIntField +import space.kscience.kmath.operations.JBigIntegerField +import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.parseBigInteger import java.math.BigInteger @@ -19,12 +22,24 @@ import java.math.BigInteger @State(Scope.Benchmark) internal class BigIntBenchmark { + val kmSmallNumber = BigIntField.number(100) + val jvmSmallNumber = JBigIntegerField.number(100) val kmNumber = BigIntField.number(Int.MAX_VALUE) val jvmNumber = JBigIntegerField.number(Int.MAX_VALUE) - val largeKmNumber = BigIntField { number(11).pow(100_000U) } - val largeJvmNumber: BigInteger = JBigIntegerField { number(11).pow(100_000) } + val kmLargeNumber = BigIntField { number(11).pow(100_000U) } + val jvmLargeNumber: BigInteger = JBigIntegerField { number(11).pow(100_000) } val bigExponent = 50_000 + @Benchmark + fun kmSmallAdd(blackhole: Blackhole) = BigIntField { + blackhole.consume(kmSmallNumber + kmSmallNumber + kmSmallNumber) + } + + @Benchmark + fun jvmSmallAdd(blackhole: Blackhole) = JBigIntegerField { + blackhole.consume(jvmSmallNumber + jvmSmallNumber + jvmSmallNumber) + } + @Benchmark fun kmAdd(blackhole: Blackhole) = BigIntField { blackhole.consume(kmNumber + kmNumber + kmNumber) @@ -37,12 +52,12 @@ internal class BigIntBenchmark { @Benchmark fun kmAddLarge(blackhole: Blackhole) = BigIntField { - blackhole.consume(largeKmNumber + largeKmNumber + largeKmNumber) + blackhole.consume(kmLargeNumber + kmLargeNumber + kmLargeNumber) } @Benchmark fun jvmAddLarge(blackhole: Blackhole) = JBigIntegerField { - blackhole.consume(largeJvmNumber + largeJvmNumber + largeJvmNumber) + blackhole.consume(jvmLargeNumber + jvmLargeNumber + jvmLargeNumber) } @Benchmark @@ -52,7 +67,7 @@ internal class BigIntBenchmark { @Benchmark fun kmMultiplyLarge(blackhole: Blackhole) = BigIntField { - blackhole.consume(largeKmNumber*largeKmNumber) + blackhole.consume(kmLargeNumber*kmLargeNumber) } @Benchmark @@ -62,7 +77,7 @@ internal class BigIntBenchmark { @Benchmark fun jvmMultiplyLarge(blackhole: Blackhole) = JBigIntegerField { - blackhole.consume(largeJvmNumber*largeJvmNumber) + blackhole.consume(jvmLargeNumber*jvmLargeNumber) } @Benchmark diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/BufferBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/BufferBenchmark.kt index 39819d407..5cf194b67 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/BufferBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/BufferBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt index 2c5a03a97..64f9b5dff 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks @@ -11,9 +11,10 @@ import kotlinx.benchmark.Scope import kotlinx.benchmark.State import space.kscience.kmath.commons.linear.CMLinearSpace import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM -import space.kscience.kmath.linear.LinearSpace import space.kscience.kmath.linear.invoke +import space.kscience.kmath.linear.linearSpace import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.structures.Buffer import kotlin.random.Random @State(Scope.Benchmark) @@ -23,8 +24,12 @@ internal class DotBenchmark { const val dim = 1000 //creating invertible matrix - val matrix1 = LinearSpace.real.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } - val matrix2 = LinearSpace.real.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } + val matrix1 = DoubleField.linearSpace.buildMatrix(dim, dim) { _, _ -> + random.nextDouble() + } + val matrix2 = DoubleField.linearSpace.buildMatrix(dim, dim) { _, _ -> + random.nextDouble() + } val cmMatrix1 = CMLinearSpace { matrix1.toCM() } val cmMatrix2 = CMLinearSpace { matrix2.toCM() } @@ -34,37 +39,32 @@ internal class DotBenchmark { } @Benchmark - fun cmDot(blackhole: Blackhole) { - CMLinearSpace.run { - blackhole.consume(cmMatrix1 dot cmMatrix2) - } + fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace { + blackhole.consume(matrix1 dot matrix2) } @Benchmark - fun ejmlDot(blackhole: Blackhole) { - EjmlLinearSpaceDDRM { - blackhole.consume(ejmlMatrix1 dot ejmlMatrix2) - } + fun cmDot(blackhole: Blackhole) = CMLinearSpace { + blackhole.consume(cmMatrix1 dot cmMatrix2) } @Benchmark - fun ejmlDotWithConversion(blackhole: Blackhole) { - EjmlLinearSpaceDDRM { - blackhole.consume(matrix1 dot matrix2) - } + fun ejmlDot(blackhole: Blackhole) = EjmlLinearSpaceDDRM { + blackhole.consume(ejmlMatrix1 dot ejmlMatrix2) } @Benchmark - fun bufferedDot(blackhole: Blackhole) { - LinearSpace.auto(DoubleField).invoke { - blackhole.consume(matrix1 dot matrix2) - } + fun ejmlDotWithConversion(blackhole: Blackhole) = EjmlLinearSpaceDDRM { + blackhole.consume(matrix1 dot matrix2) } @Benchmark - fun realDot(blackhole: Blackhole) { - LinearSpace.real { - blackhole.consume(matrix1 dot matrix2) - } + fun bufferedDot(blackhole: Blackhole) = with(DoubleField.linearSpace(Buffer.Companion::auto)) { + blackhole.consume(matrix1 dot matrix2) + } + + @Benchmark + fun doubleDot(blackhole: Blackhole) = with(DoubleField.linearSpace) { + blackhole.consume(matrix1 dot matrix2) } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt index 0294f924b..63e1511bd 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks @@ -75,8 +75,9 @@ internal class ExpressionsInterpretersBenchmark { private val algebra = DoubleField private const val times = 1_000_000 - private val functional = DoubleField.expressionInExtendedField { - bindSymbol(x) * number(2.0) + number(2.0) / bindSymbol(x) - number(16.0) / sin(bindSymbol(x)) + private val functional = DoubleField.expression { + val x = bindSymbol(Symbol.x) + x * number(2.0) + 2.0 / x - 16.0 / sin(x) } private val node = MstExtendedField { diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/JafamaBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/JafamaBenchmark.kt index 24a730375..9c6551302 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/JafamaBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/JafamaBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks @@ -13,6 +13,8 @@ import space.kscience.kmath.jafama.JafamaDoubleField import space.kscience.kmath.jafama.StrictJafamaDoubleField import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.random.Random @State(Scope.Benchmark) @@ -31,9 +33,10 @@ internal class JafamaBenchmark { fun strictJafama(blackhole: Blackhole) = invokeBenchmarks(blackhole) { x -> StrictJafamaDoubleField { x * power(x, 4) * exp(x) / cos(x) + sin(x) } } - - private inline fun invokeBenchmarks(blackhole: Blackhole, expr: (Double) -> Double) { - val rng = Random(0) - repeat(1000000) { blackhole.consume(expr(rng.nextDouble())) } - } +} + +private inline fun invokeBenchmarks(blackhole: Blackhole, expr: (Double) -> Double) { + contract { callsInPlace(expr, InvocationKind.AT_LEAST_ONCE) } + val rng = Random(0) + repeat(1000000) { blackhole.consume(expr(rng.nextDouble())) } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/MatrixInverseBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/MatrixInverseBenchmark.kt index 7bb32af28..5d331af9a 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/MatrixInverseBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/MatrixInverseBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks @@ -10,13 +10,12 @@ import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Scope import kotlinx.benchmark.State import space.kscience.kmath.commons.linear.CMLinearSpace -import space.kscience.kmath.commons.linear.inverse +import space.kscience.kmath.commons.linear.lupSolver import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM -import space.kscience.kmath.linear.InverseMatrixFeature -import space.kscience.kmath.linear.LinearSpace -import space.kscience.kmath.linear.inverseWithLup import space.kscience.kmath.linear.invoke -import space.kscience.kmath.nd.getFeature +import space.kscience.kmath.linear.linearSpace +import space.kscience.kmath.linear.lupSolver +import space.kscience.kmath.operations.algebra import kotlin.random.Random @State(Scope.Benchmark) @@ -25,7 +24,7 @@ internal class MatrixInverseBenchmark { private val random = Random(1224) private const val dim = 100 - private val space = LinearSpace.real + private val space = Double.algebra.linearSpace //creating invertible matrix private val u = space.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } @@ -35,20 +34,20 @@ internal class MatrixInverseBenchmark { @Benchmark fun kmathLupInversion(blackhole: Blackhole) { - blackhole.consume(LinearSpace.real.inverseWithLup(matrix)) + blackhole.consume(Double.algebra.linearSpace.lupSolver().inverse(matrix)) } @Benchmark fun cmLUPInversion(blackhole: Blackhole) { - with(CMLinearSpace) { - blackhole.consume(inverse(matrix)) + CMLinearSpace { + blackhole.consume(lupSolver().inverse(matrix)) } } @Benchmark fun ejmlInverse(blackhole: Blackhole) { - with(EjmlLinearSpaceDDRM) { - blackhole.consume(matrix.getFeature>()?.inverse) + EjmlLinearSpaceDDRM { + blackhole.consume(matrix.toEjml().inverse()) } } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt index 5e0c6735f..b5af5aa19 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks @@ -9,45 +9,97 @@ import kotlinx.benchmark.Benchmark import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Scope import kotlinx.benchmark.State -import space.kscience.kmath.nd.* +import org.jetbrains.kotlinx.multik.api.Multik +import org.jetbrains.kotlinx.multik.api.ones +import org.jetbrains.kotlinx.multik.ndarray.data.DN +import org.jetbrains.kotlinx.multik.ndarray.data.DataType +import space.kscience.kmath.multik.multikAlgebra +import space.kscience.kmath.nd.BufferedFieldOpsND +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.one +import space.kscience.kmath.nd4j.nd4j import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.tensors.core.DoubleTensor +import space.kscience.kmath.tensors.core.one +import space.kscience.kmath.tensors.core.tensorAlgebra +import space.kscience.kmath.viktor.viktorAlgebra @State(Scope.Benchmark) internal class NDFieldBenchmark { @Benchmark - fun autoFieldAdd(blackhole: Blackhole) { - with(autoField) { - var res: StructureND = one - repeat(n) { res += one } - blackhole.consume(res) - } + fun autoFieldAdd(blackhole: Blackhole) = with(autoField) { + var res: StructureND = one(shape) + repeat(n) { res += 1.0 } + blackhole.consume(res) } @Benchmark - fun specializedFieldAdd(blackhole: Blackhole) { - with(specializedField) { - var res: StructureND = one - repeat(n) { res += 1.0 } - blackhole.consume(res) - } + fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) { + var res: StructureND = one(shape) + repeat(n) { res += 1.0 } + blackhole.consume(res) } - @Benchmark - fun boxingFieldAdd(blackhole: Blackhole) { - with(genericField) { - var res: StructureND = one - repeat(n) { res += 1.0 } - blackhole.consume(res) - } + fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) { + var res: StructureND = one(shape) + repeat(n) { res += 1.0 } + blackhole.consume(res) } + @Benchmark + fun multikAdd(blackhole: Blackhole) = with(multikField) { + var res: StructureND = one(shape) + repeat(n) { res += 1.0 } + blackhole.consume(res) + } + + @Benchmark + fun viktorAdd(blackhole: Blackhole) = with(viktorField) { + var res: StructureND = one(shape) + repeat(n) { res += 1.0 } + blackhole.consume(res) + } + + @Benchmark + fun tensorAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) { + var res: DoubleTensor = one(shape) + repeat(n) { res = res + 1.0 } + blackhole.consume(res) + } + + @Benchmark + fun tensorInPlaceAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) { + val res: DoubleTensor = one(shape) + repeat(n) { res += 1.0 } + blackhole.consume(res) + } + + @Benchmark + fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikAlgebra) { + val res = Multik.ones(shape, DataType.DoubleDataType).wrap() + repeat(n) { res += 1.0 } + blackhole.consume(res) + } + +// @Benchmark +// fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) { +// var res: StructureND = one(dim, dim) +// repeat(n) { res += 1.0 } +// blackhole.consume(res) +// } + private companion object { private const val dim = 1000 private const val n = 100 - private val autoField = AlgebraND.auto(DoubleField, dim, dim) - private val specializedField = AlgebraND.real(dim, dim) - private val genericField = AlgebraND.field(DoubleField, Buffer.Companion::boxing, dim, dim) + private val shape = intArrayOf(dim, dim) + private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto) + private val specializedField = DoubleField.ndAlgebra + private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing) + private val nd4jField = DoubleField.nd4j + private val multikField = DoubleField.multikAlgebra + private val viktorField = DoubleField.viktorAlgebra } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorBenchmark.kt index d2359a791..6b4d5759b 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks @@ -10,19 +10,17 @@ import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Scope import kotlinx.benchmark.State import org.jetbrains.bio.viktor.F64Array -import space.kscience.kmath.nd.AlgebraND -import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.nd.auto -import space.kscience.kmath.nd.real +import space.kscience.kmath.nd.* import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.viktor.ViktorNDField +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.viktor.ViktorFieldND @State(Scope.Benchmark) internal class ViktorBenchmark { @Benchmark fun automaticFieldAddition(blackhole: Blackhole) { with(autoField) { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } blackhole.consume(res) } @@ -31,7 +29,7 @@ internal class ViktorBenchmark { @Benchmark fun realFieldAddition(blackhole: Blackhole) { with(realField) { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } blackhole.consume(res) } @@ -40,7 +38,7 @@ internal class ViktorBenchmark { @Benchmark fun viktorFieldAddition(blackhole: Blackhole) { with(viktorField) { - var res = one + var res = one(shape) repeat(n) { res += 1.0 } blackhole.consume(res) } @@ -57,10 +55,11 @@ internal class ViktorBenchmark { private companion object { private const val dim = 1000 private const val n = 100 + private val shape = Shape(dim, dim) // automatically build context most suited for given type. - private val autoField = AlgebraND.auto(DoubleField, dim, dim) - private val realField = AlgebraND.real(dim, dim) - private val viktorField = ViktorNDField(dim, dim) + private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto) + private val realField = DoubleField.ndAlgebra + private val viktorField = ViktorFieldND(dim, dim) } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorLogBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorLogBenchmark.kt index eac8634f5..a9d1e68fc 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorLogBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ViktorLogBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks @@ -10,19 +10,21 @@ import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Scope import kotlinx.benchmark.State import org.jetbrains.bio.viktor.F64Array -import space.kscience.kmath.nd.AlgebraND -import space.kscience.kmath.nd.auto -import space.kscience.kmath.nd.real +import space.kscience.kmath.nd.BufferedFieldOpsND +import space.kscience.kmath.nd.Shape +import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.one import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.structures.Buffer import space.kscience.kmath.viktor.ViktorFieldND @State(Scope.Benchmark) internal class ViktorLogBenchmark { @Benchmark fun realFieldLog(blackhole: Blackhole) { - with(realNdField) { - val fortyTwo = produce { 42.0 } - var res = one + with(realField) { + val fortyTwo = structureND(shape) { 42.0 } + var res = one(shape) repeat(n) { res = ln(fortyTwo) } blackhole.consume(res) } @@ -31,7 +33,7 @@ internal class ViktorLogBenchmark { @Benchmark fun viktorFieldLog(blackhole: Blackhole) { with(viktorField) { - val fortyTwo = produce { 42.0 } + val fortyTwo = structureND(shape) { 42.0 } var res = one repeat(n) { res = ln(fortyTwo) } blackhole.consume(res) @@ -49,10 +51,11 @@ internal class ViktorLogBenchmark { private companion object { private const val dim = 1000 private const val n = 100 + private val shape = Shape(dim, dim) // automatically build context most suited for given type. - private val autoField = AlgebraND.auto(DoubleField, dim, dim) - private val realNdField = AlgebraND.real(dim, dim) - private val viktorField = ViktorFieldND(intArrayOf(dim, dim)) + private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto) + private val realField = DoubleField.ndAlgebra + private val viktorField = ViktorFieldND(dim, dim) } } diff --git a/build.gradle.kts b/build.gradle.kts index 93cd67d47..c2347f7be 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,3 +1,5 @@ +import java.net.URL + plugins { id("ru.mipt.npm.gradle.project") kotlin("jupyter.api") apply false @@ -7,15 +9,17 @@ allprojects { repositories { maven("https://clojars.org/repo") maven("https://jitpack.io") + maven("http://logicrunch.research.it.uu.se/maven") { isAllowInsecureProtocol = true } + maven("https://oss.sonatype.org/content/repositories/snapshots") mavenCentral() } group = "space.kscience" - version = "0.3.0-dev-14" + version = "0.3.0-dev-17" } subprojects { @@ -23,32 +27,46 @@ subprojects { afterEvaluate { tasks.withType { - dependsOn(tasks.getByName("assemble")) + dependsOn(tasks["assemble"]) dokkaSourceSets.all { - val readmeFile = File(this@subprojects.projectDir, "README.md") - if (readmeFile.exists()) includes.from(readmeFile.absolutePath) - externalDocumentationLink("http://ejml.org/javadoc/") + val readmeFile = this@subprojects.projectDir.resolve("README.md") + if (readmeFile.exists()) includes.from(readmeFile) + val kotlinDirPath = "src/$name/kotlin" + val kotlinDir = file(kotlinDirPath) + + if (kotlinDir.exists()) sourceLink { + localDirectory.set(kotlinDir) + + remoteUrl.set( + URL("https://github.com/mipt-npm/${rootProject.name}/tree/master/${this@subprojects.name}/$kotlinDirPath") + ) + } + externalDocumentationLink("https://commons.apache.org/proper/commons-math/javadocs/api-3.6.1/") externalDocumentationLink("https://deeplearning4j.org/api/latest/") externalDocumentationLink("https://axelclk.bitbucket.io/symja/javadoc/") - externalDocumentationLink("https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/") - externalDocumentationLink("https://breandan.net/kotlingrad/kotlingrad/", "https://breandan.net/kotlingrad/kotlingrad/kotlingrad/package-list") + + externalDocumentationLink( + "https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/", + "https://kotlin.github.io/kotlinx.coroutines/package-list", + ) + + externalDocumentationLink( + "https://breandan.net/kotlingrad/kotlingrad/", + "https://breandan.net/kotlingrad/kotlingrad/kotlingrad/package-list", + ) } } } } -readme { - readmeTemplate = file("docs/templates/README-TEMPLATE.md") -} +readme.readmeTemplate = file("docs/templates/README-TEMPLATE.md") ksciencePublish { - github("kmath") - space() - sonatype() + vcs("https://github.com/mipt-npm/kmath") + space(publish = true) + sonatype(publish = true) } -apiValidation { - nonPublicMarkers.add("space.kscience.kmath.misc.UnstableKMathAPI") -} +apiValidation.nonPublicMarkers.add("space.kscience.kmath.misc.UnstableKMathAPI") diff --git a/buildSrc/build.gradle.kts b/buildSrc/build.gradle.kts index fe69b05c6..36a1ffd9e 100644 --- a/buildSrc/build.gradle.kts +++ b/buildSrc/build.gradle.kts @@ -11,7 +11,7 @@ repositories { dependencies { api("org.jetbrains.kotlinx:kotlinx-serialization-json:1.1.0") - api("ru.mipt.npm:gradle-tools:0.10.0") + api("ru.mipt.npm:gradle-tools:0.10.2") api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:0.3.1") } diff --git a/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/JmhReport.kt b/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/JmhReport.kt index eaa0f59d8..6859de845 100644 --- a/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/JmhReport.kt +++ b/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/JmhReport.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks diff --git a/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/addBenchmarkProperties.kt b/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/addBenchmarkProperties.kt index b55e1320e..72c9ff0ad 100644 --- a/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/addBenchmarkProperties.kt +++ b/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/addBenchmarkProperties.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.benchmarks diff --git a/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt b/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt index 5da7d0f67..cfebf61e7 100644 --- a/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt +++ b/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:Suppress("KDocUnresolvedReference") @@ -14,12 +14,12 @@ private fun Appendable.appendEjmlVector(type: String, ejmlMatrixType: String) { @Language("kotlin") val text = """/** * [EjmlVector] specialization for [$type]. */ -public class Ejml${type}Vector(public override val origin: M) : EjmlVector<$type, M>(origin) { +public class Ejml${type}Vector(override val origin: M) : EjmlVector<$type, M>(origin) { init { require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } } - public override operator fun get(index: Int): $type = origin[0, index] + override operator fun get(index: Int): $type = origin[0, index] }""" appendLine(text) appendLine() @@ -29,8 +29,8 @@ private fun Appendable.appendEjmlMatrix(type: String, ejmlMatrixType: String) { val text = """/** * [EjmlMatrix] specialization for [$type]. */ -public class Ejml${type}Matrix(public override val origin: M) : EjmlMatrix<$type, M>(origin) { - public override operator fun get(i: Int, j: Int): $type = origin[i, j] +public class Ejml${type}Matrix(override val origin: M) : EjmlMatrix<$type, M>(origin) { + override operator fun get(i: Int, j: Int): $type = origin[i, j] }""" appendLine(text) appendLine() @@ -54,23 +54,23 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, /** * The [${kmathAlgebra}] reference. */ - public override val elementAlgebra: $kmathAlgebra get() = $kmathAlgebra + override val elementAlgebra: $kmathAlgebra get() = $kmathAlgebra @Suppress("UNCHECKED_CAST") - public override fun Matrix<${type}>.toEjml(): Ejml${type}Matrix<${ejmlMatrixType}> = when { + override fun Matrix<${type}>.toEjml(): Ejml${type}Matrix<${ejmlMatrixType}> = when { this is Ejml${type}Matrix<*> && origin is $ejmlMatrixType -> this as Ejml${type}Matrix<${ejmlMatrixType}> else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } } @Suppress("UNCHECKED_CAST") - public override fun Point<${type}>.toEjml(): Ejml${type}Vector<${ejmlMatrixType}> = when { + override fun Point<${type}>.toEjml(): Ejml${type}Vector<${ejmlMatrixType}> = when { this is Ejml${type}Vector<*> && origin is $ejmlMatrixType -> this as Ejml${type}Vector<${ejmlMatrixType}> else -> Ejml${type}Vector(${ejmlMatrixType}(size, 1).also { (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } }) } - public override fun buildMatrix( + override fun buildMatrix( rows: Int, columns: Int, initializer: ${kmathAlgebra}.(i: Int, j: Int) -> ${type}, @@ -80,7 +80,7 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, } }.wrapMatrix() - public override fun buildVector( + override fun buildVector( size: Int, initializer: ${kmathAlgebra}.(Int) -> ${type}, ): Ejml${type}Vector<${ejmlMatrixType}> = Ejml${type}Vector(${ejmlMatrixType}(size, 1).also { @@ -90,21 +90,21 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, private fun T.wrapMatrix() = Ejml${type}Matrix(this) private fun T.wrapVector() = Ejml${type}Vector(this) - public override fun Matrix<${type}>.unaryMinus(): Matrix<${type}> = this * elementAlgebra { -one } + override fun Matrix<${type}>.unaryMinus(): Matrix<${type}> = this * elementAlgebra { -one } - public override fun Matrix<${type}>.dot(other: Matrix<${type}>): Ejml${type}Matrix<${ejmlMatrixType}> { + override fun Matrix<${type}>.dot(other: Matrix<${type}>): Ejml${type}Matrix<${ejmlMatrixType}> { val out = ${ejmlMatrixType}(1, 1) CommonOps_${ops}.mult(toEjml().origin, other.toEjml().origin, out) return out.wrapMatrix() } - public override fun Matrix<${type}>.dot(vector: Point<${type}>): Ejml${type}Vector<${ejmlMatrixType}> { + override fun Matrix<${type}>.dot(vector: Point<${type}>): Ejml${type}Vector<${ejmlMatrixType}> { val out = ${ejmlMatrixType}(1, 1) CommonOps_${ops}.mult(toEjml().origin, vector.toEjml().origin, out) return out.wrapVector() } - public override operator fun Matrix<${type}>.minus(other: Matrix<${type}>): Ejml${type}Matrix<${ejmlMatrixType}> { + override operator fun Matrix<${type}>.minus(other: Matrix<${type}>): Ejml${type}Matrix<${ejmlMatrixType}> { val out = ${ejmlMatrixType}(1, 1) CommonOps_${ops}.add( @@ -123,19 +123,19 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, return out.wrapMatrix() } - public override operator fun Matrix<${type}>.times(value: ${type}): Ejml${type}Matrix<${ejmlMatrixType}> { + override operator fun Matrix<${type}>.times(value: ${type}): Ejml${type}Matrix<${ejmlMatrixType}> { val res = ${ejmlMatrixType}(1, 1) CommonOps_${ops}.scale(value, toEjml().origin, res) return res.wrapMatrix() } - public override fun Point<${type}>.unaryMinus(): Ejml${type}Vector<${ejmlMatrixType}> { + override fun Point<${type}>.unaryMinus(): Ejml${type}Vector<${ejmlMatrixType}> { val res = ${ejmlMatrixType}(1, 1) CommonOps_${ops}.changeSign(toEjml().origin, res) return res.wrapVector() } - public override fun Matrix<${type}>.plus(other: Matrix<${type}>): Ejml${type}Matrix<${ejmlMatrixType}> { + override fun Matrix<${type}>.plus(other: Matrix<${type}>): Ejml${type}Matrix<${ejmlMatrixType}> { val out = ${ejmlMatrixType}(1, 1) CommonOps_${ops}.add( @@ -154,7 +154,7 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, return out.wrapMatrix() } - public override fun Point<${type}>.plus(other: Point<${type}>): Ejml${type}Vector<${ejmlMatrixType}> { + override fun Point<${type}>.plus(other: Point<${type}>): Ejml${type}Vector<${ejmlMatrixType}> { val out = ${ejmlMatrixType}(1, 1) CommonOps_${ops}.add( @@ -173,7 +173,7 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, return out.wrapVector() } - public override fun Point<${type}>.minus(other: Point<${type}>): Ejml${type}Vector<${ejmlMatrixType}> { + override fun Point<${type}>.minus(other: Point<${type}>): Ejml${type}Vector<${ejmlMatrixType}> { val out = ${ejmlMatrixType}(1, 1) CommonOps_${ops}.add( @@ -192,18 +192,18 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, return out.wrapVector() } - public override fun ${type}.times(m: Matrix<${type}>): Ejml${type}Matrix<${ejmlMatrixType}> = m * this + override fun ${type}.times(m: Matrix<${type}>): Ejml${type}Matrix<${ejmlMatrixType}> = m * this - public override fun Point<${type}>.times(value: ${type}): Ejml${type}Vector<${ejmlMatrixType}> { + override fun Point<${type}>.times(value: ${type}): Ejml${type}Vector<${ejmlMatrixType}> { val res = ${ejmlMatrixType}(1, 1) CommonOps_${ops}.scale(value, toEjml().origin, res) return res.wrapVector() } - public override fun ${type}.times(v: Point<${type}>): Ejml${type}Vector<${ejmlMatrixType}> = v * this + override fun ${type}.times(v: Point<${type}>): Ejml${type}Vector<${ejmlMatrixType}> = v * this @UnstableKMathAPI - public override fun getFeature(structure: Matrix<${type}>, type: KClass): F? { + override fun computeFeature(structure: Matrix<${type}>, type: KClass): F? { structure.getFeature(type)?.let { return it } val origin = structure.toEjml().origin @@ -240,10 +240,10 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, } override val q: Matrix<${type}> by lazy { - qr.getQ(null, false).wrapMatrix() + OrthogonalFeature + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) } - override val r: Matrix<${type}> by lazy { qr.getR(null, false).wrapMatrix() + UFeature } + override val r: Matrix<${type}> by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } } CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<${type}> { @@ -251,7 +251,7 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, val cholesky = DecompositionFactory_${ops}.chol(structure.rowNum, true).apply { decompose(origin.copy()) } - cholesky.getT(null).wrapMatrix() + LFeature + cholesky.getT(null).wrapMatrix().withFeature(LFeature) } } @@ -261,11 +261,11 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, } override val l: Matrix<${type}> by lazy { - lup.getLower(null).wrapMatrix() + LFeature + lup.getLower(null).wrapMatrix().withFeature(LFeature) } override val u: Matrix<${type}> by lazy { - lup.getUpper(null).wrapMatrix() + UFeature + lup.getUpper(null).wrapMatrix().withFeature(UFeature) } override val p: Matrix<${type}> by lazy { lup.getRowPivot(null).wrapMatrix() } @@ -275,10 +275,10 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, } override val q: Matrix<${type}> by lazy { - qr.getQ(null, false).wrapMatrix() + OrthogonalFeature + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) } - override val r: Matrix<${type}> by lazy { qr.getR(null, false).wrapMatrix() + UFeature } + override val r: Matrix<${type}> by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } } CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<${type}> { @@ -286,7 +286,7 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, val cholesky = DecompositionFactory_${ops}.cholesky().apply { decompose(origin.copy()) } - (cholesky.getT(null) as ${ejmlMatrixParentTypeMatrix}).wrapMatrix() + LFeature + (cholesky.getT(null) as ${ejmlMatrixParentTypeMatrix}).wrapMatrix().withFeature(LFeature) } } @@ -297,11 +297,11 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, } override val l: Matrix<${type}> by lazy { - lu.getLower(null).wrapMatrix() + LFeature + lu.getLower(null).wrapMatrix().withFeature(LFeature) } override val u: Matrix<${type}> by lazy { - lu.getUpper(null).wrapMatrix() + UFeature + lu.getUpper(null).wrapMatrix().withFeature(UFeature) } override val inverse: Matrix<${type}> by lazy { @@ -362,7 +362,7 @@ fun ejmlCodegen(outputFile: String): Unit = File(outputFile).run { writer().use { it.appendLine("/*") it.appendLine(" * Copyright 2018-2021 KMath contributors.") - it.appendLine(" * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.") + it.appendLine(" * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.") it.appendLine(" */") it.appendLine() it.appendLine("/* This file is generated with buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt */") diff --git a/docs/algebra.md b/docs/algebra.md index 84693bb81..20158a125 100644 --- a/docs/algebra.md +++ b/docs/algebra.md @@ -1,85 +1,45 @@ # Algebraic Structures and Algebraic Elements -The mathematical operations in KMath are generally separated from mathematical objects. This means that to perform an -operation, say `+`, one needs two objects of a type `T` and an algebra context, which draws appropriate operation up, -say `Space`. Next one needs to run the actual operation in the context: +The mathematical operations in KMath are generally separated from mathematical objects. This means that to perform an +operation, say `+`, one needs two objects of a type `T` and an algebra context, which draws appropriate operation up, +say `Group`. Next one needs to run the actual operation in the context: ```kotlin import space.kscience.kmath.operations.* val a: T = ... val b: T = ... -val space: Space = ... +val group: Group = ... -val c = space { a + b } +val c = group { a + b } ``` -At first glance, this distinction seems to be a needless complication, but in fact one needs to remember that in -mathematics, one could draw up different operations on same objects. For example, one could use different types of +At first glance, this distinction seems to be a needless complication, but in fact one needs to remember that in +mathematics, one could draw up different operations on same objects. For example, one could use different types of geometry for vectors. ## Algebraic Structures -Mathematical contexts have the following hierarchy: +Primary mathematical contexts have the following hierarchy: -**Algebra** ← **Space** ← **Ring** ← **Field** +`Field <: Ring <: Group <: Algebra` These interfaces follow real algebraic structures: -- [Space](https://mathworld.wolfram.com/VectorSpace.html) defines addition, its neutral element (i.e. 0) and scalar -multiplication; -- [Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and its neutral element (i.e. 1); +- [Group](https://mathworld.wolfram.com/Group.html) defines addition, its identity element (i.e., 0) and additive + inverse (-x); +- [Ring](http://mathworld.wolfram.com/Ring.html) adds multiplication and its identity element (i.e., 1); - [Field](http://mathworld.wolfram.com/Field.html) adds division operation. A typical implementation of `Field` is the `DoubleField` which works on doubles, and `VectorSpace` for `Space`. In some cases algebra context can hold additional operations like `exp` or `sin`, and then it inherits appropriate -interface. Also, contexts may have operations, which produce elements outside of the context. For example, `Matrix.dot` -operation produces a matrix with new dimensions, which can be incompatible with initial matrix in terms of linear -operations. - -## Algebraic Element - -To achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving -contexts KMath submits special type objects called `MathElement`. A `MathElement` is basically some object coupled to -a mathematical context. For example `Complex` is the pair of real numbers representing real and imaginary parts, -but it also holds reference to the `ComplexField` singleton, which allows performing direct operations on `Complex` -numbers without explicit involving the context like: - -```kotlin -import space.kscience.kmath.operations.* - -// Using elements -val c1 = Complex(1.0, 1.0) -val c2 = Complex(1.0, -1.0) -val c3 = c1 + c2 + 3.0.toComplex() - -// Using context -val c4 = ComplexField { c1 + i - 2.0 } -``` - -Both notations have their pros and cons. - -The hierarchy for algebraic elements follows the hierarchy for the corresponding algebraic structures. - -**MathElement** ← **SpaceElement** ← **RingElement** ← **FieldElement** - -`MathElement` is the generic common ancestor of the class with context. - -One major distinction between algebraic elements and algebraic contexts is that elements have three type -parameters: - -1. The type of elements, the field operates on. -2. The self-type of the element returned from operation (which has to be an algebraic element). -3. The type of the algebra over first type-parameter. - -The middle type is needed for of algebra members do not store context. For example, it is impossible to add a context -to regular `Double`. The element performs automatic conversions from context types and back. One should use context -operations in all performance-critical places. The performance of element operations is not guaranteed. +interface. Also, contexts may have operations, which produce elements outside the context. For example, `Matrix.dot` +operation produces a matrix with new dimensions, which can be incompatible with initial matrix in linear operations. ## Spaces and Fields -KMath submits both contexts and elements for builtin algebraic structures: +KMath introduces contexts for builtin algebraic structures: ```kotlin import space.kscience.kmath.operations.* @@ -102,13 +62,13 @@ val c2 = ComplexField { c1 - 1.0 } // Returns: Complex(re=0.0, im=2.0) val c3 = ComplexField { c1 - i * 2.0 } ``` -**Note**: In theory it is possible to add behaviors directly to the context, but as for now Kotlin does not support -that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and +**Note**: In theory it is possible to add behaviors directly to the context, but as for now Kotlin does not support +that. Watch [KT-10468](https://youtrack.jetbrains.com/issue/KT-10468) and [KEEP-176](https://github.com/Kotlin/KEEP/pull/176) for updates. ## Nested fields -Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex +Contexts allow one to build more complex structures. For example, it is possible to create a `Matrix` from complex elements like so: ```kotlin @@ -118,8 +78,9 @@ val element = NDElement.complex(shape = intArrayOf(2, 2)) { index: IntArray -> ``` The `element` in this example is a member of the `Field` of 2D structures, each element of which is a member of its own -`ComplexField`. It is important one does not need to create a special n-d class to hold complex -numbers and implement operations on it, one just needs to provide a field for its elements. +`ComplexField`. It is important one does not need to create a special n-d class to hold complex numbers and implement +operations on it, one just needs to provide a field for its elements. -**Note**: Fields themselves do not solve the problem of JVM boxing, but it is possible to solve with special contexts like +**Note**: Fields themselves do not solve the problem of JVM boxing, but it is possible to solve with special contexts +like `MemorySpec`. diff --git a/docs/buffers.md b/docs/buffers.md index 679bd4e78..e7573497e 100644 --- a/docs/buffers.md +++ b/docs/buffers.md @@ -1,17 +1,20 @@ # Buffers -Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write (with `MutableBuffer`). -There are different types of buffers: +Buffer is one of main building blocks of kmath. It is a basic interface allowing random-access read and write ( +with `MutableBuffer`). There are different types of buffers: -* Primitive buffers wrapping like `RealBuffer` which are wrapping primitive arrays. +* Primitive buffers wrapping like `DoubleBuffer` which are wrapping primitive arrays. * Boxing `ListBuffer` wrapping a list * Functionally defined `VirtualBuffer` which does not hold a state itself, but provides a function to calculate value * `MemoryBuffer` allows direct allocation of objects in continuous memory block. -Some kmath features require a `BufferFactory` class to operate properly. A general convention is to use functions defined in -`Buffer` and `MutableBuffer` companion classes. For example factory `Buffer.Companion::auto` in most cases creates the most suitable -buffer for given reified type (for types with custom memory buffer it still better to use their own `MemoryBuffer.create()` factory). +Some kmath features require a `BufferFactory` class to operate properly. A general convention is to use functions +defined in +`Buffer` and `MutableBuffer` companion classes. For example factory `Buffer.Companion::auto` in most cases creates the +most suitable buffer for given reified type (for types with custom memory buffer it still better to use their +own `MemoryBuffer.create()` factory). ## Buffer performance -One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers instead +One should avoid using default boxing buffer wherever it is possible. Try to use primitive buffers or memory buffers +instead . diff --git a/docs/codestyle.md b/docs/codestyle.md index 541dc4973..73ba5f754 100644 --- a/docs/codestyle.md +++ b/docs/codestyle.md @@ -1,26 +1,20 @@ # Coding Conventions -KMath code follows general [Kotlin conventions](https://kotlinlang.org/docs/reference/coding-conventions.html), but -with a number of small changes and clarifications. +Generally, KMath code follows general [Kotlin coding conventions](https://kotlinlang.org/docs/reference/coding-conventions.html), but with a number of small changes and clarifications. ## Utility Class Naming -Filename should coincide with a name of one of the classes contained in the file or start with small letter and -describe its contents. +Filename should coincide with a name of one of the classes contained in the file or start with small letter and describe its contents. -The code convention [here](https://kotlinlang.org/docs/reference/coding-conventions.html#source-file-names) says that -file names should start with a capital letter even if file does not contain classes. Yet starting utility classes and -aggregators with a small letter seems to be a good way to visually separate those files. +The code convention [here](https://kotlinlang.org/docs/reference/coding-conventions.html#source-file-names) says that file names should start with a capital letter even if file does not contain classes. Yet starting utility classes and aggregators with a small letter seems to be a good way to visually separate those files. This convention could be changed in future in a non-breaking way. ## Private Variable Naming -Private variables' names may start with underscore `_` for of the private mutable variable is shadowed by the public -read-only value with the same meaning. +Private variables' names may start with underscore `_` for of the private mutable variable is shadowed by the public read-only value with the same meaning. -This rule does not permit underscores in names, but it is sometimes useful to "underscore" the fact that public and -private versions draw up the same entity. It is allowed only for private variables. +This rule does not permit underscores in names, but it is sometimes useful to "underscore" the fact that public and private versions draw up the same entity. It is allowed only for private variables. This convention could be changed in future in a non-breaking way. @@ -30,5 +24,4 @@ Use one-liners when they occupy single code window line both for functions and p `val b: String get() = "fff"`. The same should be performed with multiline expressions when they could be cleanly separated. -There is no universal consensus whenever use `fun a() = ...` or `fun a() { return ... }`. Yet from reader outlook -one-lines seem to better show that the property or function is easily calculated. +There is no universal consensus whenever use `fun a() = ...` or `fun a() { return ... }`. Yet from reader outlook one-lines seem to better show that the property or function is easily calculated. diff --git a/docs/contexts.md b/docs/contexts.md index 58b198046..c26333860 100644 --- a/docs/contexts.md +++ b/docs/contexts.md @@ -2,18 +2,17 @@ ## The problem -A known problem for implementing mathematics in statically-typed languages (but not only in them) is that different -sets of mathematical operators can be defined on the same mathematical objects. Sometimes there is no single way to -treat some operations, including basic arithmetic operations, on a Java/Kotlin `Number`. Sometimes there are different ways to -define the same structure, such as Euclidean and elliptic geometry vector spaces over real vectors. Another problem arises when -one wants to add some kind of behavior to an existing entity. In dynamic languages those problems are usually solved -by adding dynamic context-specific behaviors at runtime, but this solution has a lot of drawbacks. +A known problem for implementing mathematics in statically-typed languages (but not only in them) is that different sets +of mathematical operators can be defined on the same mathematical objects. Sometimes there is no single way to treat +some operations, including basic arithmetic operations, on a Java/Kotlin `Number`. Sometimes there are different ways to +define the same structure, such as Euclidean and elliptic geometry vector spaces over real vectors. Another problem +arises when one wants to add some kind of behavior to an existing entity. In dynamic languages those problems are +usually solved by adding dynamic context-specific behaviors at runtime, but this solution has a lot of drawbacks. ## Context-oriented approach -One possible solution to these problems is to divorce numerical representations from behaviors. -For example in Kotlin one can define a separate class which represents some entity without any operations, -ex. a complex number: +One possible solution to these problems is to divorce numerical representations from behaviors. For example in Kotlin +one can define a separate class representing some entity without any operations, ex. a complex number: ```kotlin data class Complex(val re: Double, val im: Double) @@ -28,9 +27,10 @@ object ComplexOperations { } ``` -In Java, applying such external operations could be very cumbersome, but Kotlin has a unique feature which allows us -implement this naturally: [extensions with receivers](https://kotlinlang.org/docs/reference/extensions.html#extension-functions). -In Kotlin, an operation on complex number could be implemented as: +In Java, applying such external operations could be cumbersome, but Kotlin has a unique feature that allows us +implement this +naturally: [extensions with receivers](https://kotlinlang.org/docs/reference/extensions.html#extension-functions). In +Kotlin, an operation on complex number could be implemented as: ```kotlin with(ComplexOperations) { c1 + c2 - c3 } @@ -52,20 +52,20 @@ In KMath, contexts are not only responsible for operations, but also for raw obj ### Type classes -An obvious candidate to get more or less the same functionality is the type class, which allows one to bind a behavior to -a specific type without modifying the type itself. On the plus side, type classes do not require explicit context +An obvious candidate to get more or less the same functionality is the type class, which allows one to bind a behavior +to a specific type without modifying the type itself. On the plus side, type classes do not require explicit context declaration, so the code looks cleaner. On the minus side, if there are different sets of behaviors for the same types, -it is impossible to combine them into one module. Also, unlike type classes, context can have parameters or even -state. For example in KMath, sizes and strides for `NDElement` or `Matrix` could be moved to context to optimize -performance in case of a large amount of structures. +it is impossible to combine them into one module. Also, unlike type classes, context can have parameters or even state. +For example in KMath, sizes and strides for `NDElement` or `Matrix` could be moved to context to optimize performance in +case of a large amount of structures. ### Wildcard imports and importing-on-demand -Sometimes, one may wish to use a single context throughout a file. In this case, is possible to import all members -from a package or file, via `import context.complex.*`. Effectively, this is the same as enclosing an entire file -with a single context. However when using multiple contexts, this technique can introduce operator ambiguity, due to -namespace pollution. If there are multiple scoped contexts which define the same operation, it is still possible to -to import specific operations as needed, without using an explicit context with extension functions, for example: +Sometimes, one may wish to use a single context throughout a file. In this case, is possible to import all members from +a package or file, via `import context.complex.*`. Effectively, this is the same as enclosing an entire file with a +single context. However, when using multiple contexts, this technique can introduce operator ambiguity, due to namespace +pollution. If there are multiple scoped contexts that define the same operation, it is still possible to import +specific operations as needed, without using an explicit context with extension functions, for example: ``` import context.complex.op1 diff --git a/docs/diagrams/core.puml b/docs/diagrams/core.puml new file mode 100644 index 000000000..87f8f2e2d --- /dev/null +++ b/docs/diagrams/core.puml @@ -0,0 +1,1020 @@ +@startuml +interface "ColumnarData" { + size: Int +} +interface "XYColumnarData" { + x: Buffer + y: Buffer +} +interface "XYErrorColumnarData" { + yErr: Buffer +} +interface "XYZColumnarData" { + z: Buffer +} +interface "Domain" { + dimension: Int +} +interface "DoubleDomain" { + +} +class "HyperSquareDomain" { + lower: Buffer + upper: Buffer +} +class "UnconstrainedDomain" { + dimension: Int +} +class "UnivariateDomain" { + range: ClosedFloatingPointRange +} +interface "DifferentiableExpression" { + +} +interface "SpecialDifferentiableExpression" { + +} +abstract "FirstDerivativeExpression" { + +} +interface "AutoDiffProcessor" { + +} +interface "Expression" { + +} +interface "ExpressionAlgebra" { + +} +abstract "FunctionalExpressionAlgebra" { + algebra: A +} +class "FunctionalExpressionGroup" { + algebra: A +} +class "FunctionalExpressionRing" { + algebra: A +} +class "FunctionalExpressionField" { + algebra: A +} +class "FunctionalExpressionExtendedField" { + algebra: A +} +interface "MST" { + +} +class "Numeric" { + value: Number +} +class "Unary" { + operation: String + value: MST +} +class "Binary" { + operation: String + left: MST + right: MST +} +class "InnerAlgebra" { + algebra: Algebra + arguments: Map +} +class "MstNumericAlgebra" { + number() + bindSymbolOrNull() + bindSymbol() + unaryOperationFunction() + binaryOperationFunction() +} +class "MstGroup" { + zero: MST.Numericnumber() + bindSymbolOrNull() + add() + unaryPlus() + unaryMinus() + minus() + scale() + binaryOperationFunction() + unaryOperationFunction() +} +class "MstRing" { + zero: MST.Numeric + one: MST.Numericnumber() + bindSymbolOrNull() + add() + scale() + multiply() + unaryPlus() + unaryMinus() + minus() + binaryOperationFunction() + unaryOperationFunction() +} +class "MstField" { + zero: MST.Numeric + one: MST.NumericbindSymbolOrNull() + number() + add() + scale() + multiply() + divide() + unaryPlus() + unaryMinus() + minus() + binaryOperationFunction() + unaryOperationFunction() +} +class "MstExtendedField" { + zero: MST.Numeric + one: MST.NumericbindSymbolOrNull() + number() + sin() + cos() + tan() + asin() + acos() + atan() + sinh() + cosh() + tanh() + asinh() + acosh() + atanh() + add() + sqrt() + scale() + multiply() + divide() + unaryPlus() + unaryMinus() + minus() + power() + exp() + ln() + binaryOperationFunction() + unaryOperationFunction() +} +class "MstLogicAlgebra" { + bindSymbolOrNull() + const() + not() + and() + or() + xor() +} +class "AutoDiffValue" { + value: T +} +class "DerivationResult" { + value: T + derivativeValues: Map + context: Field +} +class "SimpleAutoDiffField" { + context: F + bindings: Map +} +class "AutoDiffVariableWithDerivative" { + identity: String + value: T + d: T +} +class "SimpleAutoDiffExpression" { + field: F + function: SimpleAutoDiffField +} +class "SimpleAutoDiffExtendedField" { + context: F + bindings: Map +} +interface "Symbol" { + identity: String +} +class "StringSymbol" { + identity: String +} +interface "SymbolIndexer" { + symbols: List +} +class "SimpleSymbolIndexer" { + symbols: List +} +class "BufferedLinearSpace" { + elementAlgebra: A + bufferFactory: BufferFactory +} +interface "LinearSolver" { + +} +interface "LinearSpace" { + elementAlgebra: A +} +class "LupDecomposition" { + context: LinearSpace + elementContext: Field + lu: Matrix + pivot: IntArray + even: Boolean +} +class "MatrixBuilder" { + linearSpace: LinearSpace + rows: Int + columns: Int +} +class "SymmetricMatrixFeature" { + +} +interface "MatrixFeature" { + +} +interface "DiagonalFeature" { + +} +class "ZeroFeature" { + +} +class "UnitFeature" { + +} +interface "InverseMatrixFeature" { + inverse: Matrix +} +interface "DeterminantFeature" { + determinant: T +} +class "LFeature" { + +} +class "UFeature" { + +} +interface "LUDecompositionFeature" { + l: Matrix + u: Matrix +} +interface "LupDecompositionFeature" { + l: Matrix + u: Matrix + p: Matrix +} +class "OrthogonalFeature" { + +} +interface "QRDecompositionFeature" { + q: Matrix + r: Matrix +} +interface "CholeskyDecompositionFeature" { + l: Matrix +} +interface "SingularValueDecompositionFeature" { + u: Matrix + s: Matrix + v: Matrix + singularValues: Point +} +class "MatrixWrapper" { + origin: Matrix + features: FeatureSet +} +class "TransposedFeature" { + original: Matrix +} +class "VirtualMatrix" { + rowNum: Int + colNum: Int + generator: (i:Int,j:Int)->T +} +class "UnstableKMathAPI" { + +} +class "PerformancePitfall" { + message: String +} +interface "Featured" { + +} +interface "Feature" { + key: FeatureKey +} +class "FeatureSet" { + features: Map +} +interface "Loggable" { + +} +class "ShapeMismatchException" { + expected: IntArray + actual: IntArray +} +interface "AlgebraND" { + shape: IntArray + elementContext: C +} +interface "GroupND" { + +} +interface "RingND" { + +} +interface "FieldND" { + +} +interface "BufferAlgebraND" { + strides: Strides + bufferFactory: BufferFactory + buffer: Buffer +} +class "BufferedGroupND" { + shape: IntArray + elementContext: A + bufferFactory: BufferFactory +} +class "BufferedRingND" { + shape: IntArray + elementContext: R + bufferFactory: BufferFactory +} +class "BufferedFieldND" { + shape: IntArray + elementContext: R + bufferFactory: BufferFactory +} +class "BufferND" { + strides: Strides + buffer: Buffer +} +class "MutableBufferND" { + strides: Strides + mutableBuffer: MutableBuffer +} +class "DoubleFieldND" { + shape: IntArray +} +class "ShortRingND" { + shape: IntArray +} +interface "Structure1D" { + dimension: Int +} +interface "MutableStructure1D" { + +} +class "Structure1DWrapper" { + structure: StructureND +} +class "MutableStructure1DWrapper" { + structure: MutableStructureND +} +class "Buffer1DWrapper" { + buffer: Buffer +} +class "MutableBuffer1DWrapper" { + buffer: MutableBuffer +} +interface "Structure2D" { + rowNum: Int + colNum: Int + shape: IntArray + rows: List + columns: List +} +interface "MutableStructure2D" { + rows: List + columns: List +} +class "Structure2DWrapper" { + structure: StructureND +} +class "MutableStructure2DWrapper" { + structure: MutableStructureND +} +interface "StructureFeature" { + +} +interface "StructureND" { + shape: IntArray + dimension: Int +} +interface "MutableStructureND" { + +} +interface "Strides" { + shape: IntArray + strides: IntArray + linearSize: Int +} +class "DefaultStrides" { + shape: IntArray +} +class "KMathContext" { + +} +interface "Algebra" { + +} +interface "GroupOperations" { + +} +interface "Group" { + zero: T +} +interface "RingOperations" { + +} +interface "Ring" { + one: T +} +interface "FieldOperations" { + +} +interface "Field" { + +} +interface "AlgebraElement" { + context: C +} +interface "GroupElement" { + +} +interface "RingElement" { + +} +interface "FieldElement" { + +} +class "BigIntField" { + zero: BigInt + one: BigIntnumber() + unaryMinus() + add() + scale() + multiply() + divide() + unaryPlus() + unaryMinus() +} +class "BigInt" { + sign: Byte + magnitude: Magnitude +} +interface "BufferAlgebra" { + bufferFactory: BufferFactory + elementAlgebra: A +} +class "BufferField" { + bufferFactory: BufferFactory + elementAlgebra: A + size: Int +} +interface "LogicAlgebra" { + +} +class "BooleanAlgebra" { + const() + not() + and() + or() + xor() +} +interface "ExtendedFieldOperations" { + +} +interface "ExtendedField" { + +} +class "DoubleField" { + zero: Double + one: Doublenumber() + binaryOperationFunction() + add() + multiply() + divide() + scale() + sin() + cos() + tan() + acos() + asin() + atan() + sinh() + cosh() + tanh() + asinh() + acosh() + atanh() + sqrt() + power() + exp() + ln() + norm() + unaryMinus() + plus() + minus() + times() + div() +} +class "FloatField" { + zero: Float + one: Floatnumber() + binaryOperationFunction() + add() + scale() + multiply() + divide() + sin() + cos() + tan() + acos() + asin() + atan() + sinh() + cosh() + tanh() + asinh() + acosh() + atanh() + sqrt() + power() + exp() + ln() + norm() + unaryMinus() + plus() + minus() + times() + div() +} +class "IntRing" { + zero: Int + one: Intnumber() + add() + multiply() + norm() + unaryMinus() + plus() + minus() + times() +} +class "ShortRing" { + zero: Short + one: Shortnumber() + add() + multiply() + norm() + unaryMinus() + plus() + minus() + times() +} +class "ByteRing" { + zero: Byte + one: Bytenumber() + add() + multiply() + norm() + unaryMinus() + plus() + minus() + times() +} +class "LongRing" { + zero: Long + one: Longnumber() + add() + multiply() + norm() + unaryMinus() + plus() + minus() + times() +} +interface "NumericAlgebra" { + +} +interface "ScaleOperations" { + +} +interface "NumbersAddOperations" { + +} +interface "TrigonometricOperations" { + +} +interface "PowerOperations" { + +} +interface "ExponentialOperations" { + +} +interface "Norm" { + +} +interface "Buffer" { + size: Int +} +interface "MutableBuffer" { + +} +class "ListBuffer" { + list: List +} +class "MutableListBuffer" { + list: MutableList +} +class "ArrayBuffer" { + array: Array +} +class "ReadOnlyBuffer" { + buffer: MutableBuffer +} +class "VirtualBuffer" { + size: Int + generator: (Int)->T +} +class "BufferAccessor2D" { + rowNum: Int + colNum: Int + factory: MutableBufferFactory +} +class "Row" { + buffer: MutableBuffer + rowIndex: Int +} +class "DoubleBuffer" { + array: DoubleArray +} +class "DoubleBufferFieldOperations" { + unaryMinus() + add() + multiply() + divide() + sin() + cos() + tan() + asin() + acos() + atan() + sinh() + cosh() + tanh() + asinh() + acosh() + atanh() + power() + exp() + ln() +} +class "DoubleL2Norm" { + norm() +} +class "DoubleBufferField" { + size: Int +} +enum "ValueFlag" { + NAN + MISSING + NEGATIVE_INFINITY + POSITIVE_INFINITY +} +interface "FlaggedBuffer" { + +} +class "FlaggedDoubleBuffer" { + values: DoubleArray + flags: ByteArray +} +class "FloatBuffer" { + array: FloatArray +} +class "IntBuffer" { + array: IntArray +} +class "LongBuffer" { + array: LongArray +} +class "MemoryBuffer" { + memory: Memory + spec: MemorySpec +} +class "MutableMemoryBuffer" { + memory: Memory + spec: MemorySpec +} +class "ShortBuffer" { + array: ShortArray +} +class "ExpressionFieldTest" { + x +} +class "InterpretTest" { + +} +class "SimpleAutoDiffTest" { + x + y + z +} +class "DoubleLUSolverTest" { + +} +class "MatrixTest" { + +} +class "CumulativeKtTest" { + +} +class "BigIntAlgebraTest" { + +} +class "BigIntConstructorTest" { + +} +class "BigIntConversionsTest" { + +} +class "BigIntOperationsTest" { + +} +class "DoubleFieldTest" { + +} +class "NDFieldTest" { + +} +class "NumberNDFieldTest" { + algebra + array1 + array2 +} +class "L2Norm" { + norm() +} +interface "AlgebraicVerifier" { + algebra: A +} +class "FieldVerifier" { + algebra: A + a: T + b: T + c: T + x: Number +} +class "RingVerifier" { + algebra: A + a: T + b: T + c: T + x: Number +} +class "SpaceVerifier" { + algebra: S + a: T + b: T + c: T + x: Number +} +class "JBigIntegerField" { + zero: BigInteger + one: BigIntegernumber() + add() + minus() + multiply() + unaryMinus() +} +abstract "JBigDecimalFieldBase" { + mathContext: MathContext +} +class "JBigDecimalField" { + mathContext: MathContext +} +"ColumnarData" <|--- XYColumnarData +"XYColumnarData" <|--- XYErrorColumnarData +"XYColumnarData" <|--- XYZColumnarData +"Domain" <|--- DoubleDomain +"DoubleDomain" <|--- HyperSquareDomain +"DoubleDomain" <|--- UnconstrainedDomain +"DoubleDomain" <|--- UnivariateDomain +"Expression" <|--- DifferentiableExpression +"DifferentiableExpression" <|--- SpecialDifferentiableExpression +"DifferentiableExpression" <|--- FirstDerivativeExpression +"Algebra" <|--- ExpressionAlgebra +"ExpressionAlgebra" <|--- FunctionalExpressionAlgebra +"FunctionalExpressionAlgebra" <|--- FunctionalExpressionGroup +"Group" <|--- FunctionalExpressionGroup +"FunctionalExpressionGroup" <|--- FunctionalExpressionRing +"Ring" <|--- FunctionalExpressionRing +"FunctionalExpressionRing" <|--- FunctionalExpressionField +"Field" <|--- FunctionalExpressionField +"ScaleOperations" <|--- FunctionalExpressionField +"FunctionalExpressionField" <|--- FunctionalExpressionExtendedField +"ExtendedField" <|--- FunctionalExpressionExtendedField +"MST" <|--- Numeric +"MST" <|--- Unary +"MST" <|--- Binary +"NumericAlgebra" <|--- InnerAlgebra +"NumericAlgebra" <|--- MstNumericAlgebra +"Group" <|--- MstGroup +"NumericAlgebra" <|--- MstGroup +"ScaleOperations" <|--- MstGroup +"Ring" <|--- MstRing +"NumbersAddOperations" <|--- MstRing +"ScaleOperations" <|--- MstRing +"Field" <|--- MstField +"NumbersAddOperations" <|--- MstField +"ScaleOperations" <|--- MstField +"ExtendedField" <|--- MstExtendedField +"NumericAlgebra" <|--- MstExtendedField +"LogicAlgebra" <|--- MstLogicAlgebra +"Field" <|--- SimpleAutoDiffField +"ExpressionAlgebra" <|--- SimpleAutoDiffField +"NumbersAddOperations" <|--- SimpleAutoDiffField +"AutoDiffValue" <|--- AutoDiffVariableWithDerivative +"Symbol" <|--- AutoDiffVariableWithDerivative +"FirstDerivativeExpression" <|--- SimpleAutoDiffExpression +"ExtendedField" <|--- SimpleAutoDiffExtendedField +"ScaleOperations" <|--- SimpleAutoDiffExtendedField +'"" <|--- SimpleAutoDiffExtendedField +"SimpleAutoDiffField" <|--- SimpleAutoDiffExtendedField +"MST" <|--- Symbol +"Symbol" <|--- StringSymbol +"SymbolIndexer" <|--- SimpleSymbolIndexer +"LinearSpace" <|--- BufferedLinearSpace +"LupDecompositionFeature" <|--- LupDecomposition +"DeterminantFeature" <|--- LupDecomposition +"MatrixFeature" <|--- SymmetricMatrixFeature +"StructureFeature" <|--- MatrixFeature +"MatrixFeature" <|--- DiagonalFeature +"DiagonalFeature" <|--- ZeroFeature +"DiagonalFeature" <|--- UnitFeature +"MatrixFeature" <|--- InverseMatrixFeature +"MatrixFeature" <|--- DeterminantFeature +"MatrixFeature" <|--- LFeature +"MatrixFeature" <|--- UFeature +"MatrixFeature" <|--- LUDecompositionFeature +"MatrixFeature" <|--- LupDecompositionFeature +"MatrixFeature" <|--- OrthogonalFeature +"MatrixFeature" <|--- QRDecompositionFeature +"MatrixFeature" <|--- CholeskyDecompositionFeature +"MatrixFeature" <|--- SingularValueDecompositionFeature +'"Matrixbyorigin{ +' +' +' @UnstableKMathAPI +' @Suppress +'overridefungetFeature:F? = +'features.getFeature +' +'overridefuntoString" +'}" <|--- MatrixWrapper +"MatrixFeature" <|--- TransposedFeature +"Matrix" <|--- VirtualMatrix +"Featured" <|--- FeatureSet +"RuntimeException" <|--- ShapeMismatchException +"Group" <|--- GroupND +"AlgebraND" <|--- GroupND +"Ring" <|--- RingND +"GroupND" <|--- RingND +"Field" <|--- FieldND +"RingND" <|--- FieldND +"AlgebraND" <|--- BufferAlgebraND +"GroupND" <|--- BufferedGroupND +"BufferAlgebraND" <|--- BufferedGroupND +"BufferedGroupND" <|--- BufferedRingND +"RingND" <|--- BufferedRingND +"BufferedRingND" <|--- BufferedFieldND +"FieldND" <|--- BufferedFieldND +"StructureND" <|--- BufferND +"MutableStructureND" <|--- MutableBufferND +"BufferND" <|--- MutableBufferND +"BufferedFieldND" <|--- DoubleFieldND +'" +'" <|--- DoubleFieldND +'"NumbersAddOperations" <|--- DoubleFieldND +'" +'" <|--- DoubleFieldND +'"ScaleOperations" <|--- DoubleFieldND +'" +'" <|--- DoubleFieldND +"ExtendedField" <|--- DoubleFieldND +"BufferedRingND" <|--- ShortRingND +'" +'" <|--- ShortRingND +"NumbersAddOperations" <|--- ShortRingND +"StructureND" <|--- Structure1D +"Buffer" <|--- Structure1D +"Structure1D" <|--- MutableStructure1D +"MutableStructureND" <|--- MutableStructure1D +"MutableBuffer" <|--- MutableStructure1D +"Structure1D" <|--- Structure1DWrapper +"MutableStructure1D" <|--- MutableStructure1DWrapper +"Structure1D" <|--- Buffer1DWrapper +"MutableStructure1D" <|--- MutableBuffer1DWrapper +"StructureND" <|--- Structure2D +"Structure2D" <|--- MutableStructure2D +"MutableStructureND" <|--- MutableStructure2D +"Structure2D" <|--- Structure2DWrapper +"MutableStructure2D" <|--- MutableStructure2DWrapper +"Feature" <|--- StructureFeature +"Featured" <|--- StructureND +"StructureND" <|--- MutableStructureND +"Strides" <|--- DefaultStrides +"Algebra" <|--- GroupOperations +"GroupOperations" <|--- Group +"GroupOperations" <|--- RingOperations +"Group" <|--- Ring +"RingOperations" <|--- Ring +"RingOperations" <|--- FieldOperations +"Ring" <|--- Field +"FieldOperations" <|--- Field +"ScaleOperations" <|--- Field +"NumericAlgebra" <|--- Field +"AlgebraElement" <|--- GroupElement +"GroupElement" <|--- RingElement +"RingElement" <|--- FieldElement +"Field" <|--- BigIntField +"NumbersAddOperations" <|--- BigIntField +"ScaleOperations" <|--- BigIntField +"Comparable" <|--- BigInt +"Algebra" <|--- BufferAlgebra +"BufferAlgebra" <|--- BufferField +"Field" <|--- BufferField +"Algebra" <|--- LogicAlgebra +"LogicAlgebra" <|--- BooleanAlgebra +"FieldOperations" <|--- ExtendedFieldOperations +'" +'" <|--- ExtendedFieldOperations +'"TrigonometricOperations" <|--- ExtendedFieldOperations +'" +'" <|--- ExtendedFieldOperations +'"PowerOperations" <|--- ExtendedFieldOperations +'" +'" <|--- ExtendedFieldOperations +"ExponentialOperations" <|--- ExtendedFieldOperations +"ExtendedFieldOperations" <|--- ExtendedField +"Field" <|--- ExtendedField +"NumericAlgebra" <|--- ExtendedField +"ScaleOperations" <|--- ExtendedField +"ExtendedField" <|--- DoubleField +"Norm" <|--- DoubleField +"ScaleOperations" <|--- DoubleField +"ExtendedField" <|--- FloatField +"Norm" <|--- FloatField +"Ring" <|--- IntRing +"Norm" <|--- IntRing +"NumericAlgebra" <|--- IntRing +"Ring" <|--- ShortRing +"Norm" <|--- ShortRing +"NumericAlgebra" <|--- ShortRing +"Ring" <|--- ByteRing +"Norm" <|--- ByteRing +"NumericAlgebra" <|--- ByteRing +"Ring" <|--- LongRing +"Norm" <|--- LongRing +"NumericAlgebra" <|--- LongRing +"Algebra" <|--- NumericAlgebra +"Algebra" <|--- ScaleOperations +"Ring" <|--- NumbersAddOperations +"NumericAlgebra" <|--- NumbersAddOperations +"Algebra" <|--- TrigonometricOperations +"Algebra" <|--- PowerOperations +"Algebra" <|--- ExponentialOperations +"Buffer" <|--- MutableBuffer +"Buffer" <|--- ListBuffer +"MutableBuffer" <|--- MutableListBuffer +"MutableBuffer" <|--- ArrayBuffer +"Buffer" <|--- ReadOnlyBuffer +"Buffer" <|--- VirtualBuffer +"MutableBuffer" <|--- Row +"MutableBuffer" <|--- DoubleBuffer +"ExtendedFieldOperations" <|--- DoubleBufferFieldOperations +"Norm" <|--- DoubleL2Norm +"ExtendedField" <|--- DoubleBufferField +"Norm" <|--- DoubleBufferField +"Buffer" <|--- FlaggedBuffer +"FlaggedBuffer" <|--- FlaggedDoubleBuffer +'" +'" <|--- FlaggedDoubleBuffer +"Buffer" <|--- FlaggedDoubleBuffer +"MutableBuffer" <|--- FloatBuffer +"MutableBuffer" <|--- IntBuffer +"MutableBuffer" <|--- LongBuffer +"Buffer" <|--- MemoryBuffer +"MemoryBuffer" <|--- MutableMemoryBuffer +'" +'" <|--- MutableMemoryBuffer +"MutableBuffer" <|--- MutableMemoryBuffer +"MutableBuffer" <|--- ShortBuffer +"Norm" <|--- L2Norm +"RingVerifier" <|--- FieldVerifier +"SpaceVerifier" <|--- RingVerifier +"AlgebraicVerifier" <|--- SpaceVerifier +"Ring" <|--- JBigIntegerField +"NumericAlgebra" <|--- JBigIntegerField +"Field" <|--- JBigDecimalFieldBase +"PowerOperations" <|--- JBigDecimalFieldBase +"NumericAlgebra" <|--- JBigDecimalFieldBase +"ScaleOperations" <|--- JBigDecimalFieldBase +"JBigDecimalFieldBase" <|--- JBigDecimalField +@enduml \ No newline at end of file diff --git a/docs/expressions.md b/docs/expressions.md index 1e05e5340..e6250110c 100644 --- a/docs/expressions.md +++ b/docs/expressions.md @@ -1,26 +1,21 @@ # Expressions -**Experimental: this API is in early stage and could change any time** - -Expressions is an experimental feature which allows to construct lazily or immediately calculated parametric mathematical -expressions. +Expressions is a feature, which allows constructing lazily or immediately calculated parametric mathematical expressions. The potential use-cases for it (so far) are following: -* Lazy evaluation (in general simple lambda is better, but there are some border cases) +* lazy evaluation (in general simple lambda is better, but there are some border cases); +* automatic differentiation in single-dimension and in multiple dimensions; +* generation of mathematical syntax trees with subsequent code generation for other languages; +* symbolic computations, especially differentiation (and some other actions with `kmath-symja` integration with Symja's `IExpr`—integration, simplification, and more); +* visualization with `kmath-jupyter`. -* Automatic differentiation in single-dimension and in multiple dimensions - -* Generation of mathematical syntax trees with subsequent code generation for other languages - -* Maybe symbolic computations (needs additional research) - -The workhorse of this API is `Expression` interface which exposes single `operator fun invoke(arguments: Map): T` -method. `ExpressionContext` is used to generate expressions and introduce variables. +The workhorse of this API is `Expression` interface, which exposes single `operator fun invoke(arguments: Map): T` +method. `ExpressionAlgebra` is used to generate expressions and introduce variables. Currently there are two implementations: * Generic `ExpressionField` in `kmath-core` which allows construction of custom lazy expressions -* Auto-differentiation expression in `kmath-commons` module allows to use full power of `DerivativeStructure` +* Auto-differentiation expression in `kmath-commons` module allows using full power of `DerivativeStructure` from commons-math. **TODO: add example** diff --git a/docs/features.md b/docs/features.md deleted file mode 100644 index 1068a4417..000000000 --- a/docs/features.md +++ /dev/null @@ -1,14 +0,0 @@ -# Features - -* [Algebra](algebra.md) - [Context-based](contexts.md) operations on different primitives and structures. - -* [NDStructures](nd-structure.md) - -* [Linear algebra](linear.md) - Matrices, operations and linear equations solving. To be moved to separate module. Currently supports basic -api and multiple library back-ends. - -* [Histograms](histograms.md) - Multidimensional histogram calculation and operations. - -* [Expressions](expressions.md) - -* Commons math integration diff --git a/docs/images/KM.svg b/docs/images/KM.svg index 83af21f35..f5ec452c7 100644 --- a/docs/images/KM.svg +++ b/docs/images/KM.svg @@ -1,4 +1,9 @@ + + + + + + + + i.toDouble() } + val mat = buildMatrix(10, 10) { i, j -> i.toDouble() + j } -## Vector spaces + // Addition + vec + vec + mat + mat + // Multiplication by scalar + vec * 2.0 + mat * 2.0 -## Matrix operations + // Dot product + mat dot vec + mat dot mat +} +``` -## Back-end overview \ No newline at end of file +## Backends overview + +### EJML +### Commons Math diff --git a/docs/nd-structure.md b/docs/nd-structure.md index ec9b4d521..3e9203ec0 100644 --- a/docs/nd-structure.md +++ b/docs/nd-structure.md @@ -11,16 +11,16 @@ Let us consider following contexts: ```kotlin // automatically build context most suited for given type. val autoField = NDField.auto(DoubleField, dim, dim) - // specialized nd-field for Double. It works as generic Double field as well + // specialized nd-field for Double. It works as generic Double field as well. val specializedField = NDField.real(dim, dim) //A generic boxing field. It should be used for objects, not primitives. val genericField = NDField.buffered(DoubleField, dim, dim) ``` -Now let us perform several tests and see which implementation is best suited for each case: +Now let us perform several tests and see, which implementation is best suited for each case: ## Test case -In order to test performance we will take 2d-structures with `dim = 1000` and add a structure filled with `1.0` +To test performance we will take 2d-structures with `dim = 1000` and add a structure filled with `1.0` to it `n = 1000` times. ## Specialized @@ -35,8 +35,8 @@ The code to run this looks like: ``` The performance of this code is the best of all tests since it inlines all operations and is specialized for operation with doubles. We will measure everything else relative to this one, so time for this test will be `1x` (real time -on my computer is about 4.5 seconds). The only problem with this approach is that it requires to specify type -from the beginning. Everyone do so anyway, so it is the recommended approach. +on my computer is about 4.5 seconds). The only problem with this approach is that it requires specifying type +from the beginning. Everyone does so anyway, so it is the recommended approach. ## Automatic Let's do the same with automatic field inference: @@ -49,7 +49,7 @@ Let's do the same with automatic field inference: } ``` Ths speed of this operation is approximately the same as for specialized case since `NDField.auto` just -returns the same `RealNDField` in this case. Of course it is usually better to use specialized method to be sure. +returns the same `RealNDField` in this case. Of course, it is usually better to use specialized method to be sure. ## Lazy Lazy field does not produce a structure when asked, instead it generates an empty structure and fills it on-demand @@ -63,7 +63,7 @@ When one calls } } ``` -The result will be calculated almost immediately but the result will be empty. In order to get the full result +The result will be calculated almost immediately but the result will be empty. To get the full result structure one needs to call all its elements. In this case computation overhead will be huge. So this field never should be used if one expects to use the full result structure. Though if one wants only small fraction, it could save a lot of time. @@ -94,7 +94,7 @@ The boxing field produced by } } ``` -obviously is the slowest one, because it requires to box and unbox the `double` on each operation. It takes about +is the slowest one, because it requires boxing and unboxing the `double` on each operation. It takes about `15x` time (**TODO: there seems to be a problem here, it should be slow, but not that slow**). This field should never be used for primitives. @@ -115,12 +115,14 @@ via extension function. Usually it is bad idea to compare the direct numerical operation performance in different languages, but it hard to work completely without frame of reference. In this case, simple numpy code: ```python +import numpy as np + res = np.ones((1000,1000)) for i in range(1000): res = res + 1.0 ``` gives the completion time of about `1.1x`, which means that specialized kotlin code in fact is working faster (I think it is because better memory management). Of course if one writes `res += 1.0`, the performance will be different, -but it would be differenc case, because numpy overrides `+=` with in-place operations. In-place operations are +but it would be different case, because numpy overrides `+=` with in-place operations. In-place operations are available in `kmath` with `MutableNDStructure` but there is no field for it (one can still work with mapping functions). \ No newline at end of file diff --git a/docs/readme.md b/docs/readme.md new file mode 100644 index 000000000..2953b7113 --- /dev/null +++ b/docs/readme.md @@ -0,0 +1,14 @@ +# Documentation + +* [Algebra](algebra.md): [context-based](contexts.md) operations on different primitives and structures. + +* [NDStructures](nd-structure.md) + +* [Linear algebra](linear.md): matrices, operations and linear equations solving. To be moved to separate module. + Currently, supports basic API and multiple library back-ends. + +* [Histograms](histograms.md): multidimensional histogram calculation and operations. + +* [Expressions](expressions.md) + +* Commons math integration diff --git a/docs/templates/README-TEMPLATE.md b/docs/templates/README-TEMPLATE.md index bad11a31a..e75d4c5ed 100644 --- a/docs/templates/README-TEMPLATE.md +++ b/docs/templates/README-TEMPLATE.md @@ -6,10 +6,10 @@ # KMath -Could be pronounced as `key-math`. The **K**otlin **Math**ematics library was initially intended as a Kotlin-based analog to -Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior architecture -designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like experience could -be achieved with [kmath-for-real](/kmath-for-real) extension module. +Could be pronounced as `key-math`. The **K**otlin **Math**ematics library was initially intended as a Kotlin-based +analog to Python's NumPy library. Later we found that kotlin is much more flexible language and allows superior +architecture designs. In contrast to `numpy` and `scipy` it is modular and has a lightweight core. The `numpy`-like +experience could be achieved with [kmath-for-real](/kmath-for-real) extension module. [Documentation site (**WIP**)](https://mipt-npm.github.io/kmath/) @@ -21,26 +21,33 @@ be achieved with [kmath-for-real](/kmath-for-real) extension module. # Goal -* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native). +* Provide a flexible and powerful API to work with mathematics abstractions in Kotlin-multiplatform (JVM, JS and Native) + . * Provide basic multiplatform implementations for those abstractions (without significant performance optimization). * Provide bindings and wrappers with those abstractions for popular optimized platform libraries. ## Non-goals -* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in terms of API. +* Be like NumPy. It was the idea at the beginning, but we decided that we can do better in API. * Provide the best performance out of the box. We have specialized libraries for that. Need only API wrappers for them. * Cover all cases as immediately and in one bundle. We will modularize everything and add new features gradually. -* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like -for `Double` in the core. For that we will have specialization modules like `kmath-for-real`, which will give better -experience for those, who want to work with specific types. +* Provide specialized behavior in the core. API is made generic on purpose, so one needs to specialize for types, like + for `Double` in the core. For that we will have specialization modules like `kmath-for-real`, which will give better + experience for those, who want to work with specific types. ## Features and stability -KMath is a modular library. Different modules provide different features with different API stability guarantees. All core modules are released with the same version, but with different API change policy. The features are described in module definitions below. The module stability could have following levels: +KMath is a modular library. Different modules provide different features with different API stability guarantees. All +core modules are released with the same version, but with different API change policy. The features are described in +module definitions below. The module stability could have the following levels: -* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could break any moment. You can still use it, but be sure to fix the specific version. -* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked with `@UnstableKmathAPI` or other stability warning annotations. -* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor versions, but not in patch versions. API is protected with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool. +* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could + break any moment. You can still use it, but be sure to fix the specific version. +* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked + with `@UnstableKmathAPI` or other stability warning annotations. +* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor + versions, but not in patch versions. API is protected + with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool. * **STABLE**. The API stabilized. Breaking changes are allowed only in major releases. @@ -78,30 +85,33 @@ $modules ## Multi-platform support -KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the -[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features -are delegated to platform-specific implementations even if they could be provided in the common module for performance -reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and +KMath is developed as a multi-platform library, which means that most of the interfaces are declared in the +[common source sets](/kmath-core/src/commonMain) and implemented there wherever it is possible. In some cases, features +are delegated to platform-specific implementations even if they could be provided in the common module for performance +reasons. Currently, the Kotlin/JVM is the primary platform, however Kotlin/Native and Kotlin/JS contributions and feedback are also welcome. ## Performance -Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve -both performance and flexibility. +Calculation performance is one of major goals of KMath in the future, but in some cases it is impossible to achieve both +performance and flexibility. -We expect to focus on creating convenient universal API first and then work on increasing performance for specific -cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized -native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be +We expect to focus on creating convenient universal API first and then work on increasing performance for specific +cases. We expect the worst KMath benchmarks will perform better than native Python, but worse than optimized +native/SciPy (mostly due to boxing operations on primitive numbers). The best performance of optimized parts could be better than SciPy. ## Requirements -KMath currently relies on JDK 11 for compilation and execution of Kotlin-JVM part. We recommend to use GraalVM-CE 11 for execution in order to get better performance. +KMath currently relies on JDK 11 for compilation and execution of Kotlin-JVM part. We recommend to use GraalVM-CE 11 for +execution to get better performance. ### Repositories -Release and development artifacts are accessible from mipt-npm [Space](https://www.jetbrains.com/space/) repository `https://maven.pkg.jetbrains.space/mipt-npm/p/sci/maven` (see documentation of -[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details). The repository could be reached through [repo.kotlin.link](https://repo.kotlin.link) proxy: +Release and development artifacts are accessible from mipt-npm [Space](https://www.jetbrains.com/space/) +repository `https://maven.pkg.jetbrains.space/mipt-npm/p/sci/maven` (see documentation of +[Kotlin Multiplatform](https://kotlinlang.org/docs/reference/multiplatform.html) for more details). The repository could +be reached through [repo.kotlin.link](https://repo.kotlin.link) proxy: ```kotlin repositories { @@ -118,7 +128,7 @@ Gradle `6.0+` is required for multiplatform artifacts. ## Contributing -The project requires a lot of additional work. The most important thing we need is a feedback about what features are -required the most. Feel free to create feature requests. We are also welcome to code contributions, -especially in issues marked with +The project requires a lot of additional work. The most important thing we need is a feedback about what features are +required the most. Feel free to create feature requests. We are also welcome to code contributions, especially in issues +marked with [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero) label. diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 406b8f470..7b1bce26a 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -20,6 +20,7 @@ dependencies { implementation(project(":kmath-coroutines")) implementation(project(":kmath-commons")) implementation(project(":kmath-complex")) + implementation(project(":kmath-optimization")) implementation(project(":kmath-stat")) implementation(project(":kmath-viktor")) implementation(project(":kmath-dimensions")) @@ -28,6 +29,11 @@ dependencies { implementation(project(":kmath-tensors")) implementation(project(":kmath-symja")) implementation(project(":kmath-for-real")) + //jafama + implementation(project(":kmath-jafama")) + //multik + implementation(projects.kmathMultik) + implementation("org.nd4j:nd4j-native:1.0.0-beta7") @@ -41,11 +47,12 @@ dependencies { // } else implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") - implementation("org.slf4j:slf4j-simple:1.7.30") + // multik implementation + implementation("org.jetbrains.kotlinx:multik-default:0.1.0") + + implementation("org.slf4j:slf4j-simple:1.7.32") // plotting - implementation("space.kscience:plotlykt-server:0.4.0") - //jafama - implementation(project(":kmath-jafama")) + implementation("space.kscience:plotlykt-server:0.5.0") } kotlin.sourceSets.all { diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/astRendering.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/astRendering.kt index e16769464..0c16d82d1 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/astRendering.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/astRendering.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast @@ -10,7 +10,7 @@ import space.kscience.kmath.ast.rendering.LatexSyntaxRenderer import space.kscience.kmath.ast.rendering.MathMLSyntaxRenderer import space.kscience.kmath.ast.rendering.renderWithStringBuilder -public fun main() { +fun main() { val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath() val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst) println("MathSyntax:") diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt index 96c9856cf..887d76c42 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt index 6ceaa962a..4e3528b3e 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast @@ -13,8 +13,8 @@ import space.kscience.kmath.kotlingrad.toKotlingradExpression import space.kscience.kmath.operations.DoubleField /** - * In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with - * valid derivative in a certain point. + * In this example, *x2 − 4 x − 44* function is differentiated with Kotlin∇, and the + * derivation result is compared with valid derivative in a certain point. */ fun main() { val actualDerivative = "x^2-4*x-44" diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/symjaSupport.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/symjaSupport.kt index a9eca0500..209523c89 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/symjaSupport.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/symjaSupport.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast @@ -13,8 +13,8 @@ import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.symja.toSymjaExpression /** - * In this example, x^2-4*x-44 function is differentiated with Symja, and the autodiff result is compared with - * valid derivative in a certain point. + * In this example, *x2 − 4 x − 44* function is differentiated with Symja, and the + * derivation result is compared with valid derivative in a certain point. */ fun main() { val actualDerivative = "x^2-4*x-44" diff --git a/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt b/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt similarity index 69% rename from examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt rename to examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt index 5e64235e3..dbe0b8454 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt @@ -1,32 +1,34 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ -package space.kscience.kmath.commons.fit +package space.kscience.kmath.fit import kotlinx.html.br import kotlinx.html.h3 -import space.kscience.kmath.commons.optimization.chiSquared -import space.kscience.kmath.commons.optimization.minimize +import space.kscience.kmath.commons.expressions.DSProcessor +import space.kscience.kmath.commons.optimization.CMOptimizer import space.kscience.kmath.distributions.NormalDistribution +import space.kscience.kmath.expressions.chiSquaredExpression import space.kscience.kmath.expressions.symbol -import space.kscience.kmath.optimization.FunctionOptimization -import space.kscience.kmath.optimization.OptimizationResult +import space.kscience.kmath.operations.asIterable +import space.kscience.kmath.operations.toList +import space.kscience.kmath.optimization.FunctionOptimizationTarget +import space.kscience.kmath.optimization.optimizeWith +import space.kscience.kmath.optimization.resultPoint +import space.kscience.kmath.optimization.resultValue import space.kscience.kmath.real.DoubleVector import space.kscience.kmath.real.map import space.kscience.kmath.real.step import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.structures.asIterable -import space.kscience.kmath.structures.toList import space.kscience.plotly.* import space.kscience.plotly.models.ScatterMode import space.kscience.plotly.models.TraceValues import kotlin.math.pow import kotlin.math.sqrt -//Forward declaration of symbols that will be used in expressions. -// This declaration is required for +// Forward declaration of symbols that will be used in expressions. private val a by symbol private val b by symbol private val c by symbol @@ -43,7 +45,7 @@ operator fun TraceValues.invoke(vector: DoubleVector) { */ suspend fun main() { //A generator for a normally distributed values - val generator = NormalDistribution(2.0, 7.0) + val generator = NormalDistribution(0.0, 1.0) //A chain/flow of random values with the given seed val chain = generator.sample(RandomGenerator.default(112667)) @@ -54,7 +56,7 @@ suspend fun main() { //Perform an operation on each x value (much more effective, than numpy) - val y = x.map { + val y = x.map { it -> val value = it.pow(2) + it + 1 value + chain.next() * sqrt(value) } @@ -65,17 +67,21 @@ suspend fun main() { val yErr = y.map { sqrt(it) }//RealVector.same(x.size, sigma) // compute differentiable chi^2 sum for given model ax^2 + bx + c - val chi2 = FunctionOptimization.chiSquared(x, y, yErr) { x1 -> + val chi2 = DSProcessor.chiSquaredExpression(x, y, yErr) { arg -> //bind variables to autodiff context val a = bindSymbol(a) val b = bindSymbol(b) //Include default value for c if it is not provided as a parameter val c = bindSymbolOrNull(c) ?: one - a * x1.pow(2) + b * x1 + c + a * arg.pow(2) + b * arg + c } //minimize the chi^2 in given starting point. Derivatives are not required, they are already included. - val result: OptimizationResult = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) + val result = chi2.optimizeWith( + CMOptimizer, + mapOf(a to 1.5, b to 0.9, c to 1.0), + FunctionOptimizationTarget.MINIMIZE + ) //display a page with plot and numerical results val page = Plotly.page { @@ -92,7 +98,7 @@ suspend fun main() { scatter { mode = ScatterMode.lines x(x) - y(x.map { result.point[a]!! * it.pow(2) + result.point[b]!! * it + 1 }) + y(x.map { result.resultPoint[a]!! * it.pow(2) + result.resultPoint[b]!! * it + 1 }) name = "fit" } } @@ -101,7 +107,7 @@ suspend fun main() { +"Fit result: $result" } h3 { - +"Chi2/dof = ${result.value / (x.size - 3)}" + +"Chi2/dof = ${result.resultValue / (x.size - 3)}" } } diff --git a/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt b/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt new file mode 100644 index 000000000..d52976671 --- /dev/null +++ b/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt @@ -0,0 +1,106 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.fit + +import kotlinx.html.br +import kotlinx.html.h3 +import space.kscience.kmath.commons.expressions.DSProcessor +import space.kscience.kmath.data.XYErrorColumnarData +import space.kscience.kmath.distributions.NormalDistribution +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.expressions.binding +import space.kscience.kmath.expressions.symbol +import space.kscience.kmath.operations.asIterable +import space.kscience.kmath.operations.toList +import space.kscience.kmath.optimization.QowOptimizer +import space.kscience.kmath.optimization.chiSquaredOrNull +import space.kscience.kmath.optimization.fitWith +import space.kscience.kmath.optimization.resultPoint +import space.kscience.kmath.real.map +import space.kscience.kmath.real.step +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.plotly.* +import space.kscience.plotly.models.ScatterMode +import kotlin.math.abs +import kotlin.math.pow +import kotlin.math.sqrt + +// Forward declaration of symbols that will be used in expressions. +private val a by symbol +private val b by symbol +private val c by symbol + + +/** + * Least squares fie with auto-differentiation. Uses `kmath-commons` and `kmath-for-real` modules. + */ +suspend fun main() { + //A generator for a normally distributed values + val generator = NormalDistribution(0.0, 1.0) + + //A chain/flow of random values with the given seed + val chain = generator.sample(RandomGenerator.default(112667)) + + + //Create a uniformly distributed x values like numpy.arrange + val x = 1.0..100.0 step 1.0 + + + //Perform an operation on each x value (much more effective, than numpy) + val y = x.map { it -> + val value = it.pow(2) + it + 1 + value + chain.next() * sqrt(value) + } + // this will also work, but less effective: + // val y = x.pow(2)+ x + 1 + chain.nextDouble() + + // create same errors for all xs + val yErr = y.map { sqrt(abs(it)) } + require(yErr.asIterable().all { it > 0 }) { "All errors must be strictly positive" } + + val result = XYErrorColumnarData.of(x, y, yErr).fitWith( + QowOptimizer, + DSProcessor, + mapOf(a to 0.9, b to 1.2, c to 2.0) + ) { arg -> + //bind variables to autodiff context + val a by binding + val b by binding + //Include default value for c if it is not provided as a parameter + val c = bindSymbolOrNull(c) ?: one + a * arg.pow(2) + b * arg + c + } + + //display a page with plot and numerical results + val page = Plotly.page { + plot { + scatter { + mode = ScatterMode.markers + x(x) + y(y) + error_y { + array = yErr.toList() + } + name = "data" + } + scatter { + mode = ScatterMode.lines + x(x) + y(x.map { result.model(result.resultPoint + (Symbol.x to it)) }) + name = "fit" + } + } + br() + h3 { + +"Fit result: ${result.resultPoint}" + } + h3 { + +"Chi2/dof = ${result.chiSquaredOrNull!! / (x.size - 3)}" + } + } + + page.makeFile() +} diff --git a/examples/src/main/kotlin/space/kscience/kmath/functions/integrate.kt b/examples/src/main/kotlin/space/kscience/kmath/functions/integrate.kt index f60b1ab45..c77d1d70c 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/functions/integrate.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/functions/integrate.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.functions diff --git a/examples/src/main/kotlin/space/kscience/kmath/functions/interpolate.kt b/examples/src/main/kotlin/space/kscience/kmath/functions/interpolate.kt index 8dbc7b7a4..a98467ced 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/functions/interpolate.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/functions/interpolate.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.functions diff --git a/examples/src/main/kotlin/space/kscience/kmath/functions/interpolateSquare.kt b/examples/src/main/kotlin/space/kscience/kmath/functions/interpolateSquare.kt index 33973c880..3f958b3b0 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/functions/interpolateSquare.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/functions/interpolateSquare.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.functions @@ -8,8 +8,8 @@ package space.kscience.kmath.functions import space.kscience.kmath.interpolation.SplineInterpolator import space.kscience.kmath.interpolation.interpolatePolynomials import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.real.map import space.kscience.kmath.real.step -import space.kscience.kmath.structures.map import space.kscience.plotly.Plotly import space.kscience.plotly.UnstablePlotlyAPI import space.kscience.plotly.makeFile diff --git a/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt b/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt index 2619d3d74..4b6ac475c 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.functions @@ -9,20 +9,21 @@ import space.kscience.kmath.integration.gaussIntegrator import space.kscience.kmath.integration.integrate import space.kscience.kmath.integration.value import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.nd.nd -import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.nd.structureND +import space.kscience.kmath.nd.withNdAlgebra +import space.kscience.kmath.operations.algebra import space.kscience.kmath.operations.invoke -fun main(): Unit = DoubleField { - nd(2, 2) { +fun main(): Unit = Double.algebra { + withNdAlgebra(2, 2) { //Produce a diagonal StructureND - fun diagonal(v: Double) = produce { (i, j) -> + fun diagonal(v: Double) = structureND { (i, j) -> if (i == j) v else 0.0 } //Define a function in a nd space - val function: (Double) -> StructureND = { x: Double -> 3 * number(x).pow(2) + 2 * diagonal(x) + 1 } + val function: (Double) -> StructureND = { x: Double -> 3 * x.pow(2) + 2 * diagonal(x) + 1 } //get the result of the integration val result = gaussIntegrator.integrate(0.0..10.0, function = function) diff --git a/examples/src/main/kotlin/space/kscience/kmath/jafama/JafamaDemo.kt b/examples/src/main/kotlin/space/kscience/kmath/jafama/JafamaDemo.kt new file mode 100644 index 000000000..10ed30728 --- /dev/null +++ b/examples/src/main/kotlin/space/kscience/kmath/jafama/JafamaDemo.kt @@ -0,0 +1,15 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.jafama + +import space.kscience.kmath.operations.invoke + +fun main() { + val a = 2.0 + val b = StrictJafamaDoubleField { exp(a) } + println(JafamaDoubleField { b + a }) + println(StrictJafamaDoubleField { ln(b) }) +} diff --git a/examples/src/main/kotlin/space/kscience/kmath/jafama/KMathaJafamaDemo.kt b/examples/src/main/kotlin/space/kscience/kmath/jafama/KMathaJafamaDemo.kt deleted file mode 100644 index 879aab08f..000000000 --- a/examples/src/main/kotlin/space/kscience/kmath/jafama/KMathaJafamaDemo.kt +++ /dev/null @@ -1,17 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.jafama - -import net.jafama.FastMath - - -fun main(){ - val a = JafamaDoubleField.number(2.0) - val b = StrictJafamaDoubleField.power(FastMath.E,a) - - println(JafamaDoubleField.add(b,a)) - println(StrictJafamaDoubleField.ln(b)) -} diff --git a/examples/src/main/kotlin/space/kscience/kmath/linear/dotPerformance.kt b/examples/src/main/kotlin/space/kscience/kmath/linear/dotPerformance.kt new file mode 100644 index 000000000..6e8767a5b --- /dev/null +++ b/examples/src/main/kotlin/space/kscience/kmath/linear/dotPerformance.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.linear + +import space.kscience.kmath.operations.algebra +import kotlin.random.Random +import kotlin.system.measureTimeMillis + +fun main() { + val random = Random(12224) + val dim = 1000 + + //creating invertible matrix + val matrix1 = Double.algebra.linearSpace.buildMatrix(dim, dim) { i, j -> + if (i <= j) random.nextDouble() else 0.0 + } + val matrix2 = Double.algebra.linearSpace.buildMatrix(dim, dim) { i, j -> + if (i <= j) random.nextDouble() else 0.0 + } + + val time = measureTimeMillis { + with(Double.algebra.linearSpace) { + repeat(10) { + matrix1 dot matrix2 + } + } + } + + println(time) + +} \ No newline at end of file diff --git a/examples/src/main/kotlin/space/kscience/kmath/linear/gradient.kt b/examples/src/main/kotlin/space/kscience/kmath/linear/gradient.kt index a01ea7fe2..afc42ea26 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/linear/gradient.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/linear/gradient.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear diff --git a/examples/src/main/kotlin/space/kscience/kmath/operations/BigIntDemo.kt b/examples/src/main/kotlin/space/kscience/kmath/operations/BigIntDemo.kt index 51f439612..2039953b5 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/operations/BigIntDemo.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/operations/BigIntDemo.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations diff --git a/examples/src/main/kotlin/space/kscience/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/space/kscience/kmath/operations/ComplexDemo.kt deleted file mode 100644 index f99dd8c0e..000000000 --- a/examples/src/main/kotlin/space/kscience/kmath/operations/ComplexDemo.kt +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.operations - -import space.kscience.kmath.complex.Complex -import space.kscience.kmath.complex.complex -import space.kscience.kmath.nd.AlgebraND - -fun main() { - // 2d element - val element = AlgebraND.complex(2, 2).produce { (i, j) -> - Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble()) - } - println(element) - - // 1d element operation - val result = with(AlgebraND.complex(8)) { - val a = produce { (it) -> i * it - it.toDouble() } - val b = 3 - val c = Complex(1.0, 1.0) - - (a pow b) + c - } - - println(result) -} diff --git a/examples/src/main/kotlin/space/kscience/kmath/operations/complexDemo.kt b/examples/src/main/kotlin/space/kscience/kmath/operations/complexDemo.kt new file mode 100644 index 000000000..3b9c32f4b --- /dev/null +++ b/examples/src/main/kotlin/space/kscience/kmath/operations/complexDemo.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.operations + +import space.kscience.kmath.complex.Complex +import space.kscience.kmath.complex.algebra +import space.kscience.kmath.complex.bufferAlgebra +import space.kscience.kmath.complex.ndAlgebra +import space.kscience.kmath.nd.BufferND +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.nd.structureND + +fun main() = Complex.algebra { + val complex = 2 + 2 * i + println(complex * 8 - 5 * i) + + //flat buffer + val buffer = with(bufferAlgebra){ + buffer(8) { Complex(it, -it) }.map { Complex(it.im, it.re) } + } + println(buffer) + + // 2d element + val element: BufferND = ndAlgebra.structureND(2, 2) { (i, j) -> + Complex(i - j, i + j) + } + println(element) + + // 1d element operation + val result: StructureND = ndAlgebra{ + val a = structureND(8) { (it) -> i * it - it.toDouble() } + val b = 3 + val c = Complex(1.0, 1.0) + + (a pow b) + c + } + println(result) +} diff --git a/examples/src/main/kotlin/space/kscience/kmath/operations/mixedNDOperations.kt b/examples/src/main/kotlin/space/kscience/kmath/operations/mixedNDOperations.kt new file mode 100644 index 000000000..f517046ee --- /dev/null +++ b/examples/src/main/kotlin/space/kscience/kmath/operations/mixedNDOperations.kt @@ -0,0 +1,24 @@ +package space.kscience.kmath.operations + +import space.kscience.kmath.commons.linear.CMLinearSpace +import space.kscience.kmath.linear.matrix +import space.kscience.kmath.nd.DoubleBufferND +import space.kscience.kmath.nd.Shape +import space.kscience.kmath.nd.Structure2D +import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.viktor.ViktorStructureND +import space.kscience.kmath.viktor.viktorAlgebra + +fun main() { + val viktorStructure: ViktorStructureND = DoubleField.viktorAlgebra.structureND(Shape(2, 2)) { (i, j) -> + if (i == j) 2.0 else 0.0 + } + + val cmMatrix: Structure2D = CMLinearSpace.matrix(2, 2)(0.0, 1.0, 0.0, 3.0) + + val res: DoubleBufferND = DoubleField.ndAlgebra { + exp(viktorStructure) + 2.0 * cmMatrix + } + + println(res) +} \ No newline at end of file diff --git a/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionBenchmark.kt b/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionBenchmark.kt index 8e3cdf86f..732c9a8e3 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionBenchmark.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat diff --git a/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionDemo.kt b/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionDemo.kt index b319766e3..685214c39 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionDemo.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionDemo.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat @@ -10,9 +10,6 @@ import space.kscience.kmath.chains.Chain import space.kscience.kmath.chains.collectWithState import space.kscience.kmath.distributions.NormalDistribution -/** - * The state of distribution averager. - */ private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) /** diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt index b30165f71..61df3d065 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/ComplexND.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:Suppress("unused") @@ -9,10 +9,11 @@ package space.kscience.kmath.structures import space.kscience.kmath.complex.* import space.kscience.kmath.linear.transpose -import space.kscience.kmath.nd.AlgebraND import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.as2D -import space.kscience.kmath.nd.real +import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.structureND +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke import kotlin.system.measureTimeMillis @@ -20,8 +21,8 @@ fun main() { val dim = 1000 val n = 1000 - val realField = AlgebraND.real(dim, dim) - val complexField: ComplexFieldND = AlgebraND.complex(dim, dim) + val realField = DoubleField.ndAlgebra(dim, dim) + val complexField: ComplexFieldND = ComplexField.ndAlgebra(dim, dim) val realTime = measureTimeMillis { realField { @@ -49,12 +50,12 @@ fun main() { fun complexExample() { //Create a context for 2-d structure with complex values ComplexField { - nd(4, 8) { + withNdAlgebra(4, 8) { //a constant real-valued structure val x = one * 2.5 operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im) //a structure generator specific to this context - val matrix = produce { (k, l) -> k + l * i } + val matrix = structureND { (k, l) -> k + l * i } //Perform sum val sum = matrix + x + 1.0 diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/NDField.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/NDField.kt index 501bf98db..cf0721ce7 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/NDField.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -9,10 +9,10 @@ import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.GlobalScope import org.nd4j.linalg.factory.Nd4j import space.kscience.kmath.nd.* -import space.kscience.kmath.nd4j.Nd4jArrayField +import space.kscience.kmath.nd4j.nd4j import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke -import space.kscience.kmath.viktor.ViktorNDField +import space.kscience.kmath.viktor.ViktorFieldND import kotlin.contracts.InvocationKind import kotlin.contracts.contract import kotlin.system.measureTimeMillis @@ -29,37 +29,39 @@ fun main() { Nd4j.zeros(0) val dim = 1000 val n = 1000 + val shape = Shape(dim, dim) + // automatically build context most suited for given type. - val autoField = AlgebraND.auto(DoubleField, dim, dim) - // specialized nd-field for Double. It works as generic Double field as well - val realField = AlgebraND.real(dim, dim) + val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto) + // specialized nd-field for Double. It works as generic Double field as well. + val realField = DoubleField.ndAlgebra //A generic boxing field. It should be used for objects, not primitives. - val boxingField = AlgebraND.field(DoubleField, Buffer.Companion::boxing, dim, dim) + val boxingField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing) // Nd4j specialized field. - val nd4jField = Nd4jArrayField.real(dim, dim) + val nd4jField = DoubleField.nd4j //viktor field - val viktorField = ViktorNDField(dim, dim) + val viktorField = ViktorFieldND(dim, dim) //parallel processing based on Java Streams - val parallelField = AlgebraND.realWithStream(dim, dim) + val parallelField = DoubleField.ndStreaming(dim, dim) measureAndPrint("Boxing addition") { boxingField { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } } } measureAndPrint("Specialized addition") { realField { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } } } measureAndPrint("Nd4j specialized addition") { nd4jField { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } } } @@ -80,13 +82,13 @@ fun main() { measureAndPrint("Automatic field addition") { autoField { - var res: StructureND = one + var res: StructureND = one(shape) repeat(n) { res += 1.0 } } } measureAndPrint("Lazy addition") { - val res = realField.one.mapAsync(GlobalScope) { + val res = realField.one(shape).mapAsync(GlobalScope) { var c = 0.0 repeat(n) { c += 1.0 diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt index d53cfa9b9..05a13f5d2 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -8,7 +8,7 @@ package space.kscience.kmath.structures import space.kscience.kmath.nd.* import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations +import space.kscience.kmath.operations.NumbersAddOps import java.util.* import java.util.stream.IntStream @@ -17,17 +17,17 @@ import java.util.stream.IntStream * execution. */ class StreamDoubleFieldND(override val shape: IntArray) : FieldND, - NumbersAddOperations>, + NumbersAddOps>, ExtendedField> { private val strides = DefaultStrides(shape) - override val elementContext: DoubleField get() = DoubleField - override val zero: BufferND by lazy { produce { zero } } - override val one: BufferND by lazy { produce { one } } + override val elementAlgebra: DoubleField get() = DoubleField + override val zero: BufferND by lazy { structureND(shape) { zero } } + override val one: BufferND by lazy { structureND(shape) { one } } override fun number(value: Number): BufferND { val d = value.toDouble() // minimize conversions - return produce { d } + return structureND(shape) { d } } private val StructureND.buffer: DoubleBuffer @@ -36,11 +36,11 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND this.buffer as DoubleBuffer + this is BufferND && this.indices == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } } - override fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND { + override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): BufferND { val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> val index = strides.index(offset) DoubleField.initializer(index) @@ -69,13 +69,13 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND, - b: StructureND, + override fun zip( + left: StructureND, + right: StructureND, transform: DoubleField.(Double, Double) -> Double, ): BufferND { val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> - DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset]) + DoubleField.transform(left.buffer.array[offset], right.buffer.array[offset]) }.toArray() return BufferND(strides, array.asBuffer()) } @@ -105,4 +105,4 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND): BufferND = arg.map { atanh(it) } } -fun AlgebraND.Companion.realWithStream(vararg shape: Int): StreamDoubleFieldND = StreamDoubleFieldND(shape) +fun DoubleField.ndStreaming(vararg shape: Int): StreamDoubleFieldND = StreamDoubleFieldND(shape) diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/StructureReadBenchmark.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/StructureReadBenchmark.kt index 0d5358354..db77129a2 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/StructureReadBenchmark.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/StructureReadBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -19,24 +19,24 @@ fun main() { measureTimeMillis { var res = 0.0 - strides.indices().forEach { res = structure[it] } + strides.asSequence().forEach { res = structure[it] } } // warmup val time1 = measureTimeMillis { var res = 0.0 - strides.indices().forEach { res = structure[it] } + strides.asSequence().forEach { res = structure[it] } } println("Structure reading finished in $time1 millis") val time2 = measureTimeMillis { var res = 0.0 - strides.indices().forEach { res = buffer[strides.offset(it)] } + strides.asSequence().forEach { res = buffer[strides.offset(it)] } } println("Buffer reading finished in $time2 millis") val time3 = measureTimeMillis { var res = 0.0 - strides.indices().forEach { res = array[strides.offset(it)] } + strides.asSequence().forEach { res = array[strides.offset(it)] } } println("Array reading finished in $time3 millis") } diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/StructureWriteBenchmark.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/StructureWriteBenchmark.kt index dea7095a8..84644ddd9 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/StructureWriteBenchmark.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/StructureWriteBenchmark.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt new file mode 100644 index 000000000..889ea99bd --- /dev/null +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/buffers.kt @@ -0,0 +1,23 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.structures + +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.buffer +import space.kscience.kmath.operations.bufferAlgebra +import space.kscience.kmath.operations.withSize + +inline fun MutableBuffer.Companion.same( + n: Int, + value: R +): MutableBuffer = auto(n) { value } + + +fun main() { + with(DoubleField.bufferAlgebra.withSize(5)) { + println(number(2.0) + buffer(1, 2, 3, 4, 5)) + } +} diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/typeSafeDimensions.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/typeSafeDimensions.kt index 955f86fa9..853ebad32 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/typeSafeDimensions.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/typeSafeDimensions.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -19,7 +19,7 @@ private fun DMatrixContext.simple() { } private object D5 : Dimension { - override val dim: UInt = 5u + override val dim: Int = 5 } private fun DMatrixContext.custom() { diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt index b42602988..a266d4849 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/OLSWithSVD.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt index 411e048d7..d83d47805 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/PCA.kt @@ -1,23 +1,23 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors -import space.kscience.kmath.operations.invoke -import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra +import space.kscience.kmath.tensors.core.tensorAlgebra +import space.kscience.kmath.tensors.core.withBroadcast // simple PCA -fun main(): Unit = BroadcastDoubleTensorAlgebra { // work in context with broadcast methods +fun main(): Unit = Double.tensorAlgebra.withBroadcast { // work in context with broadcast methods val seed = 100500L // assume x is range from 0 until 10 val x = fromArray( intArrayOf(10), - (0 until 10).toList().map { it.toDouble() }.toDoubleArray() + DoubleArray(10) { it.toDouble() } ) // take y dependent on x with noise @@ -62,11 +62,11 @@ fun main(): Unit = BroadcastDoubleTensorAlgebra { // work in context with broad println("Eigenvector:\n$v") // reduce dimension of dataset - val datasetReduced = v dot stack(listOf(xScaled, yScaled)) + val datasetReduced = v dot stack(listOf(xScaled, yScaled)) println("Reduced data:\n$datasetReduced") - // we can restore original data from reduced data. - // for example, find 7th element of dataset + // we can restore original data from reduced data; + // for example, find 7th element of dataset. val n = 7 val restored = (datasetReduced[n] dot v.view(intArrayOf(1, 2))) * std + mean println("Original value:\n${dataset[n]}") diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/DataSetNormalization.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/dataSetNormalization.kt similarity index 71% rename from examples/src/main/kotlin/space/kscience/kmath/tensors/DataSetNormalization.kt rename to examples/src/main/kotlin/space/kscience/kmath/tensors/dataSetNormalization.kt index 74795cc68..9d5b8c2a5 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/DataSetNormalization.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/dataSetNormalization.kt @@ -1,23 +1,23 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors -import space.kscience.kmath.operations.invoke -import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra +import space.kscience.kmath.tensors.core.tensorAlgebra +import space.kscience.kmath.tensors.core.withBroadcast // Dataset normalization -fun main() = BroadcastDoubleTensorAlgebra { // work in context with broadcast methods +fun main() = Double.tensorAlgebra.withBroadcast { // work in context with broadcast methods // take dataset of 5-element vectors from normal distribution val dataset = randomNormal(intArrayOf(100, 5)) * 1.5 // all elements from N(0, 1.5) dataset += fromArray( intArrayOf(5), - doubleArrayOf(0.0, 1.0, 1.5, 3.0, 5.0) // rows means + doubleArrayOf(0.0, 1.0, 1.5, 3.0, 5.0) // row means ) @@ -28,7 +28,7 @@ fun main() = BroadcastDoubleTensorAlgebra { // work in context with broadcast m println("Mean:\n$mean") println("Standard deviation:\n$std") - // also we can calculate other statistic as minimum and maximum of rows + // also, we can calculate other statistic as minimum and maximum of rows println("Minimum:\n${dataset.min(0, false)}") println("Maximum:\n${dataset.max(0, false)}") diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/LinearSystemSolvingWithLUP.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/linearSystemSolvingWithLUP.kt similarity index 88% rename from examples/src/main/kotlin/space/kscience/kmath/tensors/LinearSystemSolvingWithLUP.kt rename to examples/src/main/kotlin/space/kscience/kmath/tensors/linearSystemSolvingWithLUP.kt index 6453ca44e..846e338da 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/LinearSystemSolvingWithLUP.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/linearSystemSolvingWithLUP.kt @@ -1,17 +1,17 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors -import space.kscience.kmath.operations.invoke -import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensor +import space.kscience.kmath.tensors.core.tensorAlgebra +import space.kscience.kmath.tensors.core.withBroadcast // solving linear system with LUP decomposition -fun main() = BroadcastDoubleTensorAlgebra {// work in context with linear operations +fun main() = Double.tensorAlgebra.withBroadcast {// work in context with linear operations // set true value of x val trueX = fromArray( @@ -42,7 +42,7 @@ fun main() = BroadcastDoubleTensorAlgebra {// work in context with linear operat // get P, L, U such that PA = LU val (p, l, u) = a.lu() - // check that P is permutation matrix + // check P is permutation matrix println("P:\n$p") // L is lower triangular matrix and U is upper triangular matrix println("L:\n$l") diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt new file mode 100644 index 000000000..fad68fa96 --- /dev/null +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt @@ -0,0 +1,18 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.tensors + +import org.jetbrains.kotlinx.multik.api.Multik +import org.jetbrains.kotlinx.multik.api.ndarray +import space.kscience.kmath.multik.multikAlgebra +import space.kscience.kmath.nd.one +import space.kscience.kmath.operations.DoubleField + +fun main(): Unit = with(DoubleField.multikAlgebra) { + val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType().wrap() + val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap() + one(a.shape) - a + b * 3.0 +} diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/NeuralNetwork.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt similarity index 96% rename from examples/src/main/kotlin/space/kscience/kmath/tensors/NeuralNetwork.kt rename to examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt index 1e961fc7b..1a2a94534 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/NeuralNetwork.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors @@ -9,7 +9,7 @@ import space.kscience.kmath.operations.invoke import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensorAlgebra -import space.kscience.kmath.tensors.core.toDoubleArray +import space.kscience.kmath.tensors.core.copyArray import kotlin.math.sqrt const val seed = 100500L @@ -111,7 +111,7 @@ class NeuralNetwork(private val layers: List) { private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra { val onesForAnswers = yPred.zeroesLike() - yTrue.toDoubleArray().forEachIndexed { index, labelDouble -> + yTrue.copyArray().forEachIndexed { index, labelDouble -> val label = labelDouble.toInt() onesForAnswers[intArrayOf(index, label)] = 1.0 } @@ -186,7 +186,7 @@ fun main() = BroadcastDoubleTensorAlgebra { x += fromArray( intArrayOf(5), - doubleArrayOf(0.0, -1.0, -2.5, -3.0, 5.5) // rows means + doubleArrayOf(0.0, -1.0, -2.5, -3.0, 5.5) // row means ) diff --git a/gradle.properties b/gradle.properties index 3aaade368..959511c68 100644 --- a/gradle.properties +++ b/gradle.properties @@ -4,10 +4,13 @@ # kotlin.code.style=official -kotlin.mpp.enableGranularSourceSetsMetadata=true kotlin.mpp.stability.nowarn=true -kotlin.native.enableDependencyPropagation=false -kotlin.parallel.tasks.in.project=true + +#kotlin.mpp.enableGranularSourceSetsMetadata=true +#kotlin.native.enableDependencyPropagation=false + +kotlin.jupyter.add.scanner=false + org.gradle.configureondemand=true org.gradle.jvmargs=-XX:MaxMetaspaceSize=2G org.gradle.parallel=true diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index e708b1c02..7454180f2 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index f371643ee..ffed3a254 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.0-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 4f906e0c8..744e882ed 100755 --- a/gradlew +++ b/gradlew @@ -72,7 +72,7 @@ case "`uname`" in Darwin* ) darwin=true ;; - MINGW* ) + MSYS* | MINGW* ) msys=true ;; NONSTOP* ) diff --git a/kmath-ast/README.md b/kmath-ast/README.md index 026c5a625..686506f6f 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -10,7 +10,7 @@ Performance and visualization extensions to MST API. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-13`. +The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-14`. **Gradle:** ```gradle @@ -20,7 +20,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-ast:0.3.0-dev-13' + implementation 'space.kscience:kmath-ast:0.3.0-dev-14' } ``` **Gradle Kotlin DSL:** @@ -31,7 +31,7 @@ repositories { } dependencies { - implementation("space.kscience:kmath-ast:0.3.0-dev-13") + implementation("space.kscience:kmath-ast:0.3.0-dev-14") } ``` @@ -106,7 +106,7 @@ var executable = function (constants, arguments) { }; ``` -JS also supports very experimental expression optimization with [WebAssembly](https://webassembly.org/) IR generation. +JS also supports experimental expression optimization with [WebAssembly](https://webassembly.org/) IR generation. Currently, only expressions inside `DoubleField` and `IntRing` are supported. ```kotlin @@ -161,7 +161,10 @@ public fun main() { Result LaTeX: +
+ ![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3}) +
Result MathML (can be used with MathJax or other renderers): diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index 8209a0dad..9de7e9980 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -45,8 +45,7 @@ kotlin.sourceSets { jvmMain { dependencies { - implementation("org.ow2.asm:asm:9.1") - implementation("org.ow2.asm:asm-commons:9.1") + implementation("org.ow2.asm:asm-commons:9.2") } } } diff --git a/kmath-ast/docs/README-TEMPLATE.md b/kmath-ast/docs/README-TEMPLATE.md index b90f8ff08..9494af63a 100644 --- a/kmath-ast/docs/README-TEMPLATE.md +++ b/kmath-ast/docs/README-TEMPLATE.md @@ -77,7 +77,7 @@ var executable = function (constants, arguments) { }; ``` -JS also supports very experimental expression optimization with [WebAssembly](https://webassembly.org/) IR generation. +JS also supports experimental expression optimization with [WebAssembly](https://webassembly.org/) IR generation. Currently, only expressions inside `DoubleField` and `IntRing` are supported. ```kotlin @@ -132,7 +132,10 @@ public fun main() { Result LaTeX: +
+ ![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3}) +
Result MathML (can be used with MathJax or other renderers): diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt index 5201fec38..7f2780548 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/parser.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast @@ -17,11 +17,11 @@ import com.github.h0tk3y.betterParse.lexer.regexToken import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.Parser import space.kscience.kmath.expressions.MST -import space.kscience.kmath.expressions.StringSymbol -import space.kscience.kmath.operations.FieldOperations -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.operations.FieldOps +import space.kscience.kmath.operations.GroupOps import space.kscience.kmath.operations.PowerOperations -import space.kscience.kmath.operations.RingOperations +import space.kscience.kmath.operations.RingOps /** * better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4. @@ -43,7 +43,7 @@ public object ArithmeticsEvaluator : Grammar() { private val ws: Token by regexToken("\\s+".toRegex(), ignore = true) private val number: Parser by num use { MST.Numeric(text.toDouble()) } - private val singular: Parser by id use { StringSymbol(text) } + private val singular: Parser by id use { Symbol(text) } private val unaryFunction: Parser by (id and -lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar) .map { (id, term) -> MST.Unary(id.text, term) } @@ -60,7 +60,7 @@ public object ArithmeticsEvaluator : Grammar() { .or(binaryFunction) .or(unaryFunction) .or(singular) - .or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOperations.MINUS_OPERATION, it) }) + .or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOps.MINUS_OPERATION, it) }) .or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar) private val powChain: Parser by leftAssociative(term = term, operator = pow) { a, _, b -> @@ -72,9 +72,9 @@ public object ArithmeticsEvaluator : Grammar() { operator = div or mul use TokenMatch::type ) { a, op, b -> if (op == div) - MST.Binary(FieldOperations.DIV_OPERATION, a, b) + MST.Binary(FieldOps.DIV_OPERATION, a, b) else - MST.Binary(RingOperations.TIMES_OPERATION, a, b) + MST.Binary(RingOps.TIMES_OPERATION, a, b) } private val subSumChain: Parser by leftAssociative( @@ -82,9 +82,9 @@ public object ArithmeticsEvaluator : Grammar() { operator = plus or minus use TokenMatch::type ) { a, op, b -> if (op == plus) - MST.Binary(GroupOperations.PLUS_OPERATION, a, b) + MST.Binary(GroupOps.PLUS_OPERATION, a, b) else - MST.Binary(GroupOperations.MINUS_OPERATION, a, b) + MST.Binary(GroupOps.MINUS_OPERATION, a, b) } override val rootParser: Parser by subSumChain diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt index 01717b0f9..bf5916fa5 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering @@ -27,7 +27,7 @@ import space.kscience.kmath.misc.UnstableKMathAPI */ @UnstableKMathAPI public object LatexSyntaxRenderer : SyntaxRenderer { - public override fun render(node: MathSyntax, output: Appendable): Unit = output.run { + override fun render(node: MathSyntax, output: Appendable): Unit = output.run { fun render(syntax: MathSyntax) = render(syntax, output) when (node) { diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt index cda8e2322..5439c42fa 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering @@ -16,7 +16,7 @@ import space.kscience.kmath.misc.UnstableKMathAPI */ @UnstableKMathAPI public object MathMLSyntaxRenderer : SyntaxRenderer { - public override fun render(node: MathSyntax, output: Appendable) { + override fun render(node: MathSyntax, output: Appendable) { output.append("") renderPart(node, output) output.append("") diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt index 68d829724..24bac425a 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering @@ -29,7 +29,7 @@ public fun interface MathRenderer { */ @UnstableKMathAPI public open class FeaturedMathRenderer(public val features: List) : MathRenderer { - public override fun render(mst: MST): MathSyntax { + override fun render(mst: MST): MathSyntax { for (feature in features) feature.render(this, mst)?.let { return it } throw UnsupportedOperationException("Renderer $this has no appropriate feature to render node $mst.") } @@ -56,7 +56,7 @@ public open class FeaturedMathRendererWithPostProcess( features: List, public val stages: List, ) : FeaturedMathRenderer(features) { - public override fun render(mst: MST): MathSyntax { + override fun render(mst: MST): MathSyntax { val res = super.render(mst) for (stage in stages) stage.perform(res) return res diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt index a71985fbc..81b7d2afb 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering @@ -8,7 +8,7 @@ package space.kscience.kmath.ast.rendering import space.kscience.kmath.misc.UnstableKMathAPI /** - * Mathematical typography syntax node. + * Syntax node for mathematical typography. * * @author Iaroslav Postovalov */ @@ -150,9 +150,9 @@ public data class OperandSyntax( */ @UnstableKMathAPI public data class UnaryOperatorSyntax( - public override val operation: String, + override val operation: String, public var prefix: MathSyntax, - public override val operand: OperandSyntax, + override val operand: OperandSyntax, ) : UnarySyntax() { init { operand.parent = this @@ -166,8 +166,8 @@ public data class UnaryOperatorSyntax( */ @UnstableKMathAPI public data class UnaryPlusSyntax( - public override val operation: String, - public override val operand: OperandSyntax, + override val operation: String, + override val operand: OperandSyntax, ) : UnarySyntax() { init { operand.parent = this @@ -181,8 +181,8 @@ public data class UnaryPlusSyntax( */ @UnstableKMathAPI public data class UnaryMinusSyntax( - public override val operation: String, - public override val operand: OperandSyntax, + override val operation: String, + override val operand: OperandSyntax, ) : UnarySyntax() { init { operand.parent = this @@ -197,8 +197,8 @@ public data class UnaryMinusSyntax( */ @UnstableKMathAPI public data class RadicalSyntax( - public override val operation: String, - public override val operand: MathSyntax, + override val operation: String, + override val operand: MathSyntax, ) : UnarySyntax() { init { operand.parent = this @@ -215,8 +215,8 @@ public data class RadicalSyntax( */ @UnstableKMathAPI public data class ExponentSyntax( - public override val operation: String, - public override val operand: OperandSyntax, + override val operation: String, + override val operand: OperandSyntax, public var useOperatorForm: Boolean, ) : UnarySyntax() { init { @@ -233,9 +233,9 @@ public data class ExponentSyntax( */ @UnstableKMathAPI public data class SuperscriptSyntax( - public override val operation: String, - public override val left: MathSyntax, - public override val right: MathSyntax, + override val operation: String, + override val left: MathSyntax, + override val right: MathSyntax, ) : BinarySyntax() { init { left.parent = this @@ -252,9 +252,9 @@ public data class SuperscriptSyntax( */ @UnstableKMathAPI public data class SubscriptSyntax( - public override val operation: String, - public override val left: MathSyntax, - public override val right: MathSyntax, + override val operation: String, + override val left: MathSyntax, + override val right: MathSyntax, ) : BinarySyntax() { init { left.parent = this @@ -270,10 +270,10 @@ public data class SubscriptSyntax( */ @UnstableKMathAPI public data class BinaryOperatorSyntax( - public override val operation: String, + override val operation: String, public var prefix: MathSyntax, - public override val left: MathSyntax, - public override val right: MathSyntax, + override val left: MathSyntax, + override val right: MathSyntax, ) : BinarySyntax() { init { left.parent = this @@ -290,9 +290,9 @@ public data class BinaryOperatorSyntax( */ @UnstableKMathAPI public data class BinaryPlusSyntax( - public override val operation: String, - public override val left: OperandSyntax, - public override val right: OperandSyntax, + override val operation: String, + override val left: OperandSyntax, + override val right: OperandSyntax, ) : BinarySyntax() { init { left.parent = this @@ -301,7 +301,7 @@ public data class BinaryPlusSyntax( } /** - * Represents binary, infix subtraction (*42 - 42*). + * Represents binary, infix subtraction (*42 − 42*). * * @param left The minuend. * @param right The subtrahend. @@ -309,9 +309,9 @@ public data class BinaryPlusSyntax( */ @UnstableKMathAPI public data class BinaryMinusSyntax( - public override val operation: String, - public override val left: OperandSyntax, - public override val right: OperandSyntax, + override val operation: String, + override val left: OperandSyntax, + override val right: OperandSyntax, ) : BinarySyntax() { init { left.parent = this @@ -329,9 +329,9 @@ public data class BinaryMinusSyntax( */ @UnstableKMathAPI public data class FractionSyntax( - public override val operation: String, - public override val left: OperandSyntax, - public override val right: OperandSyntax, + override val operation: String, + override val left: OperandSyntax, + override val right: OperandSyntax, public var infix: Boolean, ) : BinarySyntax() { init { @@ -349,9 +349,9 @@ public data class FractionSyntax( */ @UnstableKMathAPI public data class RadicalWithIndexSyntax( - public override val operation: String, - public override val left: MathSyntax, - public override val right: MathSyntax, + override val operation: String, + override val left: MathSyntax, + override val right: MathSyntax, ) : BinarySyntax() { init { left.parent = this @@ -369,9 +369,9 @@ public data class RadicalWithIndexSyntax( */ @UnstableKMathAPI public data class MultiplicationSyntax( - public override val operation: String, - public override val left: OperandSyntax, - public override val right: OperandSyntax, + override val operation: String, + override val left: OperandSyntax, + override val right: OperandSyntax, public var times: Boolean, ) : BinarySyntax() { init { diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/SyntaxRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/SyntaxRenderer.kt index fb2b3b66f..2f285c600 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/SyntaxRenderer.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/SyntaxRenderer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering @@ -9,7 +9,7 @@ import space.kscience.kmath.misc.UnstableKMathAPI /** * Abstraction of writing [MathSyntax] as a string of an actual markup language. Typical implementation should - * involve traversal of MathSyntax with handling each its subtype. + * involve traversal of MathSyntax with handling each subtype. * * @author Iaroslav Postovalov */ diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt index a2f42d1bf..8b76b6f19 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering @@ -13,7 +13,7 @@ import space.kscience.kmath.operations.* import kotlin.reflect.KClass /** - * Prints any [Symbol] as a [SymbolSyntax] containing the [Symbol.value] of it. + * Prints any [Symbol] as a [SymbolSyntax] containing the [Symbol.identity] of it. * * @author Iaroslav Postovalov */ @@ -39,7 +39,7 @@ public val PrintNumeric: RenderFeature = RenderFeature { _, node -> @UnstableKMathAPI private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-')) UnaryMinusSyntax( - operation = GroupOperations.MINUS_OPERATION, + operation = GroupOps.MINUS_OPERATION, operand = OperandSyntax( operand = NumberSyntax(string = s.removePrefix("-")), parentheses = true, @@ -49,7 +49,7 @@ else NumberSyntax(string = s) /** - * Special printing for numeric types which are printed in form of + * Special printing for numeric types that are printed in form of * *('-'? (DIGIT+ ('.' DIGIT+)? ('E' '-'? DIGIT+)? | 'Infinity')) | 'NaN'*. * * @property types The suitable types. @@ -57,7 +57,7 @@ else */ @UnstableKMathAPI public class PrettyPrintFloats(public val types: Set>) : RenderFeature { - public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? { + override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? { if (node !is MST.Numeric || node.value::class !in types) return null val toString = when (val v = node.value) { @@ -72,7 +72,7 @@ public class PrettyPrintFloats(public val types: Set>) : Rend val exponent = afterE.toDouble().toString().removeSuffix(".0") return MultiplicationSyntax( - operation = RingOperations.TIMES_OPERATION, + operation = RingOps.TIMES_OPERATION, left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true), right = OperandSyntax( operand = SuperscriptSyntax( @@ -91,7 +91,7 @@ public class PrettyPrintFloats(public val types: Set>) : Rend if (toString.startsWith('-')) return UnaryMinusSyntax( - operation = GroupOperations.MINUS_OPERATION, + operation = GroupOps.MINUS_OPERATION, operand = OperandSyntax(operand = infty, parentheses = true), ) @@ -110,14 +110,14 @@ public class PrettyPrintFloats(public val types: Set>) : Rend } /** - * Special printing for numeric types which are printed in form of *'-'? DIGIT+*. + * Special printing for numeric types that are printed in form of *'-'? DIGIT+*. * * @property types The suitable types. * @author Iaroslav Postovalov */ @UnstableKMathAPI public class PrettyPrintIntegers(public val types: Set>) : RenderFeature { - public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? = + override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? = if (node !is MST.Numeric || node.value::class !in types) null else @@ -140,7 +140,7 @@ public class PrettyPrintIntegers(public val types: Set>) : Re */ @UnstableKMathAPI public class PrettyPrintPi(public val symbols: Set) : RenderFeature { - public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? = + override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? = if (node !is Symbol || node.identity !in symbols) null else @@ -155,7 +155,7 @@ public class PrettyPrintPi(public val symbols: Set) : RenderFeature { } /** - * Abstract printing of unary operations which discards [MST] if their operation is not in [operations] or its type is + * Abstract printing of unary operations that discards [MST] if their operation is not in [operations] or its type is * not [MST.Unary]. * * @param operations the allowed operations. If `null`, any operation is accepted. @@ -176,7 +176,7 @@ public abstract class Unary(public val operations: Collection?) : Render } /** - * Abstract printing of unary operations which discards [MST] if their operation is not in [operations] or its type is + * Abstract printing of unary operations that discards [MST] if their operation is not in [operations] or its type is * not [MST.Binary]. * * @property operations the allowed operations. If `null`, any operation is accepted. @@ -202,7 +202,7 @@ public abstract class Binary(public val operations: Collection?) : Rende */ @UnstableKMathAPI public class BinaryPlus(operations: Collection?) : Binary(operations) { - public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = + override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryPlusSyntax( operation = node.operation, left = OperandSyntax(parent.render(node.left), true), @@ -211,9 +211,9 @@ public class BinaryPlus(operations: Collection?) : Binary(operations) { public companion object { /** - * The default instance configured with [GroupOperations.PLUS_OPERATION]. + * The default instance configured with [GroupOps.PLUS_OPERATION]. */ - public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION)) + public val Default: BinaryPlus = BinaryPlus(setOf(GroupOps.PLUS_OPERATION)) } } @@ -224,7 +224,7 @@ public class BinaryPlus(operations: Collection?) : Binary(operations) { */ @UnstableKMathAPI public class BinaryMinus(operations: Collection?) : Binary(operations) { - public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = + override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryMinusSyntax( operation = node.operation, left = OperandSyntax(operand = parent.render(node.left), parentheses = true), @@ -233,9 +233,9 @@ public class BinaryMinus(operations: Collection?) : Binary(operations) { public companion object { /** - * The default instance configured with [GroupOperations.MINUS_OPERATION]. + * The default instance configured with [GroupOps.MINUS_OPERATION]. */ - public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION)) + public val Default: BinaryMinus = BinaryMinus(setOf(GroupOps.MINUS_OPERATION)) } } @@ -246,16 +246,16 @@ public class BinaryMinus(operations: Collection?) : Binary(operations) { */ @UnstableKMathAPI public class UnaryPlus(operations: Collection?) : Unary(operations) { - public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryPlusSyntax( + override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryPlusSyntax( operation = node.operation, operand = OperandSyntax(operand = parent.render(node.value), parentheses = true), ) public companion object { /** - * The default instance configured with [GroupOperations.PLUS_OPERATION]. + * The default instance configured with [GroupOps.PLUS_OPERATION]. */ - public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION)) + public val Default: UnaryPlus = UnaryPlus(setOf(GroupOps.PLUS_OPERATION)) } } @@ -266,16 +266,16 @@ public class UnaryPlus(operations: Collection?) : Unary(operations) { */ @UnstableKMathAPI public class UnaryMinus(operations: Collection?) : Unary(operations) { - public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryMinusSyntax( + override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryMinusSyntax( operation = node.operation, operand = OperandSyntax(operand = parent.render(node.value), parentheses = true), ) public companion object { /** - * The default instance configured with [GroupOperations.MINUS_OPERATION]. + * The default instance configured with [GroupOps.MINUS_OPERATION]. */ - public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION)) + public val Default: UnaryMinus = UnaryMinus(setOf(GroupOps.MINUS_OPERATION)) } } @@ -286,7 +286,7 @@ public class UnaryMinus(operations: Collection?) : Unary(operations) { */ @UnstableKMathAPI public class Fraction(operations: Collection?) : Binary(operations) { - public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = FractionSyntax( + override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = FractionSyntax( operation = node.operation, left = OperandSyntax(operand = parent.render(node.left), parentheses = true), right = OperandSyntax(operand = parent.render(node.right), parentheses = true), @@ -295,9 +295,9 @@ public class Fraction(operations: Collection?) : Binary(operations) { public companion object { /** - * The default instance configured with [FieldOperations.DIV_OPERATION]. + * The default instance configured with [FieldOps.DIV_OPERATION]. */ - public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION)) + public val Default: Fraction = Fraction(setOf(FieldOps.DIV_OPERATION)) } } @@ -308,7 +308,7 @@ public class Fraction(operations: Collection?) : Binary(operations) { */ @UnstableKMathAPI public class BinaryOperator(operations: Collection?) : Binary(operations) { - public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = + override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryOperatorSyntax( operation = node.operation, prefix = OperatorNameSyntax(name = node.operation), @@ -331,7 +331,7 @@ public class BinaryOperator(operations: Collection?) : Binary(operations */ @UnstableKMathAPI public class UnaryOperator(operations: Collection?) : Unary(operations) { - public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = + override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax( operation = node.operation, prefix = OperatorNameSyntax(node.operation), @@ -353,7 +353,7 @@ public class UnaryOperator(operations: Collection?) : Unary(operations) */ @UnstableKMathAPI public class Power(operations: Collection?) : Binary(operations) { - public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = + override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = SuperscriptSyntax( operation = node.operation, left = OperandSyntax(parent.render(node.left), true), @@ -373,7 +373,7 @@ public class Power(operations: Collection?) : Binary(operations) { */ @UnstableKMathAPI public class SquareRoot(operations: Collection?) : Unary(operations) { - public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = + override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = RadicalSyntax(operation = node.operation, operand = parent.render(node.value)) public companion object { @@ -391,7 +391,7 @@ public class SquareRoot(operations: Collection?) : Unary(operations) { */ @UnstableKMathAPI public class Exponent(operations: Collection?) : Unary(operations) { - public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = ExponentSyntax( + override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = ExponentSyntax( operation = node.operation, operand = OperandSyntax(operand = parent.render(node.value), parentheses = true), useOperatorForm = true, @@ -412,7 +412,7 @@ public class Exponent(operations: Collection?) : Unary(operations) { */ @UnstableKMathAPI public class Multiplication(operations: Collection?) : Binary(operations) { - public override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = + override fun renderBinary(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = MultiplicationSyntax( operation = node.operation, left = OperandSyntax(operand = parent.render(node.left), parentheses = true), @@ -422,9 +422,9 @@ public class Multiplication(operations: Collection?) : Binary(operations public companion object { /** - * The default instance configured with [RingOperations.TIMES_OPERATION]. + * The default instance configured with [RingOps.TIMES_OPERATION]. */ - public val Default: Multiplication = Multiplication(setOf(RingOperations.TIMES_OPERATION)) + public val Default: Multiplication = Multiplication(setOf(RingOps.TIMES_OPERATION)) } } @@ -435,7 +435,7 @@ public class Multiplication(operations: Collection?) : Binary(operations */ @UnstableKMathAPI public class InverseTrigonometricOperations(operations: Collection?) : Unary(operations) { - public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = + override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax( operation = node.operation, prefix = OperatorNameSyntax(name = node.operation.replaceFirst("a", "arc")), @@ -462,7 +462,7 @@ public class InverseTrigonometricOperations(operations: Collection?) : U */ @UnstableKMathAPI public class InverseHyperbolicOperations(operations: Collection?) : Unary(operations) { - public override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = + override fun renderUnary(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax( operation = node.operation, prefix = OperatorNameSyntax(name = node.operation.replaceFirst("a", "ar")), diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt index 291399cee..3e33d6415 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/phases.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/phases.kt index 6da4994a6..ecea2d104 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/phases.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/phases.kt @@ -1,16 +1,16 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.FieldOperations -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.FieldOps +import space.kscience.kmath.operations.GroupOps import space.kscience.kmath.operations.PowerOperations -import space.kscience.kmath.operations.RingOperations +import space.kscience.kmath.operations.RingOps /** * Removes unnecessary times (×) symbols from [MultiplicationSyntax]. @@ -205,7 +205,7 @@ public val BetterExponent: PostProcessPhase = PostProcessPhase { node -> @UnstableKMathAPI public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) -> Int) : PostProcessPhase { - public override fun perform(node: MathSyntax): Unit = when (node) { + override fun perform(node: MathSyntax): Unit = when (node) { is NumberSyntax -> Unit is SymbolSyntax -> Unit is OperatorNameSyntax -> Unit @@ -306,10 +306,10 @@ public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) -> is BinarySyntax -> when (it.operation) { PowerOperations.POW_OPERATION -> 1 - RingOperations.TIMES_OPERATION -> 3 - FieldOperations.DIV_OPERATION -> 3 - GroupOperations.MINUS_OPERATION -> 4 - GroupOperations.PLUS_OPERATION -> 4 + RingOps.TIMES_OPERATION -> 3 + FieldOps.DIV_OPERATION -> 3 + GroupOps.MINUS_OPERATION -> 4 + GroupOps.PLUS_OPERATION -> 4 else -> 0 } diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerConsistencyWithInterpreter.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerConsistencyWithInterpreter.kt index 3116466e6..802d4c10e 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerConsistencyWithInterpreter.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerConsistencyWithInterpreter.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast @@ -11,7 +11,6 @@ import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.interpret import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.IntRing -import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt index 929d17775..f5b1e2842 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerOperations.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerVariables.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerVariables.kt index bed5bc7fa..8d9a2301f 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerVariables.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerVariables.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast @@ -9,7 +9,6 @@ import space.kscience.kmath.expressions.MstRing import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.invoke import space.kscience.kmath.operations.IntRing -import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt index b838245e1..4c834a9ca 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParser.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt index bb6bb3ce1..9776da45c 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestParserPrecedence.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt index a40c785b9..ae429d97e 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt index 6322df25d..aba713c43 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt @@ -1,13 +1,13 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering import space.kscience.kmath.ast.rendering.TestUtils.testLatex import space.kscience.kmath.expressions.MST -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.GroupOps import kotlin.test.Test internal class TestLatex { @@ -36,7 +36,7 @@ internal class TestLatex { fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)") @Test - fun unaryPlus() = testLatex(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1") + fun unaryPlus() = testLatex(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "+1") @Test fun unaryMinus() = testLatex("-x", "-x") diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt index 2d7bfad19..658ecd47a 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt @@ -1,13 +1,13 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering import space.kscience.kmath.ast.rendering.TestUtils.testMathML import space.kscience.kmath.expressions.MST -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.GroupOps import kotlin.test.Test internal class TestMathML { @@ -47,7 +47,7 @@ internal class TestMathML { @Test fun unaryPlus() = - testMathML(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1") + testMathML(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "+1") @Test fun unaryMinus() = testMathML("-x", "-x") diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestStages.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestStages.kt index 09ec127c7..4485605a6 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestStages.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestStages.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestUtils.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestUtils.kt index bf87b6fd0..6b418821b 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestUtils.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/rendering/TestUtils.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/utils.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/utils.kt index ec7436188..ef9f3145a 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/utils.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/utils.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt index 521907d2c..2e69a536f 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt index 0c15e994c..316fdeeff 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.estree diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt index 4907d8225..850f20be7 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.estree.internal diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/astring/astring.typealises.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/astring/astring.typealises.kt index eb5c1e3dd..c7faf73e0 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/astring/astring.typealises.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/astring/astring.typealises.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.estree.internal.astring diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/astring/astring.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/astring/astring.kt index cca2d83af..c36860654 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/astring/astring.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/astring/astring.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:JsModule("astring") diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/astring/astring.typealises.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/astring/astring.typealises.kt index 93b4f6ce6..0a5b059ba 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/astring/astring.typealises.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/astring/astring.typealises.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.internal.astring diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/base64/base64.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/base64/base64.kt index 86e0cede7..26186c453 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/base64/base64.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/base64/base64.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:Suppress( diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/binaryen/index.binaryen.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/binaryen/index.binaryen.kt index 42b6ac7d8..13e3a49e2 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/binaryen/index.binaryen.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/binaryen/index.binaryen.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:Suppress( diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/binaryen/index.binaryen.typealiases.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/binaryen/index.binaryen.typealiases.kt index 523b13b40..8e449627c 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/binaryen/index.binaryen.typealiases.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/binaryen/index.binaryen.typealiases.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:Suppress("PackageDirectoryMismatch", "NO_EXPLICIT_VISIBILITY_IN_API_MODE_WARNING", "KDocMissingDocumentation") diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/emitter/emitter.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/emitter/emitter.kt index 1f7b09af8..d85857de8 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/emitter/emitter.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/emitter/emitter.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.internal.emitter diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/estree/estree.extensions.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/estree/estree.extensions.kt index 3aa31f921..122a3a397 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/estree/estree.extensions.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/estree/estree.extensions.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.internal.estree diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/estree/estree.kt index e5254013e..ad079dbd0 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/estree/estree.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/estree/estree.kt @@ -1,8 +1,10 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ +@file:Suppress("ClassName") + package space.kscience.kmath.internal.estree import kotlin.js.RegExp diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/stream/stream.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/stream/stream.kt index 52be5530f..caab91731 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/stream/stream.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/stream/stream.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.internal.stream diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/tsstdlib/lib.es2015.iterable.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/tsstdlib/lib.es2015.iterable.kt index 9c012e3a3..5c091e3a1 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/tsstdlib/lib.es2015.iterable.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/tsstdlib/lib.es2015.iterable.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.internal.tsstdlib diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/tsstdlib/lib.es5.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/tsstdlib/lib.es5.kt index 0cd395f2c..bb7fd44ca 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/tsstdlib/lib.es5.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/tsstdlib/lib.es5.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:Suppress("UNUSED_TYPEALIAS_PARAMETER", "DEPRECATION") diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/lib.dom.WebAssembly.module_dukat.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/lib.dom.WebAssembly.module_dukat.kt index 3754c3eff..52dd64a5e 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/lib.dom.WebAssembly.module_dukat.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/lib.dom.WebAssembly.module_dukat.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:JsQualifier("WebAssembly") diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/nonDeclarations.WebAssembly.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/nonDeclarations.WebAssembly.kt index c5023c384..d59a52701 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/nonDeclarations.WebAssembly.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/nonDeclarations.WebAssembly.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:Suppress( diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt index c89ad83c4..b04c4d48f 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.wasm.internal @@ -108,8 +108,8 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleF override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value) override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { - GroupOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value)) - GroupOperations.PLUS_OPERATION -> visit(mst.value) + GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value)) + GroupOps.PLUS_OPERATION -> visit(mst.value) PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value)) TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64) TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64) @@ -129,10 +129,10 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleF } override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { - GroupOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) - GroupOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) - RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) - FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right)) + GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) + GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) + RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) + FieldOps.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right)) PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64) else -> super.visitBinary(mst) } @@ -142,15 +142,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder(i32, IntRing, targ override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value) override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { - GroupOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value)) - GroupOperations.PLUS_OPERATION -> visit(mst.value) + GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value)) + GroupOps.PLUS_OPERATION -> visit(mst.value) else -> super.visitUnary(mst) } override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { - GroupOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) - GroupOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) - RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) + GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) + GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) + RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) else -> super.visitBinary(mst) } } diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/f64StandardFunctions.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/f64StandardFunctions.kt index 21a88b5d0..fe9c22c18 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/f64StandardFunctions.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/f64StandardFunctions.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.wasm.internal diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt index 6ea8f26c1..5b28b8782 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.wasm diff --git a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/ast/TestExecutionTime.kt b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/ast/TestExecutionTime.kt index d0e8128b4..f8c429d5a 100644 --- a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/ast/TestExecutionTime.kt +++ b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/ast/TestExecutionTime.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast @@ -11,19 +11,21 @@ import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.invoke import kotlin.math.sin import kotlin.random.Random +import kotlin.test.Ignore import kotlin.test.Test import kotlin.time.measureTime import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression import space.kscience.kmath.wasm.compileToExpression as wasmCompileToExpression // TODO move to benchmarks when https://github.com/Kotlin/kotlinx-benchmark/pull/38 or similar feature is merged +@Ignore internal class TestExecutionTime { private companion object { private const val times = 1_000_000 private val x by symbol private val algebra = DoubleField - private val functional = DoubleField.expressionInExtendedField { + private val functional = algebra.expressionInExtendedField { bindSymbol(x) * const(2.0) + const(2.0) / bindSymbol(x) - const(16.0) / sin(bindSymbol(x)) } @@ -31,9 +33,9 @@ internal class TestExecutionTime { x * number(2.0) + number(2.0) / x - number(16.0) / sin(x) } - private val mst = node.toExpression(DoubleField) - private val wasm = node.wasmCompileToExpression(DoubleField) - private val estree = node.estreeCompileToExpression(DoubleField) + private val mst = node.toExpression(algebra) + private val wasm = node.wasmCompileToExpression(algebra) + private val estree = node.estreeCompileToExpression(algebra) // In JavaScript, the expression below is implemented like // _no_name_provided__125.prototype.invoke_178 = function (args) { @@ -44,7 +46,7 @@ internal class TestExecutionTime { private val raw = Expression { args -> val x = args[x]!! - x * 2.0 + 2.0 / x - 16.0 / sin(x) + algebra { x * 2.0 + 2.0 / x - 16.0 / sin(x) } } private val justCalculate = { args: dynamic -> diff --git a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/ast/utils.kt b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/ast/utils.kt index 93b7e9449..3c2a9bd13 100644 --- a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/ast/utils.kt +++ b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/ast/utils.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast @@ -10,6 +10,8 @@ import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.IntRing +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import space.kscience.kmath.estree.compile as estreeCompile import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression import space.kscience.kmath.wasm.compile as wasmCompile @@ -34,6 +36,7 @@ private object ESTreeCompilerTestContext : CompilerTestContext { } internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) { + contract { callsInPlace(action, InvocationKind.AT_LEAST_ONCE) } action(WasmCompilerTestContext) action(ESTreeCompilerTestContext) } diff --git a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/wasm/TestWasmSpecific.kt b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/wasm/TestWasmSpecific.kt index 45776c191..6c91df866 100644 --- a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/wasm/TestWasmSpecific.kt +++ b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/wasm/TestWasmSpecific.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.wasm @@ -11,7 +11,6 @@ import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.symbol import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.IntRing -import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt index 4147324ee..2426d6ee4 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.asm diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt index a796ae2a5..418d6141b 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.asm.internal @@ -14,9 +14,11 @@ import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.MST import java.lang.invoke.MethodHandles import java.lang.invoke.MethodType +import java.nio.file.Paths import java.util.stream.Collectors.toMap import kotlin.contracts.InvocationKind import kotlin.contracts.contract +import kotlin.io.path.writeBytes /** * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. @@ -194,15 +196,18 @@ internal class AsmBuilder( visitEnd() } - val cls = classLoader.defineClass(className, classWriter.toByteArray()) - // java.io.File("dump.class").writeBytes(classWriter.toByteArray()) + val binary = classWriter.toByteArray() + val cls = classLoader.defineClass(className, binary) + + if (System.getProperty("space.kscience.communicator.prettyapi.dump.generated.classes") == "1") + Paths.get("$className.class").writeBytes(binary) + val l = MethodHandles.publicLookup() - if (hasConstants) - l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array::class.java)) - .invoke(constants.toTypedArray()) as Expression + (if (hasConstants) + l.findConstructor(cls, MethodType.methodType(Void.TYPE, Array::class.java))(constants.toTypedArray()) else - l.findConstructor(cls, MethodType.methodType(Void.TYPE)).invoke() as Expression + l.findConstructor(cls, MethodType.methodType(Void.TYPE))()) as Expression } /** diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt index a84248f63..5e2e7d8c6 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/codegenUtils.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.asm.internal @@ -57,13 +57,13 @@ internal fun MethodVisitor.label(): Label = Label().also(::visitLabel) /** * Creates a class name for [Expression] subclassed to implement [mst] provided. * - * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there + * These methods help to avoid collisions of class name to prevent loading several classes with the same name. If there * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. * * @author Iaroslav Postovalov */ internal tailrec fun buildName(mst: MST, collision: Int = 0): String { - val name = "space.kscience.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + val name = "space.kscience.kmath.asm.generated.CompiledExpression_${mst.hashCode()}_$collision" try { Class.forName(name) diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/mapIntrinsics.kt index 8f4daecf9..40d9d8fe6 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/mapIntrinsics.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/mapIntrinsics.kt @@ -1,13 +1,12 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:JvmName("MapIntrinsics") package space.kscience.kmath.asm.internal -import space.kscience.kmath.expressions.StringSymbol import space.kscience.kmath.expressions.Symbol /** @@ -15,4 +14,4 @@ import space.kscience.kmath.expressions.Symbol * * @author Iaroslav Postovalov */ -internal fun Map.getOrFail(key: String): V = getValue(StringSymbol(key)) +internal fun Map.getOrFail(key: String): V = getValue(Symbol(key)) diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt index 556adbe7d..3e5253084 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/ast/rendering/multiplatformToString.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast.rendering diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/utils.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/utils.kt index d3b554efd..a0bdd68a0 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/utils.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/utils.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ast @@ -10,6 +10,8 @@ import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.IntRing +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import space.kscience.kmath.asm.compile as asmCompile import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression @@ -22,4 +24,7 @@ private object AsmCompilerTestContext : CompilerTestContext { asmCompile(algebra, arguments) } -internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) = action(AsmCompilerTestContext) +internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) { + contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } + action(AsmCompilerTestContext) +} diff --git a/kmath-commons/build.gradle.kts b/kmath-commons/build.gradle.kts index a208c956c..96c17a215 100644 --- a/kmath-commons/build.gradle.kts +++ b/kmath-commons/build.gradle.kts @@ -9,6 +9,7 @@ dependencies { api(project(":kmath-core")) api(project(":kmath-complex")) api(project(":kmath-coroutines")) + api(project(":kmath-optimization")) api(project(":kmath-stat")) api(project(":kmath-functions")) api("org.apache.commons:commons-math3:3.6.1") diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 361027968..4d2bd6237 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.expressions @@ -9,29 +9,29 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import space.kscience.kmath.expressions.* import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations +import space.kscience.kmath.operations.NumbersAddOps /** * A field over commons-math [DerivativeStructure]. * * @property order The derivation order. - * @property bindings The map of bindings values. All bindings are considered free parameters + * @param bindings The map of bindings values. All bindings are considered free parameters */ @OptIn(UnstableKMathAPI::class) public class DerivativeStructureField( public val order: Int, bindings: Map, ) : ExtendedField, ExpressionAlgebra, - NumbersAddOperations { + NumbersAddOps { public val numberOfVariables: Int = bindings.size - public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } - public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) } + override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } + override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) } - public override fun number(value: Number): DerivativeStructure = const(value.toDouble()) + override fun number(value: Number): DerivativeStructure = const(value.toDouble()) /** - * A class that implements both [DerivativeStructure] and a [Symbol] + * A class implementing both [DerivativeStructure] and [Symbol]. */ public inner class DerivativeStructureSymbol( size: Int, @@ -39,10 +39,10 @@ public class DerivativeStructureField( symbol: Symbol, value: Double, ) : DerivativeStructure(size, order, index, value), Symbol { - public override val identity: String = symbol.identity - public override fun toString(): String = identity - public override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity - public override fun hashCode(): Int = identity.hashCode() + override val identity: String = symbol.identity + override fun toString(): String = identity + override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + override fun hashCode(): Int = identity.hashCode() } /** @@ -52,10 +52,10 @@ public class DerivativeStructureField( key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value) }.toMap() - public override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value) + override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value) - public override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = variables[value] - public override fun bindSymbol(value: String): DerivativeStructureSymbol = variables.getValue(value) + override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = variables[value] + override fun bindSymbol(value: String): DerivativeStructureSymbol = variables.getValue(value) public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity) @@ -68,47 +68,50 @@ public class DerivativeStructureField( public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList()) - public override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate() + override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate() - public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) + override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.add(right) - public override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value) + override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value) - public override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b) - public override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) - public override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() - public override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() - public override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan() - public override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() - public override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() - public override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() - public override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh() - public override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh() - public override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh() - public override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh() - public override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh() - public override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh() + override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.multiply(right) + override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right) + override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() + override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() + override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan() + override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() + override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() + override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() + override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh() + override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh() + override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh() + override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh() + override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh() + override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh() - public override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { + override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { is Double -> arg.pow(pow) is Int -> arg.pow(pow) else -> arg.pow(pow.toDouble()) } public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) - public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() - public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() + override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() + override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() - public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) - public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) - public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this - public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this + override operator fun DerivativeStructure.plus(other: Number): DerivativeStructure = add(other.toDouble()) + override operator fun DerivativeStructure.minus(other: Number): DerivativeStructure = subtract(other.toDouble()) + override operator fun Number.plus(other: DerivativeStructure): DerivativeStructure = other + this + override operator fun Number.minus(other: DerivativeStructure): DerivativeStructure = other - this +} - public companion object : - AutoDiffProcessor> { - public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression = - DerivativeStructureExpression(function) - } +/** + * Auto-diff processor based on Commons-math [DerivativeStructure] + */ +public object DSProcessor : AutoDiffProcessor { + override fun differentiate( + function: DerivativeStructureField.() -> DerivativeStructure, + ): DerivativeStructureExpression = DerivativeStructureExpression(function) } /** @@ -117,13 +120,13 @@ public class DerivativeStructureField( public class DerivativeStructureExpression( public val function: DerivativeStructureField.() -> DerivativeStructure, ) : DifferentiableExpression { - public override operator fun invoke(arguments: Map): Double = + override operator fun invoke(arguments: Map): Double = DerivativeStructureField(0, arguments).function().value /** * Get the derivative expression with given orders */ - public override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) } } } diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/integration/CMGaussRuleIntegrator.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/integration/CMGaussRuleIntegrator.kt index 4e174723d..5152b04f9 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/integration/CMGaussRuleIntegrator.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/integration/CMGaussRuleIntegrator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.integration @@ -16,7 +16,7 @@ public class CMGaussRuleIntegrator( private var type: GaussRule = GaussRule.LEGANDRE, ) : UnivariateIntegrator { - override fun integrate(integrand: UnivariateIntegrand): UnivariateIntegrand { + override fun process(integrand: UnivariateIntegrand): UnivariateIntegrand { val range = integrand.getFeature()?.range ?: error("Integration range is not provided") val integrator: GaussIntegrator = getIntegrator(range) @@ -76,8 +76,8 @@ public class CMGaussRuleIntegrator( numPoints: Int = 100, type: GaussRule = GaussRule.LEGANDRE, function: (Double) -> Double, - ): Double = CMGaussRuleIntegrator(numPoints, type).integrate( + ): Double = CMGaussRuleIntegrator(numPoints, type).process( UnivariateIntegrand(function, IntegrationRange(range)) - ).valueOrNull!! + ).value } } \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/integration/CMIntegrator.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/integration/CMIntegrator.kt index bcddccdc4..76a2f297c 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/integration/CMIntegrator.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/integration/CMIntegrator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.integration @@ -18,7 +18,7 @@ public class CMIntegrator( public val integratorBuilder: (Integrand) -> org.apache.commons.math3.analysis.integration.UnivariateIntegrator, ) : UnivariateIntegrator { - override fun integrate(integrand: UnivariateIntegrand): UnivariateIntegrand { + override fun process(integrand: UnivariateIntegrand): UnivariateIntegrand { val integrator = integratorBuilder(integrand) val maxCalls = integrand.getFeature()?.maxCalls ?: defaultMaxCalls val remainingCalls = maxCalls - integrand.calls diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/linear/CMMatrix.kt index 11b097831..14e7fc365 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/linear/CMMatrix.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.linear @@ -10,23 +10,27 @@ import space.kscience.kmath.linear.* import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.DoubleBuffer import kotlin.reflect.KClass import kotlin.reflect.cast public class CMMatrix(public val origin: RealMatrix) : Matrix { - public override val rowNum: Int get() = origin.rowDimension - public override val colNum: Int get() = origin.columnDimension + override val rowNum: Int get() = origin.rowDimension + override val colNum: Int get() = origin.columnDimension - public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) + override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) } -public class CMVector(public val origin: RealVector) : Point { - public override val size: Int get() = origin.dimension +@JvmInline +public value class CMVector(public val origin: RealVector) : Point { + override val size: Int get() = origin.dimension - public override operator fun get(index: Int): Double = origin.getEntry(index) + override operator fun get(index: Int): Double = origin.getEntry(index) - public override operator fun iterator(): Iterator = origin.toArray().iterator() + override operator fun iterator(): Iterator = origin.toArray().iterator() + + override fun toString(): String = Buffer.toString(this) } public fun RealVector.toPoint(): CMVector = CMVector(this) @@ -34,7 +38,7 @@ public fun RealVector.toPoint(): CMVector = CMVector(this) public object CMLinearSpace : LinearSpace { override val elementAlgebra: DoubleField get() = DoubleField - public override fun buildMatrix( + override fun buildMatrix( rows: Int, columns: Int, initializer: DoubleField.(i: Int, j: Int) -> Double, @@ -73,16 +77,16 @@ public object CMLinearSpace : LinearSpace { override fun Point.minus(other: Point): CMVector = toCM().origin.subtract(other.toCM().origin).wrap() - public override fun Matrix.dot(other: Matrix): CMMatrix = + override fun Matrix.dot(other: Matrix): CMMatrix = toCM().origin.multiply(other.toCM().origin).wrap() - public override fun Matrix.dot(vector: Point): CMVector = + override fun Matrix.dot(vector: Point): CMVector = toCM().origin.preMultiply(vector.toCM().origin).wrap() - public override operator fun Matrix.minus(other: Matrix): CMMatrix = + override operator fun Matrix.minus(other: Matrix): CMMatrix = toCM().origin.subtract(other.toCM().origin).wrap() - public override operator fun Matrix.times(value: Double): CMMatrix = + override operator fun Matrix.times(value: Double): CMMatrix = toCM().origin.scalarMultiply(value).wrap() override fun Double.times(m: Matrix): CMMatrix = @@ -95,7 +99,7 @@ public object CMLinearSpace : LinearSpace { v * this @UnstableKMathAPI - override fun getFeature(structure: Matrix, type: KClass): F? { + override fun computeFeature(structure: Matrix, type: KClass): F? { //Return the feature if it is intrinsic to the structure structure.getFeature(type)?.let { return it } @@ -109,22 +113,22 @@ public object CMLinearSpace : LinearSpace { LupDecompositionFeature { private val lup by lazy { LUDecomposition(origin) } override val determinant: Double by lazy { lup.determinant } - override val l: Matrix by lazy { CMMatrix(lup.l) + LFeature } - override val u: Matrix by lazy { CMMatrix(lup.u) + UFeature } + override val l: Matrix by lazy> { CMMatrix(lup.l).withFeature(LFeature) } + override val u: Matrix by lazy> { CMMatrix(lup.u).withFeature(UFeature) } override val p: Matrix by lazy { CMMatrix(lup.p) } } CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { - override val l: Matrix by lazy { + override val l: Matrix by lazy> { val cholesky = CholeskyDecomposition(origin) - CMMatrix(cholesky.l) + LFeature + CMMatrix(cholesky.l).withFeature(LFeature) } } QRDecompositionFeature::class -> object : QRDecompositionFeature { private val qr by lazy { QRDecomposition(origin) } - override val q: Matrix by lazy { CMMatrix(qr.q) + OrthogonalFeature } - override val r: Matrix by lazy { CMMatrix(qr.r) + UFeature } + override val q: Matrix by lazy> { CMMatrix(qr.q).withFeature(OrthogonalFeature) } + override val r: Matrix by lazy> { CMMatrix(qr.r).withFeature(UFeature) } } SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature { diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/linear/CMSolver.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/linear/CMSolver.kt index ee602ca06..d1fb441b0 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/linear/CMSolver.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/linear/CMSolver.kt @@ -1,11 +1,12 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.linear import org.apache.commons.math3.linear.* +import space.kscience.kmath.linear.LinearSolver import space.kscience.kmath.linear.Matrix import space.kscience.kmath.linear.Point @@ -17,7 +18,7 @@ public enum class CMDecomposition { CHOLESKY } -public fun CMLinearSpace.solver( +private fun CMLinearSpace.solver( a: Matrix, decomposition: CMDecomposition = CMDecomposition.LUP, ): DecompositionSolver = when (decomposition) { @@ -44,3 +45,14 @@ public fun CMLinearSpace.inverse( a: Matrix, decomposition: CMDecomposition = CMDecomposition.LUP, ): CMMatrix = solver(a, decomposition).inverse.wrap() + + +public fun CMLinearSpace.solver(decomposition: CMDecomposition): LinearSolver = object : LinearSolver { + override fun solve(a: Matrix, b: Matrix): Matrix = solver(a, decomposition).solve(b.toCM().origin).wrap() + + override fun solve(a: Matrix, b: Point): Point = solver(a, decomposition).solve(b.toCM().origin).toPoint() + + override fun inverse(matrix: Matrix): Matrix = solver(matrix, decomposition).inverse.wrap() +} + +public fun CMLinearSpace.lupSolver(): LinearSolver = solver((CMDecomposition.LUP)) \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimization.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimization.kt deleted file mode 100644 index 400ee0310..000000000 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimization.kt +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.commons.optimization - -import org.apache.commons.math3.optim.* -import org.apache.commons.math3.optim.nonlinear.scalar.GoalType -import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer -import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction -import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient -import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer -import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex -import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex -import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer -import space.kscience.kmath.expressions.* -import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.optimization.* -import kotlin.collections.set -import kotlin.reflect.KClass - -public operator fun PointValuePair.component1(): DoubleArray = point -public operator fun PointValuePair.component2(): Double = value - -@OptIn(UnstableKMathAPI::class) -public class CMOptimization( - override val symbols: List, -) : FunctionOptimization, NoDerivFunctionOptimization, SymbolIndexer, OptimizationFeature { - - private val optimizationData: HashMap, OptimizationData> = HashMap() - private var optimizerBuilder: (() -> MultivariateOptimizer)? = null - public var convergenceChecker: ConvergenceChecker = SimpleValueChecker( - DEFAULT_RELATIVE_TOLERANCE, - DEFAULT_ABSOLUTE_TOLERANCE, - DEFAULT_MAX_ITER - ) - - override var maximize: Boolean - get() = optimizationData[GoalType::class] == GoalType.MAXIMIZE - set(value) { - optimizationData[GoalType::class] = if (value) GoalType.MAXIMIZE else GoalType.MINIMIZE - } - - public fun addOptimizationData(data: OptimizationData) { - optimizationData[data::class] = data - } - - init { - addOptimizationData(MaxEval.unlimited()) - } - - public fun exportOptimizationData(): List = optimizationData.values.toList() - - public override fun initialGuess(map: Map): Unit { - addOptimizationData(InitialGuess(map.toDoubleArray())) - } - - public override fun function(expression: Expression): Unit { - val objectiveFunction = ObjectiveFunction { - val args = it.toMap() - expression(args) - } - addOptimizationData(objectiveFunction) - } - - public override fun diffFunction(expression: DifferentiableExpression) { - function(expression) - val gradientFunction = ObjectiveFunctionGradient { - val args = it.toMap() - DoubleArray(symbols.size) { index -> - expression.derivative(symbols[index])(args) - } - } - addOptimizationData(gradientFunction) - if (optimizerBuilder == null) { - optimizerBuilder = { - NonLinearConjugateGradientOptimizer( - NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, - convergenceChecker - ) - } - } - } - - public fun simplex(simplex: AbstractSimplex) { - addOptimizationData(simplex) - //Set optimization builder to simplex if it is not present - if (optimizerBuilder == null) { - optimizerBuilder = { SimplexOptimizer(convergenceChecker) } - } - } - - public fun simplexSteps(steps: Map) { - simplex(NelderMeadSimplex(steps.toDoubleArray())) - } - - public fun goal(goalType: GoalType) { - addOptimizationData(goalType) - } - - public fun optimizer(block: () -> MultivariateOptimizer) { - optimizerBuilder = block - } - - override fun update(result: OptimizationResult) { - initialGuess(result.point) - } - - override fun optimize(): OptimizationResult { - val optimizer = optimizerBuilder?.invoke() ?: error("Optimizer not defined") - val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray()) - return OptimizationResult(point.toMap(), value, setOf(this)) - } - - public companion object : OptimizationProblemFactory { - public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4 - public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4 - public const val DEFAULT_MAX_ITER: Int = 1000 - - override fun build(symbols: List): CMOptimization = CMOptimization(symbols) - } -} - -public fun CMOptimization.initialGuess(vararg pairs: Pair): Unit = initialGuess(pairs.toMap()) -public fun CMOptimization.simplexSteps(vararg pairs: Pair): Unit = simplexSteps(pairs.toMap()) diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimizer.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimizer.kt new file mode 100644 index 000000000..11eb6fba8 --- /dev/null +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimizer.kt @@ -0,0 +1,145 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ +@file:OptIn(UnstableKMathAPI::class) +package space.kscience.kmath.commons.optimization + +import org.apache.commons.math3.optim.* +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType +import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer +import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction +import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient +import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.expressions.SymbolIndexer +import space.kscience.kmath.expressions.derivative +import space.kscience.kmath.expressions.withSymbols +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.misc.log +import space.kscience.kmath.optimization.* +import kotlin.collections.set +import kotlin.reflect.KClass + +public operator fun PointValuePair.component1(): DoubleArray = point +public operator fun PointValuePair.component2(): Double = value + +public class CMOptimizerEngine(public val optimizerBuilder: () -> MultivariateOptimizer) : OptimizationFeature { + override fun toString(): String = "CMOptimizer($optimizerBuilder)" +} + +/** + * Specify a Commons-maths optimization engine + */ +public fun FunctionOptimizationBuilder.cmEngine(optimizerBuilder: () -> MultivariateOptimizer) { + addFeature(CMOptimizerEngine(optimizerBuilder)) +} + +public class CMOptimizerData(public val data: List OptimizationData>) : OptimizationFeature { + public constructor(vararg data: (SymbolIndexer.() -> OptimizationData)) : this(data.toList()) + + override fun toString(): String = "CMOptimizerData($data)" +} + +/** + * Specify Commons-maths optimization data. + */ +public fun FunctionOptimizationBuilder.cmOptimizationData(data: SymbolIndexer.() -> OptimizationData) { + updateFeature { + val newData = (it?.data ?: emptyList()) + data + CMOptimizerData(newData) + } +} + +public fun FunctionOptimizationBuilder.simplexSteps(vararg steps: Pair) { + //TODO use convergence checker from features + cmEngine { SimplexOptimizer(CMOptimizer.defaultConvergenceChecker) } + cmOptimizationData { NelderMeadSimplex(mapOf(*steps).toDoubleArray()) } +} + +@OptIn(UnstableKMathAPI::class) +public object CMOptimizer : Optimizer> { + + public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4 + public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4 + public const val DEFAULT_MAX_ITER: Int = 1000 + + public val defaultConvergenceChecker: SimpleValueChecker = SimpleValueChecker( + DEFAULT_RELATIVE_TOLERANCE, + DEFAULT_ABSOLUTE_TOLERANCE, + DEFAULT_MAX_ITER + ) + + + override suspend fun optimize( + problem: FunctionOptimization, + ): FunctionOptimization { + val startPoint = problem.startPoint + + val parameters = problem.getFeature()?.symbols + ?: problem.getFeature>()?.point?.keys + ?: startPoint.keys + + + withSymbols(parameters) { + val convergenceChecker: ConvergenceChecker = SimpleValueChecker( + DEFAULT_RELATIVE_TOLERANCE, + DEFAULT_ABSOLUTE_TOLERANCE, + DEFAULT_MAX_ITER + ) + + val cmOptimizer: MultivariateOptimizer = problem.getFeature()?.optimizerBuilder?.invoke() + ?: NonLinearConjugateGradientOptimizer( + NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, + convergenceChecker + ) + + val optimizationData: HashMap, OptimizationData> = HashMap() + + fun addOptimizationData(data: OptimizationData) { + optimizationData[data::class] = data + } + + addOptimizationData(MaxEval.unlimited()) + addOptimizationData(InitialGuess(startPoint.toDoubleArray())) + + //fun exportOptimizationData(): List = optimizationData.values.toList() + + val objectiveFunction = ObjectiveFunction { + val args = startPoint + it.toMap() + val res = problem.expression(args) + res + } + addOptimizationData(objectiveFunction) + + val gradientFunction = ObjectiveFunctionGradient { + val args = startPoint + it.toMap() + val res = DoubleArray(symbols.size) { index -> + problem.expression.derivative(symbols[index])(args) + } + res + } + addOptimizationData(gradientFunction) + + val logger = problem.getFeature() + + for (feature in problem.features) { + when (feature) { + is CMOptimizerData -> feature.data.forEach { dataBuilder -> + addOptimizationData(dataBuilder()) + } + is FunctionOptimizationTarget -> when (feature) { + FunctionOptimizationTarget.MAXIMIZE -> addOptimizationData(GoalType.MAXIMIZE) + FunctionOptimizationTarget.MINIMIZE -> addOptimizationData(GoalType.MINIMIZE) + } + else -> logger?.log { "The feature $feature is unused in optimization" } + } + } + + val (point, value) = cmOptimizer.optimize(*optimizationData.values.toTypedArray()) + return problem.withFeatures(OptimizationResult(point.toMap()), OptimizationValue(value)) + } + } +} diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/cmFit.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/cmFit.kt deleted file mode 100644 index 645c41291..000000000 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/cmFit.kt +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.commons.optimization - -import org.apache.commons.math3.analysis.differentiation.DerivativeStructure -import space.kscience.kmath.commons.expressions.DerivativeStructureField -import space.kscience.kmath.expressions.DifferentiableExpression -import space.kscience.kmath.expressions.Expression -import space.kscience.kmath.expressions.Symbol -import space.kscience.kmath.optimization.FunctionOptimization -import space.kscience.kmath.optimization.OptimizationResult -import space.kscience.kmath.optimization.noDerivOptimizeWith -import space.kscience.kmath.optimization.optimizeWith -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.asBuffer - -/** - * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation - */ -public fun FunctionOptimization.Companion.chiSquared( - x: Buffer, - y: Buffer, - yErr: Buffer, - model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, -): DifferentiableExpression = chiSquared(DerivativeStructureField, x, y, yErr, model) - -/** - * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation - */ -public fun FunctionOptimization.Companion.chiSquared( - x: Iterable, - y: Iterable, - yErr: Iterable, - model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, -): DifferentiableExpression = chiSquared( - DerivativeStructureField, - x.toList().asBuffer(), - y.toList().asBuffer(), - yErr.toList().asBuffer(), - model -) - -/** - * Optimize expression without derivatives - */ -public fun Expression.optimize( - vararg symbols: Symbol, - configuration: CMOptimization.() -> Unit, -): OptimizationResult = noDerivOptimizeWith(CMOptimization, symbols = symbols, configuration) - -/** - * Optimize differentiable expression - */ -public fun DifferentiableExpression.optimize( - vararg symbols: Symbol, - configuration: CMOptimization.() -> Unit, -): OptimizationResult = optimizeWith(CMOptimization, symbols = symbols, configuration) - -public fun DifferentiableExpression.minimize( - vararg startPoint: Pair, - configuration: CMOptimization.() -> Unit = {}, -): OptimizationResult { - val symbols = startPoint.map { it.first }.toTypedArray() - return optimize(*symbols){ - maximize = false - initialGuess(startPoint.toMap()) - diffFunction(this@minimize) - configuration() - } -} \ No newline at end of file diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt index 4e2fbf980..28294cf14 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.random @@ -8,6 +8,7 @@ package space.kscience.kmath.commons.random import kotlinx.coroutines.runBlocking import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.samplers.GaussianSampler +import space.kscience.kmath.misc.toIntExact import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.stat.next @@ -16,31 +17,31 @@ public class CMRandomGeneratorWrapper( ) : org.apache.commons.math3.random.RandomGenerator { private var generator: RandomGenerator = factory(intArrayOf()) - public override fun nextBoolean(): Boolean = generator.nextBoolean() - public override fun nextFloat(): Float = generator.nextDouble().toFloat() + override fun nextBoolean(): Boolean = generator.nextBoolean() + override fun nextFloat(): Float = generator.nextDouble().toFloat() - public override fun setSeed(seed: Int) { + override fun setSeed(seed: Int) { generator = factory(intArrayOf(seed)) } - public override fun setSeed(seed: IntArray) { + override fun setSeed(seed: IntArray) { generator = factory(seed) } - public override fun setSeed(seed: Long) { - setSeed(seed.toInt()) + override fun setSeed(seed: Long) { + setSeed(seed.toIntExact()) } - public override fun nextBytes(bytes: ByteArray) { + override fun nextBytes(bytes: ByteArray) { generator.fillBytes(bytes) } - public override fun nextInt(): Int = generator.nextInt() - public override fun nextInt(n: Int): Int = generator.nextInt(n) + override fun nextInt(): Int = generator.nextInt() + override fun nextInt(n: Int): Int = generator.nextInt(n) @PerformancePitfall - public override fun nextGaussian(): Double = runBlocking { GaussianSampler(0.0, 1.0).next(generator) } + override fun nextGaussian(): Double = runBlocking { GaussianSampler(0.0, 1.0).next(generator) } - public override fun nextDouble(): Double = generator.nextDouble() - public override fun nextLong(): Long = generator.nextLong() + override fun nextDouble(): Double = generator.nextDouble() + override fun nextLong(): Long = generator.nextLong() } diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/transform/Transformations.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/transform/Transformations.kt index d29491d63..73ab91542 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/transform/Transformations.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/transform/Transformations.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.transform @@ -10,10 +10,13 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map import org.apache.commons.math3.transform.* import space.kscience.kmath.complex.Complex +import space.kscience.kmath.operations.SuspendBufferTransform import space.kscience.kmath.streaming.chunked import space.kscience.kmath.streaming.spread -import space.kscience.kmath.structures.* - +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.DoubleBuffer +import space.kscience.kmath.structures.VirtualBuffer +import space.kscience.kmath.structures.asBuffer /** @@ -32,7 +35,7 @@ public object Transformations { /** * Create a virtual buffer on top of array */ - private fun Array.asBuffer() = VirtualBuffer(size) { + private fun Array.asBuffer() = VirtualBuffer(size) { val value = get(it) Complex(value.real, value.imaginary) } diff --git a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt index 966675062..eaebc84dc 100644 --- a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt +++ b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.expressions @@ -16,7 +16,7 @@ internal inline fun diff( order: Int, vararg parameters: Pair, block: DerivativeStructureField.() -> Unit, -): Unit { +) { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } DerivativeStructureField(order, mapOf(*parameters)).run(block) } @@ -34,7 +34,7 @@ internal class AutoDiffTest { println(z.derivative(x)) println(z.derivative(y, x)) assertEquals(z.derivative(x, y), z.derivative(y, x)) - //check that improper order cause failure + // check improper order cause failure assertFails { z.derivative(x, x, y) } } } @@ -42,8 +42,8 @@ internal class AutoDiffTest { @Test fun autoDifTest() { val f = DerivativeStructureExpression { - val x by binding() - val y by binding() + val x by binding + val y by binding x.pow(2) + 2 * x * y + y.pow(2) + 1 } diff --git a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/integration/IntegrationTest.kt b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/integration/IntegrationTest.kt index c5573fef1..bab3aecb6 100644 --- a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/integration/IntegrationTest.kt +++ b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/integration/IntegrationTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.integration diff --git a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt index 15c9120ec..c670ceead 100644 --- a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -1,48 +1,47 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.commons.optimization import kotlinx.coroutines.runBlocking +import space.kscience.kmath.commons.expressions.DSProcessor import space.kscience.kmath.commons.expressions.DerivativeStructureExpression import space.kscience.kmath.distributions.NormalDistribution +import space.kscience.kmath.expressions.Symbol.Companion.x +import space.kscience.kmath.expressions.Symbol.Companion.y +import space.kscience.kmath.expressions.chiSquaredExpression import space.kscience.kmath.expressions.symbol -import space.kscience.kmath.optimization.FunctionOptimization +import space.kscience.kmath.operations.map +import space.kscience.kmath.optimization.* import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.structures.DoubleBuffer +import space.kscience.kmath.structures.asBuffer import kotlin.math.pow import kotlin.test.Test internal class OptimizeTest { - val x by symbol - val y by symbol - val normal = DerivativeStructureExpression { - exp(-bindSymbol(x).pow(2) / 2) + exp(-bindSymbol(y) - .pow(2) / 2) + exp(-bindSymbol(x).pow(2) / 2) + exp(-bindSymbol(y).pow(2) / 2) } @Test - fun testGradientOptimization() { - val result = normal.optimize(x, y) { - initialGuess(x to 1.0, y to 1.0) - //no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function - } - println(result.point) - println(result.value) + fun testGradientOptimization() = runBlocking { + val result = normal.optimizeWith(CMOptimizer, x to 1.0, y to 1.0) + println(result.resultPoint) + println(result.resultValue) } @Test - fun testSimplexOptimization() { - val result = normal.optimize(x, y) { - initialGuess(x to 1.0, y to 1.0) + fun testSimplexOptimization() = runBlocking { + val result = normal.optimizeWith(CMOptimizer, x to 1.0, y to 1.0) { simplexSteps(x to 2.0, y to 0.5) //this sets simplex optimizer } - println(result.point) - println(result.value) + println(result.resultPoint) + println(result.resultValue) } @Test @@ -54,21 +53,27 @@ internal class OptimizeTest { val sigma = 1.0 val generator = NormalDistribution(0.0, sigma) val chain = generator.sample(RandomGenerator.default(112667)) - val x = (1..100).map(Int::toDouble) + val x = (1..100).map(Int::toDouble).asBuffer() val y = x.map { it.pow(2) + it + 1 + chain.next() } - val yErr = List(x.size) { sigma } + val yErr = DoubleBuffer(x.size) { sigma } - val chi2 = FunctionOptimization.chiSquared(x, y, yErr) { x1 -> + val chi2 = DSProcessor.chiSquaredExpression( + x, y, yErr + ) { arg -> val cWithDefault = bindSymbolOrNull(c) ?: one - bindSymbol(a) * x1.pow(2) + bindSymbol(b) * x1 + cWithDefault + bindSymbol(a) * arg.pow(2) + bindSymbol(b) * arg + cWithDefault } - val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) + val result: FunctionOptimization = chi2.optimizeWith( + CMOptimizer, + mapOf(a to 1.5, b to 0.9, c to 1.0), + FunctionOptimizationTarget.MINIMIZE + ) println(result) - println("Chi2/dof = ${result.value / (x.size - 3)}") + println("Chi2/dof = ${result.resultValue / (x.size - 3)}") } } diff --git a/kmath-complex/README.md b/kmath-complex/README.md index 18a83756d..110529b72 100644 --- a/kmath-complex/README.md +++ b/kmath-complex/README.md @@ -8,7 +8,7 @@ Complex and hypercomplex number systems in KMath. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-13`. +The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-14`. **Gradle:** ```gradle @@ -18,7 +18,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-complex:0.3.0-dev-13' + implementation 'space.kscience:kmath-complex:0.3.0-dev-14' } ``` **Gradle Kotlin DSL:** @@ -29,6 +29,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-complex:0.3.0-dev-13") + implementation("space.kscience:kmath-complex:0.3.0-dev-14") } ``` diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt index a96d046c9..879cfe94e 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.complex @@ -9,10 +9,7 @@ import space.kscience.kmath.memory.MemoryReader import space.kscience.kmath.memory.MemorySpec import space.kscience.kmath.memory.MemoryWriter import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.Norm -import space.kscience.kmath.operations.NumbersAddOperations -import space.kscience.kmath.operations.ScaleOperations +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.MemoryBuffer import space.kscience.kmath.structures.MutableBuffer @@ -52,10 +49,22 @@ private val PI_DIV_2 = Complex(PI / 2, 0) * A field of [Complex]. */ @OptIn(UnstableKMathAPI::class) -public object ComplexField : ExtendedField, Norm, NumbersAddOperations, +public object ComplexField : + ExtendedField, + Norm, + NumbersAddOps, ScaleOperations { - public override val zero: Complex = 0.0.toComplex() - public override val one: Complex = 1.0.toComplex() + + override val zero: Complex = 0.0.toComplex() + override val one: Complex = 1.0.toComplex() + + override fun bindSymbolOrNull(value: String): Complex? = if (value == "i") i else null + + override fun binaryOperationFunction(operation: String): (left: Complex, right: Complex) -> Complex = + when (operation) { + PowerOperations.POW_OPERATION -> ComplexField::power + else -> super.binaryOperationFunction(operation) + } /** * The imaginary unit. @@ -68,63 +77,67 @@ public object ComplexField : ExtendedField, Norm, Num override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value) - public override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) -// public override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) + override fun add(left: Complex, right: Complex): Complex = Complex(left.re + right.re, left.im + right.im) +// override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) - public override fun multiply(a: Complex, b: Complex): Complex = - Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) + override fun multiply(left: Complex, right: Complex): Complex = + Complex(left.re * right.re - left.im * right.im, left.re * right.im + left.im * right.re) - public override fun divide(a: Complex, b: Complex): Complex = when { - abs(b.im) < abs(b.re) -> { - val wr = b.im / b.re - val wd = b.re + wr * b.im + override fun divide(left: Complex, right: Complex): Complex = when { + abs(right.im) < abs(right.re) -> { + val wr = right.im / right.re + val wd = right.re + wr * right.im if (wd.isNaN() || wd == 0.0) throw ArithmeticException("Division by zero or infinity") else - Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd) + Complex((left.re + left.im * wr) / wd, (left.im - left.re * wr) / wd) } - b.im == 0.0 -> throw ArithmeticException("Division by zero") + right.im == 0.0 -> throw ArithmeticException("Division by zero") else -> { - val wr = b.re / b.im - val wd = b.im + wr * b.re + val wr = right.re / right.im + val wd = right.im + wr * right.re if (wd.isNaN() || wd == 0.0) throw ArithmeticException("Division by zero or infinity") else - Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd) + Complex((left.re * wr + left.im) / wd, (left.im * wr - left.re) / wd) } } override operator fun Complex.div(k: Number): Complex = Complex(re / k.toDouble(), im / k.toDouble()) - public override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2.0 - public override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2.0 + override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2.0 + override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2.0 - public override fun tan(arg: Complex): Complex { + override fun tan(arg: Complex): Complex { val e1 = exp(-i * arg) val e2 = exp(i * arg) return i * (e1 - e2) / (e1 + e2) } - public override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg) - public override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg) + override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg) + override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg) - public override fun atan(arg: Complex): Complex { + override fun atan(arg: Complex): Complex { val iArg = i * arg return i * (ln(1 - iArg) - ln(1 + iArg)) / 2 } - public override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0) + override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0) { arg.re.pow(pow.toDouble()).toComplex() - else + } else { exp(pow * ln(arg)) + } - public override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im)) + public fun power(arg: Complex, pow: Complex): Complex = exp(pow * ln(arg)) - public override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re) + + override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im)) + + override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re) /** * Adds complex number to real one. @@ -171,9 +184,7 @@ public object ComplexField : ExtendedField, Norm, Num */ public operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) - public override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) - - public override fun bindSymbolOrNull(value: String): Complex? = if (value == "i") i else null + override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) } /** @@ -187,22 +198,23 @@ public data class Complex(val re: Double, val im: Double) { public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) public constructor(re: Number) : this(re.toDouble(), 0.0) - public override fun toString(): String = "($re + i * $im)" + override fun toString(): String = "($re + i * $im)" public companion object : MemorySpec { - public override val objectSize: Int + override val objectSize: Int get() = 16 - public override fun MemoryReader.read(offset: Int): Complex = + override fun MemoryReader.read(offset: Int): Complex = Complex(readDouble(offset), readDouble(offset + 8)) - public override fun MemoryWriter.write(offset: Int, value: Complex) { + override fun MemoryWriter.write(offset: Int, value: Complex) { writeDouble(offset, value.re) writeDouble(offset + 8, value.im) } } } +public val Complex.Companion.algebra: ComplexField get() = ComplexField /** * Creates a complex number with real part equal to this real. diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt index 2c783eda0..9d5b1cddd 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt @@ -1,17 +1,13 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.complex import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.nd.AlgebraND -import space.kscience.kmath.nd.BufferND -import space.kscience.kmath.nd.BufferedFieldND -import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations +import space.kscience.kmath.nd.* +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -21,104 +17,68 @@ import kotlin.contracts.contract * An optimized nd-field for complex numbers */ @OptIn(UnstableKMathAPI::class) -public class ComplexFieldND( - shape: IntArray, -) : BufferedFieldND(shape, ComplexField, Buffer.Companion::complex), - NumbersAddOperations>, - ExtendedField> { +public sealed class ComplexFieldOpsND : BufferedFieldOpsND(ComplexField.bufferAlgebra), + ScaleOperations>, ExtendedFieldOps> { - public override val zero: BufferND by lazy { produce { zero } } - public override val one: BufferND by lazy { produce { one } } - - public override fun number(value: Number): BufferND { - val d = value.toComplex() // minimize conversions - return produce { d } + override fun StructureND.toBufferND(): BufferND = when (this) { + is BufferND -> this + else -> { + val indexer = indexerBuilder(shape) + BufferND(indexer, Buffer.complex(indexer.linearSize) { offset -> get(indexer.index(offset)) }) + } } -// -// @Suppress("OVERRIDE_BY_INLINE") -// override inline fun map( -// arg: AbstractNDBuffer, -// transform: DoubleField.(Double) -> Double, -// ): RealNDElement { -// check(arg) -// val array = RealBuffer(arg.strides.linearSize) { offset -> DoubleField.transform(arg.buffer[offset]) } -// return BufferedNDFieldElement(this, array) -// } -// -// @Suppress("OVERRIDE_BY_INLINE") -// override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement { -// val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) } -// return BufferedNDFieldElement(this, array) -// } -// -// @Suppress("OVERRIDE_BY_INLINE") -// override inline fun mapIndexed( -// arg: AbstractNDBuffer, -// transform: DoubleField.(index: IntArray, Double) -> Double, -// ): RealNDElement { -// check(arg) -// return BufferedNDFieldElement( -// this, -// RealBuffer(arg.strides.linearSize) { offset -> -// elementContext.transform( -// arg.strides.index(offset), -// arg.buffer[offset] -// ) -// }) -// } -// -// @Suppress("OVERRIDE_BY_INLINE") -// override inline fun combine( -// a: AbstractNDBuffer, -// b: AbstractNDBuffer, -// transform: DoubleField.(Double, Double) -> Double, -// ): RealNDElement { -// check(a, b) -// val buffer = RealBuffer(strides.linearSize) { offset -> -// elementContext.transform(a.buffer[offset], b.buffer[offset]) -// } -// return BufferedNDFieldElement(this, buffer) -// } + //TODO do specialization - public override fun power(arg: StructureND, pow: Number): BufferND = arg.map { power(it, pow) } + override fun scale(a: StructureND, value: Double): BufferND = + mapInline(a.toBufferND()) { it * value } - public override fun exp(arg: StructureND): BufferND = arg.map { exp(it) } + override fun power(arg: StructureND, pow: Number): BufferND = + mapInline(arg.toBufferND()) { power(it, pow) } - public override fun ln(arg: StructureND): BufferND = arg.map { ln(it) } + override fun exp(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { exp(it) } + override fun ln(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { ln(it) } - public override fun sin(arg: StructureND): BufferND = arg.map { sin(it) } - public override fun cos(arg: StructureND): BufferND = arg.map { cos(it) } - public override fun tan(arg: StructureND): BufferND = arg.map { tan(it) } - public override fun asin(arg: StructureND): BufferND = arg.map { asin(it) } - public override fun acos(arg: StructureND): BufferND = arg.map { acos(it) } - public override fun atan(arg: StructureND): BufferND = arg.map { atan(it) } + override fun sin(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { sin(it) } + override fun cos(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { cos(it) } + override fun tan(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { tan(it) } + override fun asin(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { asin(it) } + override fun acos(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { acos(it) } + override fun atan(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { atan(it) } - public override fun sinh(arg: StructureND): BufferND = arg.map { sinh(it) } - public override fun cosh(arg: StructureND): BufferND = arg.map { cosh(it) } - public override fun tanh(arg: StructureND): BufferND = arg.map { tanh(it) } - public override fun asinh(arg: StructureND): BufferND = arg.map { asinh(it) } - public override fun acosh(arg: StructureND): BufferND = arg.map { acosh(it) } - public override fun atanh(arg: StructureND): BufferND = arg.map { atanh(it) } + override fun sinh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { sinh(it) } + override fun cosh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { cosh(it) } + override fun tanh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { tanh(it) } + override fun asinh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { asinh(it) } + override fun acosh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { acosh(it) } + override fun atanh(arg: StructureND): BufferND = mapInline(arg.toBufferND()) { atanh(it) } + + public companion object : ComplexFieldOpsND() } +@UnstableKMathAPI +public val ComplexField.bufferAlgebra: BufferFieldOps + get() = bufferAlgebra(Buffer.Companion::complex) -/** - * Fast element production using function inlining - */ -public inline fun BufferedFieldND.produceInline(initializer: ComplexField.(Int) -> Complex): BufferND { - contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) } - val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) } - return BufferND(strides, buffer) + +@OptIn(UnstableKMathAPI::class) +public class ComplexFieldND(override val shape: Shape) : + ComplexFieldOpsND(), FieldND, NumbersAddOps> { + + override fun number(value: Number): BufferND { + val d = value.toDouble() // minimize conversions + return structureND(shape) { d.toComplex() } + } } +public val ComplexField.ndAlgebra: ComplexFieldOpsND get() = ComplexFieldOpsND -public fun AlgebraND.Companion.complex(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape) +public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape) /** * Produce a context for n-dimensional operations inside this real field */ -public inline fun ComplexField.nd(vararg shape: Int, action: ComplexFieldND.() -> R): R { +public inline fun ComplexField.withNdAlgebra(vararg shape: Int, action: ComplexFieldND.() -> R): R { contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } return ComplexFieldND(shape).action() } diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt index c59aabdcb..ff9a8302a 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.complex @@ -44,7 +44,7 @@ public val Quaternion.r: Double */ @OptIn(UnstableKMathAPI::class) public object QuaternionField : Field, Norm, PowerOperations, - ExponentialOperations, NumbersAddOperations, ScaleOperations { + ExponentialOperations, NumbersAddOps, ScaleOperations { override val zero: Quaternion = 0.toQuaternion() override val one: Quaternion = 1.toQuaternion() @@ -63,31 +63,31 @@ public object QuaternionField : Field, Norm, */ public val k: Quaternion = Quaternion(0, 0, 0, 1) - public override fun add(a: Quaternion, b: Quaternion): Quaternion = - Quaternion(a.w + b.w, a.x + b.x, a.y + b.y, a.z + b.z) + override fun add(left: Quaternion, right: Quaternion): Quaternion = + Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z) - public override fun scale(a: Quaternion, value: Double): Quaternion = + override fun scale(a: Quaternion, value: Double): Quaternion = Quaternion(a.w * value, a.x * value, a.y * value, a.z * value) - public override fun multiply(a: Quaternion, b: Quaternion): Quaternion = Quaternion( - a.w * b.w - a.x * b.x - a.y * b.y - a.z * b.z, - a.w * b.x + a.x * b.w + a.y * b.z - a.z * b.y, - a.w * b.y - a.x * b.z + a.y * b.w + a.z * b.x, - a.w * b.z + a.x * b.y - a.y * b.x + a.z * b.w, + override fun multiply(left: Quaternion, right: Quaternion): Quaternion = Quaternion( + left.w * right.w - left.x * right.x - left.y * right.y - left.z * right.z, + left.w * right.x + left.x * right.w + left.y * right.z - left.z * right.y, + left.w * right.y - left.x * right.z + left.y * right.w + left.z * right.x, + left.w * right.z + left.x * right.y - left.y * right.x + left.z * right.w, ) - public override fun divide(a: Quaternion, b: Quaternion): Quaternion { - val s = b.w * b.w + b.x * b.x + b.y * b.y + b.z * b.z + override fun divide(left: Quaternion, right: Quaternion): Quaternion { + val s = right.w * right.w + right.x * right.x + right.y * right.y + right.z * right.z return Quaternion( - (b.w * a.w + b.x * a.x + b.y * a.y + b.z * a.z) / s, - (b.w * a.x - b.x * a.w - b.y * a.z + b.z * a.y) / s, - (b.w * a.y + b.x * a.z - b.y * a.w - b.z * a.x) / s, - (b.w * a.z - b.x * a.y + b.y * a.x - b.z * a.w) / s, + (right.w * left.w + right.x * left.x + right.y * left.y + right.z * left.z) / s, + (right.w * left.x - right.x * left.w - right.y * left.z + right.z * left.y) / s, + (right.w * left.y + right.x * left.z - right.y * left.w - right.z * left.x) / s, + (right.w * left.z - right.x * left.y + right.y * left.x - right.z * left.w) / s, ) } - public override fun power(arg: Quaternion, pow: Number): Quaternion { + override fun power(arg: Quaternion, pow: Number): Quaternion { if (pow is Int) return pwr(arg, pow) if (floor(pow.toDouble()) == pow.toDouble()) return pwr(arg, pow.toInt()) return exp(pow * ln(arg)) @@ -131,7 +131,7 @@ public object QuaternionField : Field, Norm, return Quaternion(a2 * a2 - 6 * a2 * n1 + n1 * n1, x.x * n2, x.y * n2, x.z * n2) } - public override fun exp(arg: Quaternion): Quaternion { + override fun exp(arg: Quaternion): Quaternion { val un = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z if (un == 0.0) return exp(arg.w).toQuaternion() val n1 = sqrt(un) @@ -140,14 +140,14 @@ public object QuaternionField : Field, Norm, return Quaternion(ea * cos(n1), n2 * arg.x, n2 * arg.y, n2 * arg.z) } - public override fun ln(arg: Quaternion): Quaternion { + override fun ln(arg: Quaternion): Quaternion { val nu2 = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z if (nu2 == 0.0) return if (arg.w > 0) Quaternion(ln(arg.w), 0, 0, 0) else { - val l = ComplexField { ComplexField.ln(arg.w.toComplex()) } + val l = ComplexField { ln(arg.w.toComplex()) } Quaternion(l.re, l.im, 0, 0) } @@ -158,21 +158,21 @@ public object QuaternionField : Field, Norm, return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z) } - public override operator fun Number.plus(b: Quaternion): Quaternion = Quaternion(toDouble() + b.w, b.x, b.y, b.z) + override operator fun Number.plus(other: Quaternion): Quaternion = Quaternion(toDouble() + other.w, other.x, other.y, other.z) - public override operator fun Number.minus(b: Quaternion): Quaternion = - Quaternion(toDouble() - b.w, -b.x, -b.y, -b.z) + override operator fun Number.minus(other: Quaternion): Quaternion = + Quaternion(toDouble() - other.w, -other.x, -other.y, -other.z) - public override operator fun Quaternion.plus(b: Number): Quaternion = Quaternion(w + b.toDouble(), x, y, z) - public override operator fun Quaternion.minus(b: Number): Quaternion = Quaternion(w - b.toDouble(), x, y, z) + override operator fun Quaternion.plus(other: Number): Quaternion = Quaternion(w + other.toDouble(), x, y, z) + override operator fun Quaternion.minus(other: Number): Quaternion = Quaternion(w - other.toDouble(), x, y, z) - public override operator fun Number.times(b: Quaternion): Quaternion = - Quaternion(toDouble() * b.w, toDouble() * b.x, toDouble() * b.y, toDouble() * b.z) + override operator fun Number.times(arg: Quaternion): Quaternion = + Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z) - public override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z) - public override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg) + override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z) + override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg) - public override fun bindSymbolOrNull(value: String): Quaternion? = when (value) { + override fun bindSymbolOrNull(value: String): Quaternion? = when (value) { "i" -> i "j" -> j "k" -> k @@ -181,12 +181,12 @@ public object QuaternionField : Field, Norm, override fun number(value: Number): Quaternion = value.toQuaternion() - public override fun sinh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / 2.0 - public override fun cosh(arg: Quaternion): Quaternion = (exp(arg) + exp(-arg)) / 2.0 - public override fun tanh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) - public override fun asinh(arg: Quaternion): Quaternion = ln(sqrt(arg * arg + one) + arg) - public override fun acosh(arg: Quaternion): Quaternion = ln(arg + sqrt((arg - one) * (arg + one))) - public override fun atanh(arg: Quaternion): Quaternion = (ln(arg + one) - ln(one - arg)) / 2.0 + override fun sinh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / 2.0 + override fun cosh(arg: Quaternion): Quaternion = (exp(arg) + exp(-arg)) / 2.0 + override fun tanh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) + override fun asinh(arg: Quaternion): Quaternion = ln(sqrt(arg * arg + one) + arg) + override fun acosh(arg: Quaternion): Quaternion = ln(arg + sqrt((arg - one) * (arg + one))) + override fun atanh(arg: Quaternion): Quaternion = (ln(arg + one) - ln(one - arg)) / 2.0 } /** @@ -224,16 +224,16 @@ public data class Quaternion( /** * Returns a string representation of this quaternion. */ - public override fun toString(): String = "($w + $x * i + $y * j + $z * k)" + override fun toString(): String = "($w + $x * i + $y * j + $z * k)" public companion object : MemorySpec { - public override val objectSize: Int + override val objectSize: Int get() = 32 - public override fun MemoryReader.read(offset: Int): Quaternion = + override fun MemoryReader.read(offset: Int): Quaternion = Quaternion(readDouble(offset), readDouble(offset + 8), readDouble(offset + 16), readDouble(offset + 24)) - public override fun MemoryWriter.write(offset: Int, value: Quaternion) { + override fun MemoryWriter.write(offset: Int, value: Quaternion) { writeDouble(offset, value.w) writeDouble(offset + 8, value.x) writeDouble(offset + 16, value.y) diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexBufferSpecTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexBufferSpecTest.kt index 17a077ea7..87239654d 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexBufferSpecTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexBufferSpecTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.complex diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexFieldTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexFieldTest.kt index cbaaa815b..90e624343 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexFieldTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexFieldTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.complex diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexTest.kt index 7ad7f883d..a37006f75 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ComplexTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.complex diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt index 4279471d4..00ae5ede1 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.complex diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/QuaternionFieldTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/QuaternionFieldTest.kt index 6784f3516..319460c74 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/QuaternionFieldTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/QuaternionFieldTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.complex diff --git a/kmath-core/README.md b/kmath-core/README.md index 6ca8c8ef8..4ea493f44 100644 --- a/kmath-core/README.md +++ b/kmath-core/README.md @@ -10,12 +10,12 @@ The core interfaces of KMath. objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high performance calculations to code generation. - [domains](src/commonMain/kotlin/space/kscience/kmath/domains) : Domains - - [autodif](src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation + - [autodiff](src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-13`. +The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-14`. **Gradle:** ```gradle @@ -25,7 +25,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-core:0.3.0-dev-13' + implementation 'space.kscience:kmath-core:0.3.0-dev-14' } ``` **Gradle Kotlin DSL:** @@ -36,6 +36,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-core:0.3.0-dev-13") + implementation("space.kscience:kmath-core:0.3.0-dev-14") } ``` diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index 92a5f419d..e4436c1df 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -2,6 +2,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") id("ru.mipt.npm.gradle.native") +// id("com.xcporter.metaview") version "0.0.5" } kotlin.sourceSets { @@ -12,6 +13,12 @@ kotlin.sourceSets { } } +//generateUml { +// classTree { +// +// } +//} + readme { description = "Core classes, algebra definitions, basic linear algebra" maturity = ru.mipt.npm.gradle.Maturity.DEVELOPMENT @@ -19,51 +26,42 @@ readme { feature( id = "algebras", - description = """ - Algebraic structures like rings, spaces and fields. - """.trimIndent(), - ref = "src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt" - ) + ref = "src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt", + ) { "Algebraic structures like rings, spaces and fields." } feature( id = "nd", - description = "Many-dimensional structures and operations on them.", - ref = "src/commonMain/kotlin/space/kscience/kmath/structures/StructureND.kt" - ) + ref = "src/commonMain/kotlin/space/kscience/kmath/structures/StructureND.kt", + ) { "Many-dimensional structures and operations on them." } feature( id = "linear", - description = """ - Basic linear algebra operations (sums, products, etc.), backed by the `Space` API. Advanced linear algebra operations like matrix inversion and LU decomposition. - """.trimIndent(), - ref = "src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt" - ) + ref = "src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt", + ) { "Basic linear algebra operations (sums, products, etc.), backed by the `Space` API. Advanced linear algebra operations like matrix inversion and LU decomposition." } feature( id = "buffers", - description = "One-dimensional structure", - ref = "src/commonMain/kotlin/space/kscience/kmath/structures/Buffers.kt" - ) + ref = "src/commonMain/kotlin/space/kscience/kmath/structures/Buffers.kt", + ) { "One-dimensional structure" } feature( id = "expressions", - description = """ + ref = "src/commonMain/kotlin/space/kscience/kmath/expressions" + ) { + """ By writing a single mathematical expression once, users will be able to apply different types of objects to the expression by providing a context. Expressions can be used for a wide variety of purposes from high performance calculations to code generation. - """.trimIndent(), - ref = "src/commonMain/kotlin/space/kscience/kmath/expressions" - ) + """.trimIndent() + } feature( id = "domains", - description = "Domains", - ref = "src/commonMain/kotlin/space/kscience/kmath/domains" - ) + ref = "src/commonMain/kotlin/space/kscience/kmath/domains", + ) { "Domains" } feature( - id = "autodif", - description = "Automatic differentiation", + id = "autodiff", ref = "src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt" - ) + ) { "Automatic differentiation" } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/ColumnarData.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/ColumnarData.kt index 88c14d311..53c4b4d1e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/ColumnarData.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/ColumnarData.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.data @@ -25,6 +25,9 @@ public interface ColumnarData { public operator fun get(symbol: Symbol): Buffer? } +@UnstableKMathAPI +public val ColumnarData<*>.indices: IntRange get() = 0 until size + /** * A zero-copy method to represent a [Structure2D] as a two-column x-y data. * There could more than two columns in the structure. diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYColumnarData.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYColumnarData.kt index 08bfd3ca3..ffec339bf 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYColumnarData.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYColumnarData.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.data @@ -16,7 +16,7 @@ import kotlin.math.max * The buffer of X values. */ @UnstableKMathAPI -public interface XYColumnarData : ColumnarData { +public interface XYColumnarData : ColumnarData { /** * The buffer of X values */ @@ -32,20 +32,43 @@ public interface XYColumnarData : ColumnarData { Symbol.y -> y else -> null } -} -@Suppress("FunctionName") -@UnstableKMathAPI -public fun XYColumnarData(x: Buffer, y: Buffer): XYColumnarData { - require(x.size == y.size) { "Buffer size mismatch. x buffer size is ${x.size}, y buffer size is ${y.size}" } - return object : XYColumnarData { - override val size: Int = x.size - override val x: Buffer = x - override val y: Buffer = y + public companion object{ + @UnstableKMathAPI + public fun of(x: Buffer, y: Buffer): XYColumnarData { + require(x.size == y.size) { "Buffer size mismatch. x buffer size is ${x.size}, y buffer size is ${y.size}" } + return object : XYColumnarData { + override val size: Int = x.size + override val x: Buffer = x + override val y: Buffer = y + } + } } } +/** + * Represent a [ColumnarData] as an [XYColumnarData]. The presence or respective columns is checked on creation. + */ +@UnstableKMathAPI +public fun ColumnarData.asXYData( + xSymbol: Symbol, + ySymbol: Symbol, +): XYColumnarData = object : XYColumnarData { + init { + requireNotNull(this@asXYData[xSymbol]){"The column with name $xSymbol is not present in $this"} + requireNotNull(this@asXYData[ySymbol]){"The column with name $ySymbol is not present in $this"} + } + override val size: Int get() = this@asXYData.size + override val x: Buffer get() = this@asXYData[xSymbol]!! + override val y: Buffer get() = this@asXYData[ySymbol]!! + override fun get(symbol: Symbol): Buffer? = when (symbol) { + Symbol.x -> x + Symbol.y -> y + else -> this@asXYData.get(symbol) + } +} + /** * A zero-copy method to represent a [Structure2D] as a two-column x-y data. * There could more than two columns in the structure. diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYErrorColumnarData.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYErrorColumnarData.kt new file mode 100644 index 000000000..8ddd6406f --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYErrorColumnarData.kt @@ -0,0 +1,44 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.data + +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.structures.Buffer + + +/** + * A [ColumnarData] with additional [Symbol.yError] column for an [Symbol.y] error + * Inherits [XYColumnarData]. + */ +@UnstableKMathAPI +public interface XYErrorColumnarData : XYColumnarData { + public val yErr: Buffer + + override fun get(symbol: Symbol): Buffer = when (symbol) { + Symbol.x -> x + Symbol.y -> y + Symbol.yError -> yErr + else -> error("A column for symbol $symbol not found") + } + + public companion object { + public fun of( + x: Buffer, y: Buffer, yErr: Buffer + ): XYErrorColumnarData { + require(x.size == y.size) { "Buffer size mismatch. x buffer size is ${x.size}, y buffer size is ${y.size}" } + require(y.size == yErr.size) { "Buffer size mismatch. y buffer size is ${x.size}, yErr buffer size is ${y.size}" } + + return object : XYErrorColumnarData { + override val size: Int = x.size + override val x: Buffer = x + override val y: Buffer = y + override val yErr: Buffer = yErr + } + } + } +} + diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYZColumnarData.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYZColumnarData.kt index 39a6b858c..a4a08f626 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYZColumnarData.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/data/XYZColumnarData.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.data @@ -10,11 +10,11 @@ import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.structures.Buffer /** - * A [XYColumnarData] with guaranteed [x], [y] and [z] columns designated by corresponding symbols. + * A [ColumnarData] with guaranteed [x], [y] and [z] columns designated by corresponding symbols. * Inherits [XYColumnarData]. */ @UnstableKMathAPI -public interface XYZColumnarData : XYColumnarData { +public interface XYZColumnarData : XYColumnarData { public val z: Buffer override fun get(symbol: Symbol): Buffer? = when (symbol) { @@ -23,4 +23,4 @@ public interface XYZColumnarData : XYColumna Symbol.z -> z else -> null } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/Domain.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/Domain.kt index e6e703cbf..0c4d2307b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/Domain.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/Domain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.domains @@ -12,7 +12,7 @@ import space.kscience.kmath.linear.Point * * @param T the type of element of this domain. */ -public interface Domain { +public interface Domain { /** * Checks if the specified point is contained in this domain. */ diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/DoubleDomain.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/DoubleDomain.kt index ee1bebde0..aee1d52c5 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/DoubleDomain.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/DoubleDomain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.domains diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/HyperSquareDomain.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/HyperSquareDomain.kt index f5560d935..7ea3e22c4 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/HyperSquareDomain.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/HyperSquareDomain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.domains @@ -17,17 +17,17 @@ import space.kscience.kmath.structures.indices */ @UnstableKMathAPI public class HyperSquareDomain(private val lower: Buffer, private val upper: Buffer) : DoubleDomain { - public override val dimension: Int get() = lower.size + override val dimension: Int get() = lower.size - public override operator fun contains(point: Point): Boolean = point.indices.all { i -> + override operator fun contains(point: Point): Boolean = point.indices.all { i -> point[i] in lower[i]..upper[i] } - public override fun getLowerBound(num: Int): Double = lower[num] + override fun getLowerBound(num: Int): Double = lower[num] - public override fun getUpperBound(num: Int): Double = upper[num] + override fun getUpperBound(num: Int): Double = upper[num] - public override fun volume(): Double { + override fun volume(): Double { var res = 1.0 for (i in 0 until dimension) { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/UnconstrainedDomain.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/UnconstrainedDomain.kt index 7ffc0659d..040bb80b0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/UnconstrainedDomain.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/UnconstrainedDomain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.domains @@ -8,12 +8,12 @@ import space.kscience.kmath.linear.Point import space.kscience.kmath.misc.UnstableKMathAPI @UnstableKMathAPI -public class UnconstrainedDomain(public override val dimension: Int) : DoubleDomain { - public override operator fun contains(point: Point): Boolean = true +public class UnconstrainedDomain(override val dimension: Int) : DoubleDomain { + override operator fun contains(point: Point): Boolean = true - public override fun getLowerBound(num: Int): Double = Double.NEGATIVE_INFINITY + override fun getLowerBound(num: Int): Double = Double.NEGATIVE_INFINITY - public override fun getUpperBound(num: Int): Double = Double.POSITIVE_INFINITY + override fun getUpperBound(num: Int): Double = Double.POSITIVE_INFINITY - public override fun volume(): Double = Double.POSITIVE_INFINITY + override fun volume(): Double = Double.POSITIVE_INFINITY } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/UnivariateDomain.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/UnivariateDomain.kt index e7acada85..a5add6a0b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/UnivariateDomain.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/domains/UnivariateDomain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.domains @@ -10,24 +10,24 @@ import space.kscience.kmath.misc.UnstableKMathAPI @UnstableKMathAPI public class UnivariateDomain(public val range: ClosedFloatingPointRange) : DoubleDomain { - public override val dimension: Int get() = 1 + override val dimension: Int get() = 1 public operator fun contains(d: Double): Boolean = range.contains(d) - public override operator fun contains(point: Point): Boolean { + override operator fun contains(point: Point): Boolean { require(point.size == 0) return contains(point[0]) } - public override fun getLowerBound(num: Int): Double { + override fun getLowerBound(num: Int): Double { require(num == 0) return range.start } - public override fun getUpperBound(num: Int): Double { + override fun getUpperBound(num: Int): Double { require(num == 0) return range.endInclusive } - public override fun volume(): Double = range.endInclusive - range.start + override fun volume(): Double = range.endInclusive - range.start } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DifferentiableExpression.kt index 1dcada6d3..758b992a9 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DifferentiableExpression.kt @@ -1,15 +1,16 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions +import space.kscience.kmath.operations.Algebra + /** - * Represents expression which structure can be differentiated. + * Represents expression, which structure can be differentiated. * * @param T the type this expression takes as argument and returns. - * @param R the type of expression this expression can be differentiated to. */ public interface DifferentiableExpression : Expression { /** @@ -24,16 +25,18 @@ public interface DifferentiableExpression : Expression { public fun DifferentiableExpression.derivative(symbols: List): Expression = derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") -public fun DifferentiableExpression.derivative(vararg symbols: Symbol): Expression = +public fun DifferentiableExpression.derivative(vararg symbols: Symbol): Expression = derivative(symbols.toList()) -public fun DifferentiableExpression.derivative(name: String): Expression = +public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name)) /** - * A special type of [DifferentiableExpression] which returns typed expressions as derivatives + * A special type of [DifferentiableExpression] which returns typed expressions as derivatives. + * + * @param R the type of expression this expression can be differentiated to. */ -public interface SpecialDifferentiableExpression>: DifferentiableExpression { +public interface SpecialDifferentiableExpression> : DifferentiableExpression { override fun derivativeOrNull(symbols: List): R? } @@ -53,9 +56,9 @@ public abstract class FirstDerivativeExpression : DifferentiableExpression /** * Returns first derivative of this expression by given [symbol]. */ - public abstract fun derivativeOrNull(symbol: Symbol): Expression? + public abstract fun derivativeOrNull(symbol: Symbol): Expression? - public final override fun derivativeOrNull(symbols: List): Expression? { + public final override fun derivativeOrNull(symbols: List): Expression? { val dSymbol = symbols.firstOrNull() ?: return null return derivativeOrNull(dSymbol) } @@ -63,7 +66,10 @@ public abstract class FirstDerivativeExpression : DifferentiableExpression /** * A factory that converts an expression in autodiff variables to a [DifferentiableExpression] + * @param T type of the constants for the expression + * @param I type of the actual expression state + * @param A type of expression algebra */ -public fun interface AutoDiffProcessor, out R : Expression> { - public fun process(function: A.() -> I): DifferentiableExpression +public fun interface AutoDiffProcessor> { + public fun differentiate(function: A.() -> I): DifferentiableExpression } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt index 84e66918f..edd020c9a 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions @@ -43,7 +43,7 @@ public operator fun Expression.invoke(vararg pairs: Pair): T = /** * Calls this expression from arguments. * - * @param pairs the pairs of arguments' names to values. + * @param pairs the pairs of arguments' names to value. * @return a value. */ @JvmName("callByString") @@ -60,7 +60,7 @@ public operator fun Expression.invoke(vararg pairs: Pair): T = public interface ExpressionAlgebra : Algebra { /** - * A constant expression which does not depend on arguments + * A constant expression that does not depend on arguments. */ public fun const(value: T): E } @@ -68,6 +68,7 @@ public interface ExpressionAlgebra : Algebra { /** * Bind a symbol by name inside the [ExpressionAlgebra] */ -public fun ExpressionAlgebra.binding(): ReadOnlyProperty = ReadOnlyProperty { _, property -> - bindSymbol(property.name) ?: error("A variable with name ${property.name} does not exist") -} +public val ExpressionAlgebra.binding: ReadOnlyProperty + get() = ReadOnlyProperty { _, property -> + bindSymbol(property.name) ?: error("A variable with name ${property.name} does not exist") + } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 951ec9474..661680565 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -1,48 +1,44 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions import space.kscience.kmath.operations.* +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * A context class for [Expression] construction. * * @param algebra The algebra to provide for Expressions built. */ -public abstract class FunctionalExpressionAlgebra>( +public abstract class FunctionalExpressionAlgebra>( public val algebra: A, ) : ExpressionAlgebra> { /** - * Builds an Expression of constant expression which does not depend on arguments. + * Builds an Expression of constant expression that does not depend on arguments. */ - public override fun const(value: T): Expression = Expression { value } + override fun const(value: T): Expression = Expression { value } /** * Builds an Expression to access a variable. */ - public override fun bindSymbolOrNull(value: String): Expression? = Expression { arguments -> + override fun bindSymbolOrNull(value: String): Expression? = Expression { arguments -> algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)] ?: error("Symbol '$value' is not supported in $this") } - /** - * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. - */ - public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = { left, right -> Expression { arguments -> algebra.binaryOperationFunction(operation)(left.invoke(arguments), right.invoke(arguments)) } } - /** - * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. - */ - public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = { arg -> + override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = { arg -> Expression { arguments -> algebra.unaryOperationFunction(operation)(arg.invoke(arguments)) } } } @@ -50,24 +46,24 @@ public abstract class FunctionalExpressionAlgebra>( /** * A context class for [Expression] construction for [Ring] algebras. */ -public open class FunctionalExpressionGroup>( +public open class FunctionalExpressionGroup>( algebra: A, ) : FunctionalExpressionAlgebra(algebra), Group> { - public override val zero: Expression get() = const(algebra.zero) + override val zero: Expression get() = const(algebra.zero) - public override fun Expression.unaryMinus(): Expression = - unaryOperation(GroupOperations.MINUS_OPERATION, this) + override fun Expression.unaryMinus(): Expression = + unaryOperation(GroupOps.MINUS_OPERATION, this) /** * Builds an Expression of addition of two another expressions. */ - public override fun add(a: Expression, b: Expression): Expression = - binaryOperation(GroupOperations.PLUS_OPERATION, a, b) + override fun add(left: Expression, right: Expression): Expression = + binaryOperation(GroupOps.PLUS_OPERATION, left, right) // /** // * Builds an Expression of multiplication of expression by number. // */ -// public override fun multiply(a: Expression, k: Number): Expression = Expression { arguments -> +// override fun multiply(a: Expression, k: Number): Expression = Expression { arguments -> // algebra.multiply(a.invoke(arguments), k) // } @@ -76,111 +72,127 @@ public open class FunctionalExpressionGroup>( public operator fun T.plus(arg: Expression): Expression = arg + this public operator fun T.minus(arg: Expression): Expression = arg - this - public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = super.unaryOperationFunction(operation) - public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = super.binaryOperationFunction(operation) } -public open class FunctionalExpressionRing>( +public open class FunctionalExpressionRing>( algebra: A, ) : FunctionalExpressionGroup(algebra), Ring> { - public override val one: Expression get() = const(algebra.one) + override val one: Expression get() = const(algebra.one) /** * Builds an Expression of multiplication of two expressions. */ - public override fun multiply(a: Expression, b: Expression): Expression = - binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) + override fun multiply(left: Expression, right: Expression): Expression = + binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right) public operator fun Expression.times(arg: T): Expression = this * const(arg) public operator fun T.times(arg: Expression): Expression = arg * this - public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = super.unaryOperationFunction(operation) - public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = super.binaryOperationFunction(operation) } -public open class FunctionalExpressionField>( +public open class FunctionalExpressionField>( algebra: A, ) : FunctionalExpressionRing(algebra), Field>, ScaleOperations> { /** * Builds an Expression of division an expression by another one. */ - public override fun divide(a: Expression, b: Expression): Expression = - binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) + override fun divide(left: Expression, right: Expression): Expression = + binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right) public operator fun Expression.div(arg: T): Expression = this / const(arg) public operator fun T.div(arg: Expression): Expression = arg / this - public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = super.unaryOperationFunction(operation) - public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = super.binaryOperationFunction(operation) - public override fun scale(a: Expression, value: Double): Expression = algebra { + override fun scale(a: Expression, value: Double): Expression = algebra { Expression { args -> a(args) * value } } - public override fun bindSymbolOrNull(value: String): Expression? = + override fun bindSymbolOrNull(value: String): Expression? = super.bindSymbolOrNull(value) } -public open class FunctionalExpressionExtendedField>( +public open class FunctionalExpressionExtendedField>( algebra: A, ) : FunctionalExpressionField(algebra), ExtendedField> { - public override fun number(value: Number): Expression = const(algebra.number(value)) + override fun number(value: Number): Expression = const(algebra.number(value)) - public override fun sqrt(arg: Expression): Expression = + override fun sqrt(arg: Expression): Expression = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg) - public override fun sin(arg: Expression): Expression = + override fun sin(arg: Expression): Expression = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) - public override fun cos(arg: Expression): Expression = + override fun cos(arg: Expression): Expression = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg) - public override fun asin(arg: Expression): Expression = + override fun asin(arg: Expression): Expression = unaryOperationFunction(TrigonometricOperations.ASIN_OPERATION)(arg) - public override fun acos(arg: Expression): Expression = + override fun acos(arg: Expression): Expression = unaryOperationFunction(TrigonometricOperations.ACOS_OPERATION)(arg) - public override fun atan(arg: Expression): Expression = + override fun atan(arg: Expression): Expression = unaryOperationFunction(TrigonometricOperations.ATAN_OPERATION)(arg) - public override fun power(arg: Expression, pow: Number): Expression = + override fun power(arg: Expression, pow: Number): Expression = binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) - public override fun exp(arg: Expression): Expression = + override fun exp(arg: Expression): Expression = unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg) - public override fun ln(arg: Expression): Expression = + override fun ln(arg: Expression): Expression = unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg) - public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = + override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = super.unaryOperationFunction(operation) - public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = + override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = super.binaryOperationFunction(operation) - public override fun bindSymbol(value: String): Expression = super.bindSymbol(value) + override fun bindSymbol(value: String): Expression = super.bindSymbol(value) } -public inline fun > A.expressionInSpace(block: FunctionalExpressionGroup.() -> Expression): Expression = - FunctionalExpressionGroup(this).block() +public inline fun > A.expressionInGroup( + block: FunctionalExpressionGroup.() -> Expression, +): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionGroup(this).block() +} -public inline fun > A.expressionInRing(block: FunctionalExpressionRing.() -> Expression): Expression = - FunctionalExpressionRing(this).block() +public inline fun > A.expressionInRing( + block: FunctionalExpressionRing.() -> Expression, +): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionRing(this).block() +} -public inline fun > A.expressionInField(block: FunctionalExpressionField.() -> Expression): Expression = - FunctionalExpressionField(this).block() +public inline fun > A.expressionInField( + block: FunctionalExpressionField.() -> Expression, +): Expression { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return FunctionalExpressionField(this).block() +} public inline fun > A.expressionInExtendedField( block: FunctionalExpressionExtendedField.() -> Expression, ): Expression = FunctionalExpressionExtendedField(this).block() + +public inline fun DoubleField.expression( + block: FunctionalExpressionExtendedField.() -> Expression, +): Expression = FunctionalExpressionExtendedField(this).block() diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt index 7533024a1..fe50902b1 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt index 4729f19ea..ca0671ccb 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MstAlgebra.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions @@ -12,14 +12,14 @@ import space.kscience.kmath.operations.* * [Algebra] over [MST] nodes. */ public object MstNumericAlgebra : NumericAlgebra { - public override fun number(value: Number): MST.Numeric = MST.Numeric(value) - public override fun bindSymbolOrNull(value: String): Symbol = StringSymbol(value) + override fun number(value: Number): MST.Numeric = MST.Numeric(value) + override fun bindSymbolOrNull(value: String): Symbol = StringSymbol(value) override fun bindSymbol(value: String): Symbol = bindSymbolOrNull(value) - public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = { arg -> MST.Unary(operation, arg) } - public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = { left, right -> MST.Binary(operation, left, right) } } @@ -27,27 +27,27 @@ public object MstNumericAlgebra : NumericAlgebra { * [Group] over [MST] nodes. */ public object MstGroup : Group, NumericAlgebra, ScaleOperations { - public override val zero: MST.Numeric = number(0.0) + override val zero: MST.Numeric = number(0.0) - public override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) - public override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) - public override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOperations.PLUS_OPERATION)(a, b) - public override operator fun MST.unaryPlus(): MST.Unary = - unaryOperationFunction(GroupOperations.PLUS_OPERATION)(this) + override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) + override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) + override fun add(left: MST, right: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(left, right) + override operator fun MST.unaryPlus(): MST.Unary = + unaryOperationFunction(GroupOps.PLUS_OPERATION)(this) - public override operator fun MST.unaryMinus(): MST.Unary = - unaryOperationFunction(GroupOperations.MINUS_OPERATION)(this) + override operator fun MST.unaryMinus(): MST.Unary = + unaryOperationFunction(GroupOps.MINUS_OPERATION)(this) - public override operator fun MST.minus(b: MST): MST.Binary = - binaryOperationFunction(GroupOperations.MINUS_OPERATION)(this, b) + override operator fun MST.minus(arg: MST): MST.Binary = + binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, arg) - public override fun scale(a: MST, value: Double): MST.Binary = - binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(value)) + override fun scale(a: MST, value: Double): MST.Binary = + binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value)) - public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstNumericAlgebra.binaryOperationFunction(operation) - public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = MstNumericAlgebra.unaryOperationFunction(operation) } @@ -56,28 +56,28 @@ public object MstGroup : Group, NumericAlgebra, ScaleOperations { */ @Suppress("OVERRIDE_BY_INLINE") @OptIn(UnstableKMathAPI::class) -public object MstRing : Ring, NumbersAddOperations, ScaleOperations { - public override inline val zero: MST.Numeric get() = MstGroup.zero - public override val one: MST.Numeric = number(1.0) +public object MstRing : Ring, NumbersAddOps, ScaleOperations { + override inline val zero: MST.Numeric get() = MstGroup.zero + override val one: MST.Numeric = number(1.0) - public override fun number(value: Number): MST.Numeric = MstGroup.number(value) - public override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) - public override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b) + override fun number(value: Number): MST.Numeric = MstGroup.number(value) + override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) + override fun add(left: MST, right: MST): MST.Binary = MstGroup.add(left, right) - public override fun scale(a: MST, value: Double): MST.Binary = - MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value)) + override fun scale(a: MST, value: Double): MST.Binary = + MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value)) - public override fun multiply(a: MST, b: MST): MST.Binary = - binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) + override fun multiply(left: MST, right: MST): MST.Binary = + binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right) - public override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus } - public override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus } - public override operator fun MST.minus(b: MST): MST.Binary = MstGroup { this@minus - b } + override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus } + override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus } + override operator fun MST.minus(arg: MST): MST.Binary = MstGroup { this@minus - arg } - public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstGroup.binaryOperationFunction(operation) - public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = MstNumericAlgebra.unaryOperationFunction(operation) } @@ -86,29 +86,29 @@ public object MstRing : Ring, NumbersAddOperations, ScaleOperations, NumbersAddOperations, ScaleOperations { - public override inline val zero: MST.Numeric get() = MstRing.zero - public override inline val one: MST.Numeric get() = MstRing.one +public object MstField : Field, NumbersAddOps, ScaleOperations { + override inline val zero: MST.Numeric get() = MstRing.zero + override inline val one: MST.Numeric get() = MstRing.one - public override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) - public override fun number(value: Number): MST.Numeric = MstRing.number(value) - public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) + override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) + override fun number(value: Number): MST.Numeric = MstRing.number(value) + override fun add(left: MST, right: MST): MST.Binary = MstRing.add(left, right) - public override fun scale(a: MST, value: Double): MST.Binary = - MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value)) + override fun scale(a: MST, value: Double): MST.Binary = + MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value)) - public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) - public override fun divide(a: MST, b: MST): MST.Binary = - binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) + override fun multiply(left: MST, right: MST): MST.Binary = MstRing.multiply(left, right) + override fun divide(left: MST, right: MST): MST.Binary = + binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right) - public override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } - public override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } - public override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b } + override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } + override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } + override operator fun MST.minus(arg: MST): MST.Binary = MstRing { this@minus - arg } - public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstRing.binaryOperationFunction(operation) - public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = MstRing.unaryOperationFunction(operation) } @@ -117,45 +117,45 @@ public object MstField : Field, NumbersAddOperations, ScaleOperations< */ @Suppress("OVERRIDE_BY_INLINE") public object MstExtendedField : ExtendedField, NumericAlgebra { - public override inline val zero: MST.Numeric get() = MstField.zero - public override inline val one: MST.Numeric get() = MstField.one + override inline val zero: MST.Numeric get() = MstField.zero + override inline val one: MST.Numeric get() = MstField.one - public override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) - public override fun number(value: Number): MST.Numeric = MstRing.number(value) - public override fun sin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) - public override fun cos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg) - public override fun tan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.TAN_OPERATION)(arg) - public override fun asin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ASIN_OPERATION)(arg) - public override fun acos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ACOS_OPERATION)(arg) - public override fun atan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ATAN_OPERATION)(arg) - public override fun sinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.SINH_OPERATION)(arg) - public override fun cosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.COSH_OPERATION)(arg) - public override fun tanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.TANH_OPERATION)(arg) - public override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg) - public override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg) - public override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg) - public override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) - public override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg) + override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) + override fun number(value: Number): MST.Numeric = MstRing.number(value) + override fun sin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) + override fun cos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg) + override fun tan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.TAN_OPERATION)(arg) + override fun asin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ASIN_OPERATION)(arg) + override fun acos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ACOS_OPERATION)(arg) + override fun atan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.ATAN_OPERATION)(arg) + override fun sinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.SINH_OPERATION)(arg) + override fun cosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.COSH_OPERATION)(arg) + override fun tanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.TANH_OPERATION)(arg) + override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg) + override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg) + override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg) + override fun add(left: MST, right: MST): MST.Binary = MstField.add(left, right) + override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg) - public override fun scale(a: MST, value: Double): MST = - binaryOperation(GroupOperations.PLUS_OPERATION, a, number(value)) + override fun scale(a: MST, value: Double): MST = + binaryOperation(GroupOps.PLUS_OPERATION, a, number(value)) - public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) - public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) - public override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus } - public override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus } - public override operator fun MST.minus(b: MST): MST.Binary = MstField { this@minus - b } + override fun multiply(left: MST, right: MST): MST.Binary = MstField.multiply(left, right) + override fun divide(left: MST, right: MST): MST.Binary = MstField.divide(left, right) + override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus } + override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus } + override operator fun MST.minus(arg: MST): MST.Binary = MstField { this@minus - arg } - public override fun power(arg: MST, pow: Number): MST.Binary = + override fun power(arg: MST, pow: Number): MST.Binary = binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) - public override fun exp(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg) - public override fun ln(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg) + override fun exp(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg) + override fun ln(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg) - public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = + override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstField.binaryOperationFunction(operation) - public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = MstField.unaryOperationFunction(operation) } @@ -164,7 +164,7 @@ public object MstExtendedField : ExtendedField, NumericAlgebra { */ @UnstableKMathAPI public object MstLogicAlgebra : LogicAlgebra { - public override fun bindSymbolOrNull(value: String): MST = super.bindSymbolOrNull(value) ?: StringSymbol(value) + override fun bindSymbolOrNull(value: String): MST = super.bindSymbolOrNull(value) ?: StringSymbol(value) override fun const(boolean: Boolean): Symbol = if (boolean) { LogicAlgebra.TRUE diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index 478b85620..704c4edd8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions @@ -59,9 +59,9 @@ public fun DerivationResult.grad(vararg variables: Symbol): Point>( public val context: F, bindings: Map, -) : Field>, ExpressionAlgebra>, NumbersAddOperations> { - public override val zero: AutoDiffValue get() = const(context.zero) - public override val one: AutoDiffValue get() = const(context.one) +) : Field>, ExpressionAlgebra>, NumbersAddOps> { + override val zero: AutoDiffValue get() = const(context.zero) + override val one: AutoDiffValue get() = const(context.one) // this stack contains pairs of blocks and values to apply them to private var stack: Array = arrayOfNulls(8) @@ -119,8 +119,6 @@ public open class SimpleAutoDiffField>( get() = getDerivative(this) set(value) = setDerivative(this, value) - public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block()) - /** * Performs update of derivative after the rest of the formula in the back-pass. * @@ -151,17 +149,17 @@ public open class SimpleAutoDiffField>( // // Overloads for Double constants // -// public override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue = +// override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue = // derive(const { this@plus.toDouble() * one + b.value }) { z -> // b.d += z.d // } // -// public override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this) +// override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this) // -// public override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue = +// override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue = // derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } // -// public override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue = +// override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue = // derive(const { this@minus.value - one * b.toDouble() }) { z -> d += z.d } @@ -170,30 +168,35 @@ public open class SimpleAutoDiffField>( // Basic math (+, -, *, /) - public override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = - derive(const { a.value + b.value }) { z -> - a.d += z.d - b.d += z.d + override fun add(left: AutoDiffValue, right: AutoDiffValue): AutoDiffValue = + derive(const { left.value + right.value }) { z -> + left.d += z.d + right.d += z.d } - public override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = - derive(const { a.value * b.value }) { z -> - a.d += z.d * b.value - b.d += z.d * a.value + override fun multiply(left: AutoDiffValue, right: AutoDiffValue): AutoDiffValue = + derive(const { left.value * right.value }) { z -> + left.d += z.d * right.value + right.d += z.d * left.value } - public override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = - derive(const { a.value / b.value }) { z -> - a.d += z.d / b.value - b.d -= z.d * a.value / (b.value * b.value) + override fun divide(left: AutoDiffValue, right: AutoDiffValue): AutoDiffValue = + derive(const { left.value / right.value }) { z -> + left.d += z.d / right.value + right.d -= z.d * left.value / (right.value * right.value) } - public override fun scale(a: AutoDiffValue, value: Double): AutoDiffValue = + override fun scale(a: AutoDiffValue, value: Double): AutoDiffValue = derive(const { value * a.value }) { z -> a.d += z.d * value } } +public inline fun > SimpleAutoDiffField.const(block: F.() -> T): AutoDiffValue { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return const(context.block()) +} + /** * Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code. @@ -208,7 +211,7 @@ public open class SimpleAutoDiffField>( * assertEquals(9.0, x.d) // dy/dx * ``` * - * @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to. + * @param body the action in [SimpleAutoDiffField] context returning [AutoDiffValue] to differentiate with respect to. * @return the result of differentiation. */ public fun > F.simpleAutoDiff( @@ -233,12 +236,12 @@ public class SimpleAutoDiffExpression>( public val field: F, public val function: SimpleAutoDiffField.() -> AutoDiffValue, ) : FirstDerivativeExpression() { - public override operator fun invoke(arguments: Map): T { + override operator fun invoke(arguments: Map): T { //val bindings = arguments.entries.map { it.key.bind(it.value) } return SimpleAutoDiffField(field, arguments).function().value } - public override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments -> + override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments -> //val bindings = arguments.entries.map { it.key.bind(it.value) } val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function) derivationResult.derivative(symbol) @@ -248,7 +251,9 @@ public class SimpleAutoDiffExpression>( /** * Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression] */ -public fun > simpleAutoDiff(field: F): AutoDiffProcessor, SimpleAutoDiffField, Expression> = +public fun > simpleAutoDiff( + field: F +): AutoDiffProcessor, SimpleAutoDiffField> = AutoDiffProcessor { function -> SimpleAutoDiffExpression(field, function) } @@ -343,28 +348,28 @@ public class SimpleAutoDiffExtendedField>( override fun bindSymbol(value: String): AutoDiffValue = super.bindSymbol(value) - public override fun number(value: Number): AutoDiffValue = const { number(value) } + override fun number(value: Number): AutoDiffValue = const { number(value) } - public override fun scale(a: AutoDiffValue, value: Double): AutoDiffValue = a * number(value) + override fun scale(a: AutoDiffValue, value: Double): AutoDiffValue = a * number(value) // x ^ 2 public fun sqr(x: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).sqr(x) // x ^ 1/2 - public override fun sqrt(arg: AutoDiffValue): AutoDiffValue = + override fun sqrt(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).sqrt(arg) // x ^ y (const) - public override fun power(arg: AutoDiffValue, pow: Number): AutoDiffValue = + override fun power(arg: AutoDiffValue, pow: Number): AutoDiffValue = (this as SimpleAutoDiffField).pow(arg, pow.toDouble()) // exp(x) - public override fun exp(arg: AutoDiffValue): AutoDiffValue = + override fun exp(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).exp(arg) // ln(x) - public override fun ln(arg: AutoDiffValue): AutoDiffValue = + override fun ln(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).ln(arg) // x ^ y (any) @@ -374,40 +379,40 @@ public class SimpleAutoDiffExtendedField>( ): AutoDiffValue = exp(y * ln(x)) // sin(x) - public override fun sin(arg: AutoDiffValue): AutoDiffValue = + override fun sin(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).sin(arg) // cos(x) - public override fun cos(arg: AutoDiffValue): AutoDiffValue = + override fun cos(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).cos(arg) - public override fun tan(arg: AutoDiffValue): AutoDiffValue = + override fun tan(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).tan(arg) - public override fun asin(arg: AutoDiffValue): AutoDiffValue = + override fun asin(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).asin(arg) - public override fun acos(arg: AutoDiffValue): AutoDiffValue = + override fun acos(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).acos(arg) - public override fun atan(arg: AutoDiffValue): AutoDiffValue = + override fun atan(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).atan(arg) - public override fun sinh(arg: AutoDiffValue): AutoDiffValue = + override fun sinh(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).sinh(arg) - public override fun cosh(arg: AutoDiffValue): AutoDiffValue = + override fun cosh(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).cosh(arg) - public override fun tanh(arg: AutoDiffValue): AutoDiffValue = + override fun tanh(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).tanh(arg) - public override fun asinh(arg: AutoDiffValue): AutoDiffValue = + override fun asinh(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).asinh(arg) - public override fun acosh(arg: AutoDiffValue): AutoDiffValue = + override fun acosh(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).acosh(arg) - public override fun atanh(arg: AutoDiffValue): AutoDiffValue = + override fun atanh(arg: AutoDiffValue): AutoDiffValue = (this as SimpleAutoDiffField).atanh(arg) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Symbol.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Symbol.kt index 74dc7aedc..cd49e4519 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Symbol.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Symbol.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions @@ -19,9 +19,12 @@ public interface Symbol : MST { public val identity: String public companion object { - public val x: StringSymbol = StringSymbol("x") - public val y: StringSymbol = StringSymbol("y") - public val z: StringSymbol = StringSymbol("z") + public val x: Symbol = Symbol("x") + public val xError: Symbol = Symbol("x.error") + public val y: Symbol = Symbol("y") + public val yError: Symbol = Symbol("y.error") + public val z: Symbol = Symbol("z") + public val zError: Symbol = Symbol("z.error") } } @@ -29,10 +32,15 @@ public interface Symbol : MST { * A [Symbol] with a [String] identity */ @JvmInline -public value class StringSymbol(override val identity: String) : Symbol { +internal value class StringSymbol(override val identity: String) : Symbol { override fun toString(): String = identity } +/** + * Create s Symbols with a string identity + */ +public fun Symbol(identity: String): Symbol = StringSymbol(identity) + /** * A delegate to create a symbol with a string identity in this scope */ diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SymbolIndexer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SymbolIndexer.kt index 06634704c..e8005096c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SymbolIndexer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SymbolIndexer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions @@ -9,6 +9,9 @@ import space.kscience.kmath.linear.Point import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.structures.BufferFactory +import space.kscience.kmath.structures.DoubleBuffer +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.jvm.JvmInline /** @@ -45,6 +48,11 @@ public interface SymbolIndexer { return symbols.indices.associate { symbols[it] to get(it) } } + public fun Point.toMap(): Map { + require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" } + return symbols.indices.associate { symbols[it] to get(it) } + } + public operator fun Structure2D.get(rowSymbol: Symbol, columnSymbol: Symbol): T = get(indexOf(rowSymbol), indexOf(columnSymbol)) @@ -54,6 +62,10 @@ public interface SymbolIndexer { public fun Map.toPoint(bufferFactory: BufferFactory): Point = bufferFactory(symbols.size) { getValue(symbols[it]) } + public fun Map.toPoint(): DoubleBuffer = + DoubleBuffer(symbols.size) { getValue(symbols[it]) } + + public fun Map.toDoubleArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) } } @@ -65,9 +77,13 @@ public value class SimpleSymbolIndexer(override val symbols: List) : Sym * Execute the block with symbol indexer based on given symbol order */ @UnstableKMathAPI -public inline fun withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R = - with(SimpleSymbolIndexer(symbols.toList()), block) +public inline fun withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return with(SimpleSymbolIndexer(symbols.toList()), block) +} @UnstableKMathAPI -public inline fun withSymbols(symbols: Collection, block: SymbolIndexer.() -> R): R = - with(SimpleSymbolIndexer(symbols.toList()), block) \ No newline at end of file +public inline fun withSymbols(symbols: Collection, block: SymbolIndexer.() -> R): R { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return with(SimpleSymbolIndexer(symbols.toList()), block) +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/specialExpressions.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/specialExpressions.kt new file mode 100644 index 000000000..907ce4004 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/specialExpressions.kt @@ -0,0 +1,53 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.expressions + +import space.kscience.kmath.operations.ExtendedField +import space.kscience.kmath.operations.asIterable +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.indices +import kotlin.jvm.JvmName + +/** + * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic + * differentiation. + * + * **WARNING** All elements of [yErr] must be positive. + */ +@JvmName("genericChiSquaredExpression") +public fun , I : Any, A> AutoDiffProcessor.chiSquaredExpression( + x: Buffer, + y: Buffer, + yErr: Buffer, + model: A.(I) -> I, +): DifferentiableExpression where A : ExtendedField, A : ExpressionAlgebra { + require(x.size == y.size) { "X and y buffers should be of the same size" } + require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } + + return differentiate { + var sum = zero + + x.indices.forEach { + val xValue = const(x[it]) + val yValue = const(y[it]) + val yErrValue = const(yErr[it]) + val modelValue = model(xValue) + sum += ((yValue - modelValue) / yErrValue).pow(2) + } + + sum + } +} + +public fun AutoDiffProcessor.chiSquaredExpression( + x: Buffer, + y: Buffer, + yErr: Buffer, + model: A.(I) -> I, +): DifferentiableExpression where A : ExtendedField, A : ExpressionAlgebra { + require(yErr.asIterable().all { it > 0.0 }) { "All errors must be strictly positive" } + return chiSquaredExpression(x, y, yErr, model) +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt index 9b4451a62..410fb8505 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/BufferedLinearSpace.kt @@ -1,48 +1,47 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.nd.* -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.nd.BufferedRingOpsND +import space.kscience.kmath.nd.as2D +import space.kscience.kmath.nd.asND +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.VirtualBuffer import space.kscience.kmath.structures.indices -public class BufferedLinearSpace>( - override val elementAlgebra: A, - private val bufferFactory: BufferFactory, +public class BufferedLinearSpace>( + private val bufferAlgebra: BufferAlgebra ) : LinearSpace { + override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra - private fun ndRing( - rows: Int, - cols: Int, - ): BufferedRingND = AlgebraND.ring(elementAlgebra, bufferFactory, rows, cols) + private val ndAlgebra = BufferedRingOpsND(bufferAlgebra) override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix = - ndRing(rows, columns).produce { (i, j) -> elementAlgebra.initializer(i, j) }.as2D() + ndAlgebra.structureND(intArrayOf(rows, columns)) { (i, j) -> elementAlgebra.initializer(i, j) }.as2D() override fun buildVector(size: Int, initializer: A.(Int) -> T): Point = - bufferFactory(size) { elementAlgebra.initializer(it) } + bufferAlgebra.buffer(size) { elementAlgebra.initializer(it) } - override fun Matrix.unaryMinus(): Matrix = ndRing(rowNum, colNum).run { - unwrap().map { -it }.as2D() + @OptIn(PerformancePitfall::class) + override fun Matrix.unaryMinus(): Matrix = ndAlgebra { + asND().map { -it }.as2D() } - override fun Matrix.plus(other: Matrix): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.plus(other: Matrix): Matrix = ndAlgebra { require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } - unwrap().plus(other.unwrap()).as2D() + asND().plus(other.asND()).as2D() } - override fun Matrix.minus(other: Matrix): Matrix = ndRing(rowNum, colNum).run { + override fun Matrix.minus(other: Matrix): Matrix = ndAlgebra { require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } - unwrap().minus(other.unwrap()).as2D() + asND().minus(other.asND()).as2D() } private fun Buffer.linearize() = if (this is VirtualBuffer) { @@ -85,7 +84,12 @@ public class BufferedLinearSpace>( } } - override fun Matrix.times(value: T): Matrix = ndRing(rowNum, colNum).run { - unwrap().map { it * value }.as2D() + @OptIn(PerformancePitfall::class) + override fun Matrix.times(value: T): Matrix = ndAlgebra { + asND().map { it * value }.as2D() } -} \ No newline at end of file +} + + +public fun > A.linearSpace(bufferFactory: BufferFactory): BufferedLinearSpace = + BufferedLinearSpace(BufferRingOps(this, bufferFactory)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt new file mode 100644 index 000000000..91db33bce --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/DoubleLinearSpace.kt @@ -0,0 +1,110 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.linear + +import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.DoubleFieldOpsND +import space.kscience.kmath.nd.as2D +import space.kscience.kmath.nd.asND +import space.kscience.kmath.operations.DoubleBufferOps +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.DoubleBuffer + +public object DoubleLinearSpace : LinearSpace { + + override val elementAlgebra: DoubleField get() = DoubleField + + override fun buildMatrix( + rows: Int, + columns: Int, + initializer: DoubleField.(i: Int, j: Int) -> Double + ): Matrix = DoubleFieldOpsND.structureND(intArrayOf(rows, columns)) { (i, j) -> + DoubleField.initializer(i, j) + }.as2D() + + override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer = + DoubleBuffer(size) { DoubleField.initializer(it) } + + override fun Matrix.unaryMinus(): Matrix = DoubleFieldOpsND { + asND().map { -it }.as2D() + } + + override fun Matrix.plus(other: Matrix): Matrix = DoubleFieldOpsND { + require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } + asND().plus(other.asND()).as2D() + } + + override fun Matrix.minus(other: Matrix): Matrix = DoubleFieldOpsND { + require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } + asND().minus(other.asND()).as2D() + } + + // Create a continuous in-memory representation of this vector for better memory layout handling + private fun Buffer.linearize() = if (this is DoubleBuffer) { + this.array + } else { + DoubleArray(size) { get(it) } + } + + @OptIn(PerformancePitfall::class) + override fun Matrix.dot(other: Matrix): Matrix { + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } + val rows = this@dot.rows.map { it.linearize() } + val columns = other.columns.map { it.linearize() } + return buildMatrix(rowNum, other.colNum) { i, j -> + val r = rows[i] + val c = columns[j] + var res = 0.0 + for (l in r.indices) { + res += r[l] * c[l] + } + res + } + } + + @OptIn(PerformancePitfall::class) + override fun Matrix.dot(vector: Point): DoubleBuffer { + require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" } + val rows = this@dot.rows.map { it.linearize() } + return DoubleBuffer(rowNum) { i -> + val r = rows[i] + var res = 0.0 + for (j in r.indices) { + res += r[j] * vector[j] + } + res + } + + } + + override fun Matrix.times(value: Double): Matrix = DoubleFieldOpsND { + asND().map { it * value }.as2D() + } + + public override fun Point.plus(other: Point): DoubleBuffer = DoubleBufferOps.run { + this@plus + other + } + + public override fun Point.minus(other: Point): DoubleBuffer = DoubleBufferOps.run { + this@minus - other + } + + public override fun Point.times(value: Double): DoubleBuffer = DoubleBufferOps.run { + scale(this@times, value) + } + + public operator fun Point.div(value: Double): DoubleBuffer = DoubleBufferOps.run { + scale(this@div, 1.0 / value) + } + + public override fun Double.times(v: Point): DoubleBuffer = v * this + + +} + +public val DoubleField.linearSpace: DoubleLinearSpace get() = DoubleLinearSpace diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSolver.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSolver.kt index 9c3ffd819..54d90baa8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSolver.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSolver.kt @@ -1,15 +1,13 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear -import space.kscience.kmath.nd.as1D - /** - * A group of methods to solve for *X* in equation *X = A -1 · B*, where *A* and *B* are matrices or - * vectors. + * A group of methods to solve for *X* in equation *X = A−1 · B*, where *A* and *B* are + * matrices or vectors. * * @param T the type of items. */ @@ -30,20 +28,3 @@ public interface LinearSolver { public fun inverse(matrix: Matrix): Matrix } -/** - * Convert matrix to vector if it is possible. - */ -public fun Matrix.asVector(): Point = - if (this.colNum == 1) - as1D() - else - error("Can't convert matrix with more than one column to vector") - -/** - * Creates an n × 1 [VirtualMatrix], where n is the size of the given buffer. - * - * @param T the type of elements contained in the buffer. - * @receiver a buffer. - * @return the new matrix. - */ -public fun Point.asMatrix(): VirtualMatrix = VirtualMatrix(size, 1) { i, _ -> get(i) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt index ec073ac48..5349ad864 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt @@ -1,13 +1,19 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.nd.* -import space.kscience.kmath.operations.* +import space.kscience.kmath.nd.MutableStructure2D +import space.kscience.kmath.nd.Structure2D +import space.kscience.kmath.nd.StructureFeature +import space.kscience.kmath.nd.as1D +import space.kscience.kmath.operations.BufferRingOps +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.Ring +import space.kscience.kmath.operations.invoke import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.DoubleBuffer @@ -29,12 +35,12 @@ public typealias MutableMatrix = MutableStructure2D public typealias Point = Buffer /** - * Basic operations on matrices and vectors. Operates on [Matrix]. + * Basic operations on matrices and vectors. * * @param T the type of items in the matrices. - * @param M the type of operated matrices. + * @param A the type of ring over [T]. */ -public interface LinearSpace> { +public interface LinearSpace> { public val elementAlgebra: A /** @@ -164,7 +170,7 @@ public interface LinearSpace> { public operator fun T.times(v: Point): Point = v * this /** - * Get a feature of the structure in this scope. Structure features take precedence other context features + * Compute a feature of the structure in this scope. Structure features take precedence other context features. * * @param F the type of feature. * @param structure the structure. @@ -172,7 +178,8 @@ public interface LinearSpace> { * @return a feature object or `null` if it isn't present. */ @UnstableKMathAPI - public fun getFeature(structure: Matrix, type: KClass): F? = structure.getFeature(type) + public fun computeFeature(structure: Matrix, type: KClass): F? = + structure.getFeature(type) public companion object { @@ -182,9 +189,10 @@ public interface LinearSpace> { public fun > buffered( algebra: A, bufferFactory: BufferFactory = Buffer.Companion::boxing, - ): LinearSpace = BufferedLinearSpace(algebra, bufferFactory) + ): LinearSpace = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory)) - public val real: LinearSpace = buffered(DoubleField, ::DoubleBuffer) + @Deprecated("use DoubleField.linearSpace") + public val double: LinearSpace = buffered(DoubleField, ::DoubleBuffer) /** * Automatic buffered matrix, unboxed if it is possible @@ -195,16 +203,32 @@ public interface LinearSpace> { } /** - * Get a feature of the structure in this scope. Structure features take precedence other context features + * Get a feature of the structure in this scope. Structure features take precedence other context features. * * @param T the type of items in the matrices. * @param F the type of feature. * @return a feature object or `null` if it isn't present. */ @UnstableKMathAPI -public inline fun LinearSpace.getFeature(structure: Matrix): F? = - getFeature(structure, F::class) +public inline fun LinearSpace.computeFeature(structure: Matrix): F? = + computeFeature(structure, F::class) -public operator fun , R> LS.invoke(block: LS.() -> R): R = run(block) +public inline operator fun , R> LS.invoke(block: LS.() -> R): R = run(block) + +/** + * Convert matrix to vector if it is possible. + */ +public fun Matrix.asVector(): Point = + if (this.colNum == 1) as1D() + else error("Can't convert matrix with more than one column to vector") + +/** + * Creates an n × 1 [VirtualMatrix], where n is the size of the given buffer. + * + * @param T the type of elements contained in the buffer. + * @receiver a buffer. + * @return the new matrix. + */ +public fun Point.asMatrix(): VirtualMatrix = VirtualMatrix(size, 1) { i, _ -> get(i) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt index f3653d394..95dd6d45c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt @@ -1,12 +1,11 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.nd.getFeature import space.kscience.kmath.operations.* import space.kscience.kmath.structures.BufferAccessor2D import space.kscience.kmath.structures.DoubleBuffer @@ -34,7 +33,7 @@ public class LupDecomposition( j == i -> elementContext.one else -> elementContext.zero } - } + LFeature + }.withFeature(LFeature) /** @@ -44,7 +43,7 @@ public class LupDecomposition( */ override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j >= i) lu[i, j] else elementContext.zero - } + UFeature + }.withFeature(UFeature) /** * Returns the P rows permutation matrix. @@ -82,7 +81,7 @@ public fun > LinearSpace>.lup( val m = matrix.colNum val pivot = IntArray(matrix.rowNum) - //TODO just waits for KEEP-176 + //TODO just waits for multi-receivers BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run { elementAlgebra { val lu = create(matrix) @@ -114,7 +113,7 @@ public fun > LinearSpace>.lup( for (i in 0 until col) sum -= luRow[i] * lu[i, col] luRow[col] = sum - // maintain best permutation choice + // maintain the best permutation choice if (abs(sum) > largest) { largest = abs(sum) max = row @@ -156,10 +155,13 @@ public inline fun > LinearSpace>.lup( noinline checkSingular: (T) -> Boolean, ): LupDecomposition = lup(MutableBuffer.Companion::auto, matrix, checkSingular) -public fun LinearSpace.lup(matrix: Matrix): LupDecomposition = - lup(::DoubleBuffer, matrix) { it < 1e-11 } +public fun LinearSpace.lup( + matrix: Matrix, + singularityThreshold: Double = 1e-11, +): LupDecomposition = + lup(::DoubleBuffer, matrix) { it < singularityThreshold } -public fun LupDecomposition.solveWithLup( +internal fun LupDecomposition.solve( factory: MutableBufferFactory, matrix: Matrix, ): Matrix { @@ -207,41 +209,22 @@ public fun LupDecomposition.solveWithLup( } } -public inline fun LupDecomposition.solveWithLup(matrix: Matrix): Matrix = - solveWithLup(MutableBuffer.Companion::auto, matrix) - /** - * Solves a system of linear equations *ax = b** using LUP decomposition. + * Produce a generic solver based on LUP decomposition */ @OptIn(UnstableKMathAPI::class) -public inline fun > LinearSpace>.solveWithLup( - a: Matrix, - b: Matrix, - noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, - noinline checkSingular: (T) -> Boolean, -): Matrix { - // Use existing decomposition if it is provided by matrix - val decomposition = a.getFeature() ?: lup(bufferFactory, a, checkSingular) - return decomposition.solveWithLup(bufferFactory, b) +public fun , F : Field> LinearSpace.lupSolver( + bufferFactory: MutableBufferFactory, + singularityCheck: (T) -> Boolean, +): LinearSolver = object : LinearSolver { + override fun solve(a: Matrix, b: Matrix): Matrix { + // Use existing decomposition if it is provided by matrix + val decomposition = computeFeature(a) ?: lup(bufferFactory, a, singularityCheck) + return decomposition.solve(bufferFactory, b) + } + + override fun inverse(matrix: Matrix): Matrix = solve(matrix, one(matrix.rowNum, matrix.colNum)) } -public inline fun > LinearSpace>.inverseWithLup( - matrix: Matrix, - noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, - noinline checkSingular: (T) -> Boolean, -): Matrix = solveWithLup(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) - - -@OptIn(UnstableKMathAPI::class) -public fun LinearSpace.solveWithLup(a: Matrix, b: Matrix): Matrix { - // Use existing decomposition if it is provided by matrix - val bufferFactory: MutableBufferFactory = ::DoubleBuffer - val decomposition: LupDecomposition = a.getFeature() ?: lup(bufferFactory, a) { it < 1e-11 } - return decomposition.solveWithLup(bufferFactory, b) -} - -/** - * Inverses a square matrix using LUP decomposition. Non square matrix will throw a error. - */ -public fun LinearSpace.inverseWithLup(matrix: Matrix): Matrix = - solveWithLup(matrix, one(matrix.rowNum, matrix.colNum)) \ No newline at end of file +public fun LinearSpace.lupSolver(singularityThreshold: Double = 1e-11): LinearSolver = + lupSolver(::DoubleBuffer) { it < singularityThreshold } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixBuilder.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixBuilder.kt index 72d22233a..727b644c3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixBuilder.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixBuilder.kt @@ -1,12 +1,14 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.Ring +import space.kscience.kmath.structures.BufferAccessor2D +import space.kscience.kmath.structures.MutableBuffer public class MatrixBuilder>( public val linearSpace: LinearSpace, @@ -45,4 +47,31 @@ public inline fun LinearSpace>.column( crossinline builder: (Int) -> T, ): Matrix = buildMatrix(size, 1) { i, _ -> builder(i) } -public fun LinearSpace>.column(vararg values: T): Matrix = column(values.size, values::get) \ No newline at end of file +public fun LinearSpace>.column(vararg values: T): Matrix = column(values.size, values::get) + +public object SymmetricMatrixFeature : MatrixFeature + +/** + * Naive implementation of a symmetric matrix builder, that adds a [SymmetricMatrixFeature] tag. The resulting matrix contains + * full `size^2` number of elements, but caches elements during calls to save [builder] calls. [builder] is always called in the + * upper triangle region meaning that `i <= j` + */ +public fun > MatrixBuilder.symmetric( + builder: (i: Int, j: Int) -> T, +): Matrix { + require(columns == rows) { "In order to build symmetric matrix, number of rows $rows should be equal to number of columns $columns" } + return with(BufferAccessor2D(rows, rows, MutableBuffer.Companion::boxing)) { + val cache = factory(rows * rows) { null } + linearSpace.buildMatrix(rows, rows) { i, j -> + val cached = cache[i, j] + if (cached == null) { + val value = if (i <= j) builder(i, j) else builder(j, i) + cache[i, j] = value + cache[j, i] = value + value + } else { + cached + } + }.withFeature(SymmetricMatrixFeature) + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixFeatures.kt index 37c93d249..4c2b5c73c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixFeatures.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixFeatures.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear @@ -31,11 +31,11 @@ public object ZeroFeature : DiagonalFeature public object UnitFeature : DiagonalFeature /** - * Matrices with this feature can be inverted: [inverse] = `a`-1 where `a` is the owning matrix. + * Matrices with this feature can be inverted: *[inverse] = a−1* where *a* is the owning matrix. * * @param T the type of matrices' items. */ -public interface InverseMatrixFeature : MatrixFeature { +public interface InverseMatrixFeature : MatrixFeature { /** * The inverse matrix of the matrix that owns this feature. */ @@ -47,7 +47,7 @@ public interface InverseMatrixFeature : MatrixFeature { * * @param T the type of matrices' items. */ -public interface DeterminantFeature : MatrixFeature { +public interface DeterminantFeature : MatrixFeature { /** * The determinant of the matrix that owns this feature. */ @@ -80,7 +80,7 @@ public object UFeature : MatrixFeature * * @param T the type of matrices' items. */ -public interface LUDecompositionFeature : MatrixFeature { +public interface LUDecompositionFeature : MatrixFeature { /** * The lower triangular matrix in this decomposition. It may have [LFeature]. */ @@ -98,7 +98,7 @@ public interface LUDecompositionFeature : MatrixFeature { * * @param T the type of matrices' items. */ -public interface LupDecompositionFeature : MatrixFeature { +public interface LupDecompositionFeature : MatrixFeature { /** * The lower triangular matrix in this decomposition. It may have [LFeature]. */ @@ -126,7 +126,7 @@ public object OrthogonalFeature : MatrixFeature * * @param T the type of matrices' items. */ -public interface QRDecompositionFeature : MatrixFeature { +public interface QRDecompositionFeature : MatrixFeature { /** * The orthogonal matrix in this decomposition. It may have [OrthogonalFeature]. */ @@ -144,7 +144,7 @@ public interface QRDecompositionFeature : MatrixFeature { * * @param T the type of matrices' items. */ -public interface CholeskyDecompositionFeature : MatrixFeature { +public interface CholeskyDecompositionFeature : MatrixFeature { /** * The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature]. */ @@ -157,7 +157,7 @@ public interface CholeskyDecompositionFeature : MatrixFeature { * * @param T the type of matrices' items. */ -public interface SingularValueDecompositionFeature : MatrixFeature { +public interface SingularValueDecompositionFeature : MatrixFeature { /** * The matrix in this decomposition. It is unitary, and it consists from left singular vectors. */ diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt index 16aadab3b..a40c0384c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt @@ -1,13 +1,13 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear +import space.kscience.kmath.misc.FeatureSet import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.StructureFeature -import space.kscience.kmath.nd.getFeature import space.kscience.kmath.operations.Ring import kotlin.reflect.KClass @@ -18,19 +18,18 @@ import kotlin.reflect.KClass */ public class MatrixWrapper internal constructor( public val origin: Matrix, - public val features: Set, + public val features: FeatureSet, ) : Matrix by origin { /** - * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria + * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the + * criteria. */ - @UnstableKMathAPI @Suppress("UNCHECKED_CAST") - public override fun getFeature(type: KClass): F? = - features.singleOrNull(type::isInstance) as? F - ?: origin.getFeature(type) + override fun getFeature(type: KClass): F? = + features.getFeature(type) ?: origin.getFeature(type) - public override fun toString(): String = "MatrixWrapper(matrix=$origin, features=$features)" + override fun toString(): String = "MatrixWrapper(matrix=$origin, features=$features)" } /** @@ -44,31 +43,34 @@ public val Matrix.origin: Matrix /** * Add a single feature to a [Matrix] */ -public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixWrapper = if (this is MatrixWrapper) { - MatrixWrapper(origin, features + newFeature) +public fun Matrix.withFeature(newFeature: MatrixFeature): MatrixWrapper = if (this is MatrixWrapper) { + MatrixWrapper(origin, features.with(newFeature)) } else { - MatrixWrapper(this, setOf(newFeature)) + MatrixWrapper(this, FeatureSet.of(newFeature)) } +@Deprecated("To be replaced by withFeature") +public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixWrapper = withFeature(newFeature) + /** * Add a collection of features to a [Matrix] */ -public operator fun Matrix.plus(newFeatures: Collection): MatrixWrapper = +public fun Matrix.withFeatures(newFeatures: Iterable): MatrixWrapper = if (this is MatrixWrapper) { - MatrixWrapper(origin, features + newFeatures) + MatrixWrapper(origin, features.with(newFeatures)) } else { - MatrixWrapper(this, newFeatures.toSet()) + MatrixWrapper(this, FeatureSet.of(newFeatures)) } /** - * Diagonal matrix of ones. The matrix is virtual no actual matrix is created + * Diagonal matrix of ones. The matrix is virtual no actual matrix is created. */ public fun LinearSpace>.one( rows: Int, columns: Int, ): Matrix = VirtualMatrix(rows, columns) { i, j -> if (i == j) elementAlgebra.one else elementAlgebra.zero -} + UnitFeature +}.withFeature(UnitFeature) /** @@ -79,15 +81,14 @@ public fun LinearSpace>.zero( columns: Int, ): Matrix = VirtualMatrix(rows, columns) { _, _ -> elementAlgebra.zero -} + ZeroFeature +}.withFeature(ZeroFeature) public class TransposedFeature(public val original: Matrix) : MatrixFeature /** * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` */ +@Suppress("UNCHECKED_CAST") @OptIn(UnstableKMathAPI::class) -public fun Matrix.transpose(): Matrix = getFeature>()?.original ?: VirtualMatrix( - colNum, - rowNum, -) { i, j -> get(j, i) } + TransposedFeature(this) \ No newline at end of file +public fun Matrix.transpose(): Matrix = getFeature(TransposedFeature::class)?.original as? Matrix + ?: VirtualMatrix(colNum, rowNum) { i, j -> get(j, i) }.withFeature(TransposedFeature(this)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/VirtualMatrix.kt index 3751bd33b..be1677ecd 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/VirtualMatrix.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear @@ -20,3 +20,6 @@ public class VirtualMatrix( override operator fun get(i: Int, j: Int): T = generator(i, j) } + +public fun MatrixBuilder.virtual(generator: (i: Int, j: Int) -> T): VirtualMatrix = + VirtualMatrix(rows, columns, generator) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/Featured.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/Featured.kt new file mode 100644 index 000000000..29b7caec6 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/Featured.kt @@ -0,0 +1,61 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.misc + +import kotlin.jvm.JvmInline +import kotlin.reflect.KClass + +/** + * A entity that contains a set of features defined by their types + */ +public interface Featured { + public fun getFeature(type: FeatureKey): T? +} + +public typealias FeatureKey = KClass + +public interface Feature> { + + /** + * A key used for extraction + */ + @Suppress("UNCHECKED_CAST") + public val key: FeatureKey + get() = this::class as FeatureKey +} + +/** + * A container for a set of features + */ +@JvmInline +public value class FeatureSet> private constructor(public val features: Map, F>) : Featured { + @Suppress("UNCHECKED_CAST") + override fun getFeature(type: FeatureKey): T? = features[type]?.let { it as T } + + public inline fun getFeature(): T? = getFeature(T::class) + + public fun with(feature: T, type: FeatureKey = feature.key): FeatureSet = + FeatureSet(features + (type to feature)) + + public fun with(other: FeatureSet): FeatureSet = FeatureSet(features + other.features) + + public fun with(vararg otherFeatures: F): FeatureSet = + FeatureSet(features + otherFeatures.associateBy { it.key }) + + public fun with(otherFeatures: Iterable): FeatureSet = + FeatureSet(features + otherFeatures.associateBy { it.key }) + + public operator fun iterator(): Iterator = features.values.iterator() + + override fun toString(): String = features.values.joinToString(prefix = "[ ", postfix = " ]") + + + public companion object { + public fun > of(vararg features: F): FeatureSet = FeatureSet(features.associateBy { it.key }) + public fun > of(features: Iterable): FeatureSet = + FeatureSet(features.associateBy { it.key }) + } +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/annotations.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/annotations.kt index e521e6237..2b3a4ab03 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/annotations.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/annotations.kt @@ -1,27 +1,27 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.misc /** * Marks declarations that are still experimental in the KMath APIs, which means that the design of the corresponding - * declarations has open issues which may (or may not) lead to their changes in the future. Roughly speaking, there is - * a chance that those declarations will be deprecated in the near future or the semantics of their behavior may change + * declarations has open issues that may (or may not) lead to their changes in the future. Roughly speaking, there is + * a chance of those declarations will be deprecated in the near future or the semantics of their behavior may change * in some way that may break some code. */ @MustBeDocumented -@Retention(value = AnnotationRetention.BINARY) +@Retention(value = AnnotationRetention.SOURCE) @RequiresOptIn("This API is unstable and could change in future", RequiresOptIn.Level.WARNING) public annotation class UnstableKMathAPI /** - * Marks API which could cause performance problems. The code, marked by this API is not necessary slow, but could cause + * Marks API that could cause performance problems. The code marked by this API is unnecessary slow but could cause * slow-down in some cases. Refer to the documentation and benchmark it to be sure. */ @MustBeDocumented -@Retention(value = AnnotationRetention.BINARY) +@Retention(value = AnnotationRetention.SOURCE) @RequiresOptIn( "Refer to the documentation to use this API in performance-critical code", RequiresOptIn.Level.WARNING diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/cumulative.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/cumulative.kt index 889eb4f22..413f44960 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/cumulative.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/cumulative.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.misc @@ -34,7 +34,7 @@ public inline fun Iterable.cumulative(initial: R, crossinline operatio public inline fun Sequence.cumulative(initial: R, crossinline operation: (R, T) -> R): Sequence = Sequence { this@cumulative.iterator().cumulative(initial, operation) } -public fun List.cumulative(initial: R, operation: (R, T) -> R): List = +public inline fun List.cumulative(initial: R, crossinline operation: (R, T) -> R): List = iterator().cumulative(initial, operation).asSequence().toList() //Cumulative sum diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/logging.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/logging.kt new file mode 100644 index 000000000..9dfc564c3 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/logging.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.misc + +import space.kscience.kmath.misc.Loggable.Companion.INFO + +public fun interface Loggable { + public fun log(tag: String, block: () -> String) + + public companion object { + public const val INFO: String = "INFO" + + public val console: Loggable = Loggable { tag, block -> + println("[$tag] ${block()}") + } + } +} + +public fun Loggable.log(block: () -> String): Unit = log(INFO, block) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/numbers.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/numbers.kt new file mode 100644 index 000000000..e048eb746 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/numbers.kt @@ -0,0 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.misc + +public expect fun Long.toIntExact(): Int diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt index 35bbc44f6..4e52c8ba9 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt @@ -1,17 +1,17 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd +import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.* -import space.kscience.kmath.structures.* import kotlin.reflect.KClass /** - * An exception is thrown when the expected ans actual shape of NDArray differs. + * An exception is thrown when the expected and actual shape of NDArray differ. * * @property expected the expected shape. * @property actual the actual shape. @@ -19,52 +19,72 @@ import kotlin.reflect.KClass public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) : RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.") +public typealias Shape = IntArray + +public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest) + +public interface WithShape { + public val shape: Shape + + public val indices: ShapeIndexer get() = DefaultStrides(shape) +} + /** * The base interface for all ND-algebra implementations. * * @param T the type of ND-structure element. * @param C the type of the element context. - * @param N the type of the structure. */ public interface AlgebraND> { - /** - * The shape of ND-structures this algebra operates on. - */ - public val shape: IntArray - /** * The algebra over elements of ND structure. */ - public val elementContext: C + public val elementAlgebra: C /** - * Produces a new NDStructure using given initializer function. + * Produces a new [StructureND] using given initializer function. */ - public fun produce(initializer: C.(IntArray) -> T): StructureND + public fun structureND(shape: Shape, initializer: C.(IntArray) -> T): StructureND /** * Maps elements from one structure to another one by applying [transform] to them. */ - public fun StructureND.map(transform: C.(T) -> T): StructureND + @PerformancePitfall("Very slow on remote execution algebras") + public fun StructureND.map(transform: C.(T) -> T): StructureND = structureND(shape) { index -> + elementAlgebra.transform(get(index)) + } /** * Maps elements from one structure to another one by applying [transform] to them alongside with their indices. */ - public fun StructureND.mapIndexed(transform: C.(index: IntArray, T) -> T): StructureND + @PerformancePitfall("Very slow on remote execution algebras") + public fun StructureND.mapIndexed(transform: C.(index: IntArray, T) -> T): StructureND = + structureND(shape) { index -> + elementAlgebra.transform(index, get(index)) + } /** * Combines two structures into one. */ - public fun combine(a: StructureND, b: StructureND, transform: C.(T, T) -> T): StructureND + @PerformancePitfall("Very slow on remote execution algebras") + public fun zip(left: StructureND, right: StructureND, transform: C.(T, T) -> T): StructureND { + require(left.shape.contentEquals(right.shape)) { + "Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}" + } + return structureND(left.shape) { index -> + elementAlgebra.transform(left[index], right[index]) + } + } /** * Element-wise invocation of function working on [T] on a [StructureND]. */ + @PerformancePitfall public operator fun Function1.invoke(structure: StructureND): StructureND = structure.map { value -> this@invoke(value) } /** - * Get a feature of the structure in this scope. Structure features take precedence other context features + * Get a feature of the structure in this scope. Structure features take precedence other context features. * * @param F the type of feature. * @param structure the structure. @@ -78,9 +98,8 @@ public interface AlgebraND> { public companion object } - /** - * Get a feature of the structure in this scope. Structure features take precedence other context features + * Get a feature of the structure in this scope. Structure features take precedence other context features. * * @param T the type of items in the matrices. * @param F the type of feature. @@ -90,56 +109,23 @@ public interface AlgebraND> { public inline fun AlgebraND.getFeature(structure: StructureND): F? = getFeature(structure, F::class) -/** - * Checks if given elements are consistent with this context. - * - * @param structures the structures to check. - * @return the array of valid structures. - */ -internal fun > AlgebraND.checkShape(vararg structures: StructureND): Array> = - structures - .map(StructureND::shape) - .singleOrNull { !shape.contentEquals(it) } - ?.let>> { throw ShapeMismatchException(shape, it) } - ?: structures - -/** - * Checks if given element is consistent with this context. - * - * @param element the structure to check. - * @return the valid structure. - */ -internal fun > AlgebraND.checkShape(element: StructureND): StructureND { - if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape) - return element -} - /** * Space of [StructureND]. * * @param T the type of the element contained in ND structure. - * @param N the type of ND structure. - * @param S the type of space of structure elements. + * @param A the type of group over structure elements. */ -public interface GroupND> : Group>, AlgebraND { +public interface GroupOpsND> : GroupOps>, AlgebraND { /** * Element-wise addition. * - * @param a the augend. - * @param b the addend. + * @param left the augend. + * @param right the addend. * @return the sum. */ - public override fun add(a: StructureND, b: StructureND): StructureND = - combine(a, b) { aValue, bValue -> add(aValue, bValue) } - -// /** -// * Element-wise multiplication by scalar. -// * -// * @param a the multiplicand. -// * @param k the multiplier. -// * @return the product. -// */ -// public override fun multiply(a: NDStructure, k: Number): NDStructure = a.map { multiply(it, k) } + @OptIn(PerformancePitfall::class) + override fun add(left: StructureND, right: StructureND): StructureND = + zip(left, right) { aValue, bValue -> add(aValue, bValue) } // TODO move to extensions after KEEP-176 @@ -150,6 +136,7 @@ public interface GroupND> : Group>, AlgebraND * @param arg the addend. * @return the sum. */ + @OptIn(PerformancePitfall::class) public operator fun StructureND.plus(arg: T): StructureND = this.map { value -> add(arg, value) } /** @@ -159,6 +146,7 @@ public interface GroupND> : Group>, AlgebraND * @param arg the divisor. * @return the quotient. */ + @OptIn(PerformancePitfall::class) public operator fun StructureND.minus(arg: T): StructureND = this.map { value -> add(arg, -value) } /** @@ -168,6 +156,7 @@ public interface GroupND> : Group>, AlgebraND * @param arg the addend. * @return the sum. */ + @OptIn(PerformancePitfall::class) public operator fun T.plus(arg: StructureND): StructureND = arg.map { value -> add(this@plus, value) } /** @@ -177,28 +166,33 @@ public interface GroupND> : Group>, AlgebraND * @param arg the divisor. * @return the quotient. */ + @OptIn(PerformancePitfall::class) public operator fun T.minus(arg: StructureND): StructureND = arg.map { value -> add(-this@minus, value) } public companion object } +public interface GroupND> : Group>, GroupOpsND, WithShape { + override val zero: StructureND get() = structureND(shape) { elementAlgebra.zero } +} + /** * Ring of [StructureND]. * * @param T the type of the element contained in ND structure. - * @param N the type of ND structure. - * @param R the type of ring of structure elements. + * @param A the type of ring over structure elements. */ -public interface RingND> : Ring>, GroupND { +public interface RingOpsND> : RingOps>, GroupOpsND { /** * Element-wise multiplication. * - * @param a the multiplicand. - * @param b the multiplier. + * @param left the multiplicand. + * @param right the multiplier. * @return the product. */ - public override fun multiply(a: StructureND, b: StructureND): StructureND = - combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } + @OptIn(PerformancePitfall::class) + override fun multiply(left: StructureND, right: StructureND): StructureND = + zip(left, right) { aValue, bValue -> multiply(aValue, bValue) } //TODO move to extensions after KEEP-176 @@ -209,6 +203,7 @@ public interface RingND> : Ring>, GroupND.times(arg: T): StructureND = this.map { value -> multiply(arg, value) } /** @@ -218,29 +213,39 @@ public interface RingND> : Ring>, GroupND): StructureND = arg.map { value -> multiply(this@times, value) } public companion object } +public interface RingND> : Ring>, RingOpsND, GroupND, WithShape { + override val one: StructureND get() = structureND(shape) { elementAlgebra.one } +} + + /** * Field of [StructureND]. * * @param T the type of the element contained in ND structure. - * @param F the type field of structure elements. + * @param A the type field over structure elements. */ -public interface FieldND> : Field>, RingND, ScaleOperations> { +public interface FieldOpsND> : + FieldOps>, + RingOpsND, + ScaleOperations> { /** * Element-wise division. * - * @param a the dividend. - * @param b the divisor. + * @param left the dividend. + * @param right the divisor. * @return the quotient. */ - public override fun divide(a: StructureND, b: StructureND): StructureND = - combine(a, b) { aValue, bValue -> divide(aValue, bValue) } + @OptIn(PerformancePitfall::class) + override fun divide(left: StructureND, right: StructureND): StructureND = + zip(left, right) { aValue, bValue -> divide(aValue, bValue) } - //TODO move to extensions after KEEP-176 + //TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md /** * Divides an ND structure by an element of it. * @@ -248,6 +253,7 @@ public interface FieldND> : Field>, RingND.div(arg: T): StructureND = this.map { value -> divide(arg, value) } /** @@ -257,35 +263,13 @@ public interface FieldND> : Field>, RingND): StructureND = arg.map { divide(it, this@div) } -// @ThreadLocal -// public companion object { -// private val realNDFieldCache: MutableMap = hashMapOf() -// -// /** -// * Create a nd-field for [Double] values or pull it from cache if it was created previously. -// */ -// public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) } -// -// /** -// * Create an ND field with boxing generic buffer. -// */ -// public fun > boxing( -// field: F, -// vararg shape: Int, -// bufferFactory: BufferFactory = Buffer.Companion::boxing, -// ): BufferedNDField = BufferedNDField(shape, field, bufferFactory) -// -// /** -// * Create a most suitable implementation for nd-field using reified class. -// */ -// @Suppress("UNCHECKED_CAST") -// public inline fun > auto(field: F, vararg shape: Int): NDField = -// when { -// T::class == Double::class -> real(*shape) as NDField -// T::class == Complex::class -> complex(*shape) as BufferedNDField -// else -> BoxingNDField(shape, field, Buffer.Companion::auto) -// } -// } + @OptIn(PerformancePitfall::class) + override fun scale(a: StructureND, value: Double): StructureND = a.map { scale(it, value) } } + +public interface FieldND> : Field>, FieldOpsND, RingND, WithShape { + override val one: StructureND get() = structureND(shape) { elementAlgebra.one } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt index 2b82a36ae..0e094a8c7 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt @@ -1,142 +1,185 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ +@file:OptIn(UnstableKMathAPI::class) + package space.kscience.kmath.nd +import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.* -import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract public interface BufferAlgebraND> : AlgebraND { - public val strides: Strides - public val bufferFactory: BufferFactory + public val indexerBuilder: (IntArray) -> ShapeIndexer + public val bufferAlgebra: BufferAlgebra + override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra - override fun produce(initializer: A.(IntArray) -> T): BufferND = BufferND( - strides, - bufferFactory(strides.linearSize) { offset -> - elementContext.initializer(strides.index(offset)) + override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): BufferND { + val indexer = indexerBuilder(shape) + return BufferND( + indexer, + bufferAlgebra.buffer(indexer.linearSize) { offset -> + elementAlgebra.initializer(indexer.index(offset)) + } + ) + } + + public fun StructureND.toBufferND(): BufferND = when (this) { + is BufferND -> this + else -> { + val indexer = indexerBuilder(shape) + BufferND(indexer, bufferAlgebra.buffer(indexer.linearSize) { offset -> get(indexer.index(offset)) }) + } + } + + @PerformancePitfall + override fun StructureND.map(transform: A.(T) -> T): BufferND = mapInline(toBufferND(), transform) + + @PerformancePitfall + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND = + mapIndexedInline(toBufferND(), transform) + + @PerformancePitfall + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): BufferND = + zipInline(left.toBufferND(), right.toBufferND(), transform) + + public companion object { + public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = DefaultStrides.Companion::invoke + } +} + +public inline fun > BufferAlgebraND.mapInline( + arg: BufferND, + crossinline transform: A.(T) -> T +): BufferND { + val indexes = arg.indices + return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform)) +} + +internal inline fun > BufferAlgebraND.mapIndexedInline( + arg: BufferND, + crossinline transform: A.(index: IntArray, arg: T) -> T +): BufferND { + val indexes = arg.indices + return BufferND( + indexes, + bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value -> + transform(indexes.index(offset), value) } ) - - public val StructureND.buffer: Buffer - get() = when { - !shape.contentEquals(this@BufferAlgebraND.shape) -> throw ShapeMismatchException( - this@BufferAlgebraND.shape, - shape - ) - this is BufferND && this.strides == this@BufferAlgebraND.strides -> this.buffer - else -> bufferFactory(strides.linearSize) { offset -> get(strides.index(offset)) } - } - - override fun StructureND.map(transform: A.(T) -> T): BufferND { - val buffer = bufferFactory(strides.linearSize) { offset -> - elementContext.transform(buffer[offset]) - } - return BufferND(strides, buffer) - } - - override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND { - val buffer = bufferFactory(strides.linearSize) { offset -> - elementContext.transform( - strides.index(offset), - buffer[offset] - ) - } - return BufferND(strides, buffer) - } - - override fun combine(a: StructureND, b: StructureND, transform: A.(T, T) -> T): BufferND { - val buffer = bufferFactory(strides.linearSize) { offset -> - elementContext.transform(a.buffer[offset], b.buffer[offset]) - } - return BufferND(strides, buffer) - } } -public open class BufferedGroupND>( - final override val shape: IntArray, - final override val elementContext: A, - final override val bufferFactory: BufferFactory, -) : GroupND, BufferAlgebraND { - override val strides: Strides = DefaultStrides(shape) - override val zero: BufferND by lazy { produce { zero } } - override fun StructureND.unaryMinus(): StructureND = produce { -get(it) } +internal inline fun > BufferAlgebraND.zipInline( + l: BufferND, + r: BufferND, + crossinline block: A.(l: T, r: T) -> T +): BufferND { + require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } + val indexes = l.indices + return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block)) } -public open class BufferedRingND>( - shape: IntArray, - elementContext: R, - bufferFactory: BufferFactory, -) : BufferedGroupND(shape, elementContext, bufferFactory), RingND { - override val one: BufferND by lazy { produce { one } } +@OptIn(PerformancePitfall::class) +public open class BufferedGroupNDOps>( + override val bufferAlgebra: BufferAlgebra, + override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder +) : GroupOpsND, BufferAlgebraND { + override fun StructureND.unaryMinus(): StructureND = map { -it } } -public open class BufferedFieldND>( - shape: IntArray, - elementContext: R, - bufferFactory: BufferFactory, -) : BufferedRingND(shape, elementContext, bufferFactory), FieldND { +public open class BufferedRingOpsND>( + bufferAlgebra: BufferAlgebra, + indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder +) : BufferedGroupNDOps(bufferAlgebra, indexerBuilder), RingOpsND +public open class BufferedFieldOpsND>( + bufferAlgebra: BufferAlgebra, + indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder +) : BufferedRingOpsND(bufferAlgebra, indexerBuilder), FieldOpsND { + + public constructor( + elementAlgebra: A, + bufferFactory: BufferFactory, + indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder + ) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder) + + @OptIn(PerformancePitfall::class) override fun scale(a: StructureND, value: Double): StructureND = a.map { it * value } } -// group factories -public fun > AlgebraND.Companion.group( - space: A, - bufferFactory: BufferFactory, - vararg shape: Int, -): BufferedGroupND = BufferedGroupND(shape, space, bufferFactory) +public val > BufferAlgebra.nd: BufferedGroupNDOps get() = BufferedGroupNDOps(this) +public val > BufferAlgebra.nd: BufferedRingOpsND get() = BufferedRingOpsND(this) +public val > BufferAlgebra.nd: BufferedFieldOpsND get() = BufferedFieldOpsND(this) -public inline fun , R> A.ndGroup( - noinline bufferFactory: BufferFactory, - vararg shape: Int, - action: BufferedGroupND.() -> R, -): R { - contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } - return AlgebraND.group(this, bufferFactory, *shape).run(action) -} -//ring factories -public fun > AlgebraND.Companion.ring( - ring: A, - bufferFactory: BufferFactory, +public fun > BufferAlgebraND.structureND( vararg shape: Int, -): BufferedRingND = BufferedRingND(shape, ring, bufferFactory) + initializer: A.(IntArray) -> T +): BufferND = structureND(shape, initializer) -public inline fun , R> A.ndRing( - noinline bufferFactory: BufferFactory, - vararg shape: Int, - action: BufferedRingND.() -> R, -): R { - contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } - return AlgebraND.ring(this, bufferFactory, *shape).run(action) -} +public fun , A> A.structureND( + initializer: EA.(IntArray) -> T +): BufferND where A : BufferAlgebraND, A : WithShape = structureND(shape, initializer) -//field factories -public fun > AlgebraND.Companion.field( - field: A, - bufferFactory: BufferFactory, - vararg shape: Int, -): BufferedFieldND = BufferedFieldND(shape, field, bufferFactory) +//// group factories +//public fun > A.ndAlgebra( +// bufferAlgebra: BufferAlgebra, +// vararg shape: Int, +//): BufferedGroupNDOps = BufferedGroupNDOps(bufferAlgebra) +// +//@JvmName("withNdGroup") +//public inline fun , R> A.withNdAlgebra( +// noinline bufferFactory: BufferFactory, +// vararg shape: Int, +// action: BufferedGroupNDOps.() -> R, +//): R { +// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } +// return ndAlgebra(bufferFactory, *shape).run(action) +//} -@Suppress("UNCHECKED_CAST") -public inline fun > AlgebraND.Companion.auto( - field: A, - vararg shape: Int, -): FieldND = when (field) { - DoubleField -> DoubleFieldND(shape) as FieldND - else -> BufferedFieldND(shape, field, Buffer.Companion::auto) -} - -public inline fun , R> A.ndField( - noinline bufferFactory: BufferFactory, - vararg shape: Int, - action: BufferedFieldND.() -> R, -): R { - contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } - return AlgebraND.field(this, bufferFactory, *shape).run(action) -} \ No newline at end of file +////ring factories +//public fun > A.ndAlgebra( +// bufferFactory: BufferFactory, +// vararg shape: Int, +//): BufferedRingNDOps = BufferedRingNDOps(shape, this, bufferFactory) +// +//@JvmName("withNdRing") +//public inline fun , R> A.withNdAlgebra( +// noinline bufferFactory: BufferFactory, +// vararg shape: Int, +// action: BufferedRingNDOps.() -> R, +//): R { +// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } +// return ndAlgebra(bufferFactory, *shape).run(action) +//} +// +////field factories +//public fun > A.ndAlgebra( +// bufferFactory: BufferFactory, +// vararg shape: Int, +//): BufferedFieldNDOps = BufferedFieldNDOps(shape, this, bufferFactory) +// +///** +// * Create a [FieldND] for this [Field] inferring proper buffer factory from the type +// */ +//@UnstableKMathAPI +//@Suppress("UNCHECKED_CAST") +//public inline fun > A.autoNdAlgebra( +// vararg shape: Int, +//): FieldND = when (this) { +// DoubleField -> DoubleFieldND(shape) as FieldND +// else -> BufferedFieldNDOps(shape, this, Buffer.Companion::auto) +//} +// +//@JvmName("withNdField") +//public inline fun , R> A.withNdAlgebra( +// noinline bufferFactory: BufferFactory, +// vararg shape: Int, +// action: BufferedFieldNDOps.() -> R, +//): R { +// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } +// return ndAlgebra(bufferFactory, *shape).run(action) +//} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt index 904419302..19924616d 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd @@ -15,26 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory * Represents [StructureND] over [Buffer]. * * @param T the type of items. - * @param strides The strides to access elements of [Buffer] by linear indices. + * @param indices The strides to access elements of [Buffer] by linear indices. * @param buffer The underlying buffer. */ public open class BufferND( - public val strides: Strides, - public val buffer: Buffer, + override val indices: ShapeIndexer, + public open val buffer: Buffer, ) : StructureND { - init { - if (strides.linearSize != buffer.size) { - error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}") - } - } + override operator fun get(index: IntArray): T = buffer[indices.offset(index)] - override operator fun get(index: IntArray): T = buffer[strides.offset(index)] - - override val shape: IntArray get() = strides.shape + override val shape: IntArray get() = indices.shape @PerformancePitfall - override fun elements(): Sequence> = strides.indices().map { + override fun elements(): Sequence> = indices.asSequence().map { it to this[it] } @@ -49,7 +43,7 @@ public inline fun StructureND.mapToBuffer( crossinline transform: (T) -> R, ): BufferND { return if (this is BufferND) - BufferND(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) + BufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) }) else { val strides = DefaultStrides(shape) BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) @@ -61,14 +55,14 @@ public inline fun StructureND.mapToBuffer( * * @param T the type of items. * @param strides The strides to access elements of [MutableBuffer] by linear indices. - * @param mutableBuffer The underlying buffer. + * @param buffer The underlying buffer. */ public class MutableBufferND( - strides: Strides, - public val mutableBuffer: MutableBuffer, -) : MutableStructureND, BufferND(strides, mutableBuffer) { + strides: ShapeIndexer, + override val buffer: MutableBuffer, +) : MutableStructureND, BufferND(strides, buffer) { override fun set(index: IntArray, value: T) { - mutableBuffer[strides.offset(index)] = value + buffer[indices.offset(index)] = value } } @@ -80,7 +74,7 @@ public inline fun MutableStructureND.mapToMutableBuffer( crossinline transform: (T) -> R, ): MutableBufferND { return if (this is MutableBufferND) - MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) }) + MutableBufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) }) else { val strides = DefaultStrides(shape) MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt index 71532594e..7285fdb24 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt @@ -1,114 +1,195 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd +import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations -import space.kscience.kmath.operations.ScaleOperations +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.DoubleBuffer import kotlin.contracts.InvocationKind import kotlin.contracts.contract +import kotlin.math.pow -@OptIn(UnstableKMathAPI::class) -public class DoubleFieldND( - shape: IntArray, -) : BufferedFieldND(shape, DoubleField, ::DoubleBuffer), - NumbersAddOperations>, - ScaleOperations>, - ExtendedField> { +public class DoubleBufferND( + indexes: ShapeIndexer, + override val buffer: DoubleBuffer, +) : BufferND(indexes, buffer) - public override val zero: BufferND by lazy { produce { zero } } - public override val one: BufferND by lazy { produce { one } } - public override fun number(value: Number): BufferND { - val d = value.toDouble() // minimize conversions - return produce { d } - } +public sealed class DoubleFieldOpsND : BufferedFieldOpsND(DoubleField.bufferAlgebra), + ScaleOperations>, ExtendedFieldOps> { - public override val StructureND.buffer: DoubleBuffer - get() = when { - !shape.contentEquals(this@DoubleFieldND.shape) -> throw ShapeMismatchException( - this@DoubleFieldND.shape, - shape - ) - this is BufferND && this.strides == this@DoubleFieldND.strides -> this.buffer as DoubleBuffer - else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } + override fun StructureND.toBufferND(): DoubleBufferND = when (this) { + is DoubleBufferND -> this + else -> { + val indexer = indexerBuilder(shape) + DoubleBufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) }) } - - @Suppress("OVERRIDE_BY_INLINE") - public override inline fun StructureND.map( - transform: DoubleField.(Double) -> Double, - ): BufferND { - val buffer = DoubleBuffer(strides.linearSize) { offset -> DoubleField.transform(buffer.array[offset]) } - return BufferND(strides, buffer) } - @Suppress("OVERRIDE_BY_INLINE") - public override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND { - val array = DoubleArray(strides.linearSize) { offset -> - val index = strides.index(offset) - DoubleField.initializer(index) - } - return BufferND(strides, DoubleBuffer(array)) + private inline fun mapInline( + arg: DoubleBufferND, + transform: (Double) -> Double + ): DoubleBufferND { + val indexes = arg.indices + val array = arg.buffer.array + return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) }) } - @Suppress("OVERRIDE_BY_INLINE") - public override inline fun StructureND.mapIndexed( - transform: DoubleField.(index: IntArray, Double) -> Double, - ): BufferND = BufferND( - strides, - buffer = DoubleBuffer(strides.linearSize) { offset -> - DoubleField.transform( - strides.index(offset), - buffer.array[offset] - ) - }) - - @Suppress("OVERRIDE_BY_INLINE") - public override inline fun combine( - a: StructureND, - b: StructureND, - transform: DoubleField.(Double, Double) -> Double, - ): BufferND { - val buffer = DoubleBuffer(strides.linearSize) { offset -> - DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset]) - } - return BufferND(strides, buffer) + private inline fun zipInline( + l: DoubleBufferND, + r: DoubleBufferND, + block: (l: Double, r: Double) -> Double + ): DoubleBufferND { + require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } + val indexes = l.indices + val lArray = l.buffer.array + val rArray = r.buffer.array + return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) }) } - public override fun scale(a: StructureND, value: Double): StructureND = a.map { it * value } + @OptIn(PerformancePitfall::class) + override fun StructureND.map(transform: DoubleField.(Double) -> Double): BufferND = + mapInline(toBufferND()) { DoubleField.transform(it) } - public override fun power(arg: StructureND, pow: Number): BufferND = arg.map { power(it, pow) } - public override fun exp(arg: StructureND): BufferND = arg.map { exp(it) } - public override fun ln(arg: StructureND): BufferND = arg.map { ln(it) } + @OptIn(PerformancePitfall::class) + override fun zip( + left: StructureND, + right: StructureND, + transform: DoubleField.(Double, Double) -> Double + ): BufferND = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) } - public override fun sin(arg: StructureND): BufferND = arg.map { sin(it) } - public override fun cos(arg: StructureND): BufferND = arg.map { cos(it) } - public override fun tan(arg: StructureND): BufferND = arg.map { tan(it) } - public override fun asin(arg: StructureND): BufferND = arg.map { asin(it) } - public override fun acos(arg: StructureND): BufferND = arg.map { acos(it) } - public override fun atan(arg: StructureND): BufferND = arg.map { atan(it) } + override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND { + val indexer = indexerBuilder(shape) + return DoubleBufferND( + indexer, + DoubleBuffer(indexer.linearSize) { offset -> + elementAlgebra.initializer(indexer.index(offset)) + } + ) + } - public override fun sinh(arg: StructureND): BufferND = arg.map { sinh(it) } - public override fun cosh(arg: StructureND): BufferND = arg.map { cosh(it) } - public override fun tanh(arg: StructureND): BufferND = arg.map { tanh(it) } - public override fun asinh(arg: StructureND): BufferND = arg.map { asinh(it) } - public override fun acosh(arg: StructureND): BufferND = arg.map { acosh(it) } - public override fun atanh(arg: StructureND): BufferND = arg.map { atanh(it) } + override fun add(left: StructureND, right: StructureND): DoubleBufferND = + zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l + r } + + override fun multiply(left: StructureND, right: StructureND): DoubleBufferND = + zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l * r } + + override fun StructureND.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it } + + override fun StructureND.div(arg: StructureND): DoubleBufferND = + zipInline(toBufferND(), arg.toBufferND()) { l, r -> l / r } + + override fun divide(left: StructureND, right: StructureND): DoubleBufferND = + zipInline(left.toBufferND(), right.toBufferND()) { l: Double, r: Double -> l / r } + + override fun StructureND.div(arg: Double): DoubleBufferND = + mapInline(toBufferND()) { it / arg } + + override fun Double.div(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { this / it } + + override fun StructureND.unaryPlus(): DoubleBufferND = toBufferND() + + override fun StructureND.plus(arg: StructureND): DoubleBufferND = + zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l + r } + + override fun StructureND.minus(arg: StructureND): DoubleBufferND = + zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l - r } + + override fun StructureND.times(arg: StructureND): DoubleBufferND = + zipInline(toBufferND(), arg.toBufferND()) { l: Double, r: Double -> l * r } + + override fun StructureND.times(k: Number): DoubleBufferND = + mapInline(toBufferND()) { it * k.toDouble() } + + override fun StructureND.div(k: Number): DoubleBufferND = + mapInline(toBufferND()) { it / k.toDouble() } + + override fun Number.times(arg: StructureND): DoubleBufferND = arg * this + + override fun StructureND.plus(arg: Double): DoubleBufferND = mapInline(toBufferND()) { it + arg } + + override fun StructureND.minus(arg: Double): StructureND = mapInline(toBufferND()) { it - arg } + + override fun Double.plus(arg: StructureND): StructureND = arg + this + + override fun Double.minus(arg: StructureND): StructureND = mapInline(arg.toBufferND()) { this - it } + + override fun scale(a: StructureND, value: Double): DoubleBufferND = + mapInline(a.toBufferND()) { it * value } + + override fun power(arg: StructureND, pow: Number): DoubleBufferND = + mapInline(arg.toBufferND()) { it.pow(pow.toDouble()) } + + override fun exp(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.exp(it) } + + override fun ln(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.ln(it) } + + override fun sin(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.sin(it) } + + override fun cos(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.cos(it) } + + override fun tan(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.tan(it) } + + override fun asin(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.asin(it) } + + override fun acos(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.acos(it) } + + override fun atan(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.atan(it) } + + override fun sinh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.sinh(it) } + + override fun cosh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.cosh(it) } + + override fun tanh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.tanh(it) } + + override fun asinh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.asinh(it) } + + override fun acosh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.acosh(it) } + + override fun atanh(arg: StructureND): DoubleBufferND = + mapInline(arg.toBufferND()) { kotlin.math.atanh(it) } + + public companion object : DoubleFieldOpsND() } -public fun AlgebraND.Companion.real(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape) +@OptIn(UnstableKMathAPI::class) +public class DoubleFieldND(override val shape: Shape) : + DoubleFieldOpsND(), FieldND, NumbersAddOps> { + + override fun number(value: Number): DoubleBufferND { + val d = value.toDouble() // minimize conversions + return structureND(shape) { d } + } +} + +public val DoubleField.ndAlgebra: DoubleFieldOpsND get() = DoubleFieldOpsND + +public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape) /** * Produce a context for n-dimensional operations inside this real field */ -public inline fun DoubleField.nd(vararg shape: Int, action: DoubleFieldND.() -> R): R { +@UnstableKMathAPI +public inline fun DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R { contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } return DoubleFieldND(shape).run(action) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndexer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndexer.kt new file mode 100644 index 000000000..20e180dd1 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShapeIndexer.kt @@ -0,0 +1,121 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.nd + +import kotlin.native.concurrent.ThreadLocal + +/** + * A converter from linear index to multivariate index + */ +public interface ShapeIndexer: Iterable{ + public val shape: Shape + + /** + * Get linear index from multidimensional index + */ + public fun offset(index: IntArray): Int + + /** + * Get multidimensional from linear + */ + public fun index(offset: Int): IntArray + + /** + * The size of linear buffer to accommodate all elements of ND-structure corresponding to strides + */ + public val linearSize: Int + + // TODO introduce a fast way to calculate index of the next element? + + /** + * Iterate over ND indices in a natural order + */ + public fun asSequence(): Sequence + + override fun iterator(): Iterator = asSequence().iterator() + + override fun equals(other: Any?): Boolean + override fun hashCode(): Int +} + +/** + * Linear transformation of indexes + */ +public abstract class Strides: ShapeIndexer { + /** + * Array strides + */ + public abstract val strides: IntArray + + public override fun offset(index: IntArray): Int = index.mapIndexed { i, value -> + if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})") + value * strides[i] + }.sum() + + // TODO introduce a fast way to calculate index of the next element? + + /** + * Iterate over ND indices in a natural order + */ + public override fun asSequence(): Sequence = (0 until linearSize).asSequence().map(::index) +} + +/** + * Simple implementation of [Strides]. + */ +public class DefaultStrides private constructor(override val shape: IntArray) : Strides() { + override val linearSize: Int get() = strides[shape.size] + + /** + * Strides for memory access + */ + override val strides: IntArray by lazy { + sequence { + var current = 1 + yield(1) + + shape.forEach { + current *= it + yield(current) + } + }.toList().toIntArray() + } + + override fun index(offset: Int): IntArray { + val res = IntArray(shape.size) + var current = offset + var strideIndex = strides.size - 2 + + while (strideIndex >= 0) { + res[strideIndex] = (current / strides[strideIndex]) + current %= strides[strideIndex] + strideIndex-- + } + + return res + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is DefaultStrides) return false + if (!shape.contentEquals(other.shape)) return false + return true + } + + override fun hashCode(): Int = shape.contentHashCode() + + + public companion object { + /** + * Cached builder for default strides + */ + public operator fun invoke(shape: IntArray): Strides = + defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) } + } +} + +@ThreadLocal +private val defaultStridesCache = HashMap() \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShortRingND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShortRingND.kt index 720a06ace..827f0e21e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShortRingND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/ShortRingND.kt @@ -1,41 +1,33 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.NumbersAddOperations +import space.kscience.kmath.operations.NumbersAddOps import space.kscience.kmath.operations.ShortRing -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.ShortBuffer +import space.kscience.kmath.operations.bufferAlgebra import kotlin.contracts.InvocationKind import kotlin.contracts.contract +public sealed class ShortRingOpsND : BufferedRingOpsND(ShortRing.bufferAlgebra) { + public companion object : ShortRingOpsND() +} + @OptIn(UnstableKMathAPI::class) public class ShortRingND( - shape: IntArray, -) : BufferedRingND(shape, ShortRing, Buffer.Companion::auto), - NumbersAddOperations> { - - override val zero: BufferND by lazy { produce { zero } } - override val one: BufferND by lazy { produce { one } } + override val shape: Shape +) : ShortRingOpsND(), RingND, NumbersAddOps> { override fun number(value: Number): BufferND { val d = value.toShort() // minimize conversions - return produce { d } + return structureND(shape) { d } } } -/** - * Fast element production using function inlining. - */ -public inline fun BufferedRingND.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND { - return BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) })) -} - -public inline fun ShortRing.nd(vararg shape: Int, action: ShortRingND.() -> R): R { +public inline fun ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R { contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } return ShortRingND(shape).run(action) -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt index 150ebf6fb..3dcc77334 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure1D.kt @@ -1,29 +1,29 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.operations.asSequence import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.asMutableBuffer -import space.kscience.kmath.structures.asSequence import kotlin.jvm.JvmInline /** * A structure that is guaranteed to be one-dimensional */ public interface Structure1D : StructureND, Buffer { - public override val dimension: Int get() = 1 + override val dimension: Int get() = 1 - public override operator fun get(index: IntArray): T { + override operator fun get(index: IntArray): T { require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" } return get(index[0]) } - public override operator fun iterator(): Iterator = (0 until size).asSequence().map(::get).iterator() + override operator fun iterator(): Iterator = (0 until size).asSequence().map(::get).iterator() public companion object } @@ -32,7 +32,7 @@ public interface Structure1D : StructureND, Buffer { * A mutable structure that is guaranteed to be one-dimensional */ public interface MutableStructure1D : Structure1D, MutableStructureND, MutableBuffer { - public override operator fun set(index: IntArray, value: T) { + override operator fun set(index: IntArray, value: T) { require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" } set(index[0], value) } @@ -67,12 +67,14 @@ private class MutableStructure1DWrapper(val structure: MutableStructureND) structure[intArrayOf(index)] = value } - @PerformancePitfall + @OptIn(PerformancePitfall::class) override fun copy(): MutableBuffer = structure .elements() .map(Pair::second) .toMutableList() .asMutableBuffer() + + override fun toString(): String = Buffer.toString(this) } @@ -107,10 +109,12 @@ internal class MutableBuffer1DWrapper(val buffer: MutableBuffer) : Mutable } override fun copy(): MutableBuffer = buffer.copy() + + override fun toString(): String = Buffer.toString(this) } /** - * Represent a [StructureND] as [Structure1D]. Throw error in case of dimension mismatch + * Represent a [StructureND] as [Structure1D]. Throw error in case of dimension mismatch. */ public fun StructureND.as1D(): Structure1D = this as? Structure1D ?: if (shape.size == 1) { when (this) { @@ -132,7 +136,7 @@ public fun Buffer.asND(): Structure1D = Buffer1DWrapper(this) /** * Expose inner buffer of this [Structure1D] if possible */ -internal fun Structure1D.unwrap(): Buffer = when { +internal fun Structure1D.asND(): Buffer = when { this is Buffer1DWrapper -> buffer this is Structure1DWrapper && structure is BufferND -> structure.buffer else -> this diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt index f353b6974..e3552c02e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/Structure2D.kt @@ -1,12 +1,11 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.MutableListBuffer import space.kscience.kmath.structures.VirtualBuffer @@ -29,7 +28,7 @@ public interface Structure2D : StructureND { */ public val colNum: Int - public override val shape: IntArray get() = intArrayOf(rowNum, colNum) + override val shape: IntArray get() = intArrayOf(rowNum, colNum) /** * The buffer of rows of this structure. It gets elements from the structure dynamically. @@ -86,7 +85,7 @@ public interface MutableStructure2D : Structure2D, MutableStructureND { */ @PerformancePitfall override val rows: List> - get() = List(rowNum) { i -> MutableBuffer1DWrapper(MutableListBuffer(colNum) { j -> get(i, j) })} + get() = List(rowNum) { i -> MutableBuffer1DWrapper(MutableListBuffer(colNum) { j -> get(i, j) }) } /** * The buffer of columns of this structure. It gets elements from the structure dynamically. @@ -101,14 +100,13 @@ public interface MutableStructure2D : Structure2D, MutableStructureND { */ @JvmInline private value class Structure2DWrapper(val structure: StructureND) : Structure2D { - override val shape: IntArray get() = structure.shape + override val shape: Shape get() = structure.shape override val rowNum: Int get() = shape[0] override val colNum: Int get() = shape[1] override operator fun get(i: Int, j: Int): T = structure[i, j] - @UnstableKMathAPI override fun getFeature(type: KClass): F? = structure.getFeature(type) @PerformancePitfall @@ -118,9 +116,8 @@ private value class Structure2DWrapper(val structure: StructureND) : S /** * A 2D wrapper for a mutable nd-structure */ -private class MutableStructure2DWrapper(val structure: MutableStructureND): MutableStructure2D -{ - override val shape: IntArray get() = structure.shape +private class MutableStructure2DWrapper(val structure: MutableStructureND) : MutableStructure2D { + override val shape: Shape get() = structure.shape override val rowNum: Int get() = shape[0] override val colNum: Int get() = shape[1] @@ -131,7 +128,7 @@ private class MutableStructure2DWrapper(val structure: MutableStructureND) structure[index] = value } - override operator fun set(i: Int, j: Int, value: T){ + override operator fun set(i: Int, j: Int, value: T) { structure[intArrayOf(i, j)] = value } @@ -144,25 +141,29 @@ private class MutableStructure2DWrapper(val structure: MutableStructureND) } /** - * Represent a [StructureND] as [Structure1D]. Throw error in case of dimension mismatch + * Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch. */ public fun StructureND.as2D(): Structure2D = this as? Structure2D ?: when (shape.size) { 2 -> Structure2DWrapper(this) else -> error("Can't create 2d-structure from ${shape.size}d-structure") } -public fun MutableStructureND.as2D(): MutableStructure2D = this as? MutableStructure2D ?: when (shape.size) { - 2 -> MutableStructure2DWrapper(this) - else -> error("Can't create 2d-structure from ${shape.size}d-structure") -} +/** + * Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch. + */ +public fun MutableStructureND.as2D(): MutableStructure2D = + this as? MutableStructure2D ?: when (shape.size) { + 2 -> MutableStructure2DWrapper(this) + else -> error("Can't create 2d-structure from ${shape.size}d-structure") + } /** * Expose inner [StructureND] if possible */ -internal fun Structure2D.unwrap(): StructureND = +internal fun Structure2D.asND(): StructureND = if (this is Structure2DWrapper) structure else this -internal fun MutableStructure2D.unwrap(): MutableStructureND = +internal fun MutableStructure2D.asND(): MutableStructureND = if (this is MutableStructure2DWrapper) structure else this diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt index 7fc91e321..614d97950 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt @@ -1,22 +1,26 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd +import space.kscience.kmath.linear.LinearSpace +import space.kscience.kmath.misc.Feature +import space.kscience.kmath.misc.Featured import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.Ring +import space.kscience.kmath.operations.invoke import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import kotlin.jvm.JvmName -import kotlin.native.concurrent.ThreadLocal +import kotlin.math.abs import kotlin.reflect.KClass -public interface StructureFeature +public interface StructureFeature : Feature /** - * Represents n-dimensional structure, i.e. multidimensional container of items of the same type and size. The number + * Represents n-dimensional structure i.e., multidimensional container of items of the same type and size. The number * of dimensions and items in an array is defined by its shape, which is a sequence of non-negative integers that * specify the sizes of each dimension. * @@ -24,12 +28,12 @@ public interface StructureFeature * * @param T the type of items. */ -public interface StructureND { +public interface StructureND : Featured, WithShape { /** - * The shape of structure, i.e. non-empty sequence of non-negative integers that specify sizes of dimensions of + * The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of * this structure. */ - public val shape: IntArray + override val shape: Shape /** * The count of dimensions in this structure. It should be equal to size of [shape]. @@ -50,14 +54,13 @@ public interface StructureND { * @return the lazy sequence of pairs of indices to values. */ @PerformancePitfall - public fun elements(): Sequence> + public fun elements(): Sequence> = indices.asSequence().map { it to get(it) } /** - * Feature is some additional strucure information which allows to access it special properties or hints. - * If the feature is not present, null is returned. + * Feature is some additional structure information that allows to access it special properties or hints. + * If the feature is not present, `null` is returned. */ - @UnstableKMathAPI - public fun getFeature(type: KClass): F? = null + override fun getFeature(type: KClass): F? = null public companion object { /** @@ -68,13 +71,29 @@ public interface StructureND { if (st1 === st2) return true // fast comparison of buffers if possible - if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides) + if (st1 is BufferND && st2 is BufferND && st1.indices == st2.indices) return Buffer.contentEquals(st1.buffer, st2.buffer) //element by element comparison if it could not be avoided return st1.elements().all { (index, value) -> value == st2[index] } } + @PerformancePitfall + public fun contentEquals( + st1: StructureND, + st2: StructureND, + tolerance: Double = 1e-11 + ): Boolean { + if (st1 === st2) return true + + // fast comparison of buffers if possible + if (st1 is BufferND && st2 is BufferND && st1.indices == st2.indices) + return Buffer.contentEquals(st1.buffer, st2.buffer) + + //element by element comparison if it could not be avoided + return st1.elements().all { (index, value) -> abs(value - st2[index]) < tolerance } + } + /** * Debug output to string */ @@ -145,6 +164,44 @@ public interface StructureND { } } +/** + * Indicates whether some [StructureND] is equal to another one. + */ +@PerformancePitfall +public fun > AlgebraND>.contentEquals( + st1: StructureND, + st2: StructureND, +): Boolean = StructureND.contentEquals(st1, st2) + +/** + * Indicates whether some [StructureND] is equal to another one. + */ +@PerformancePitfall +public fun > LinearSpace>.contentEquals( + st1: StructureND, + st2: StructureND, +): Boolean = StructureND.contentEquals(st1, st2) + +/** + * Indicates whether some [StructureND] is equal to another one with [absoluteTolerance]. + */ +@PerformancePitfall +public fun > GroupOpsND>.contentEquals( + st1: StructureND, + st2: StructureND, + absoluteTolerance: T, +): Boolean = st1.elements().all { (index, value) -> elementAlgebra { (value - st2[index]) } < absoluteTolerance } + +/** + * Indicates whether some [StructureND] is equal to another one with [absoluteTolerance]. + */ +@PerformancePitfall +public fun > LinearSpace>.contentEquals( + st1: StructureND, + st2: StructureND, + absoluteTolerance: T, +): Boolean = st1.elements().all { (index, value) -> elementAlgebra { (value - st2[index]) } < absoluteTolerance } + /** * Returns the value at the specified indices. * @@ -153,8 +210,8 @@ public interface StructureND { */ public operator fun StructureND.get(vararg index: Int): T = get(index) -@UnstableKMathAPI -public inline fun StructureND<*>.getFeature(): T? = getFeature(T::class) +//@UnstableKMathAPI +//public inline fun StructureND<*>.getFeature(): T? = getFeature(T::class) /** * Represents mutable [StructureND]. @@ -173,107 +230,10 @@ public interface MutableStructureND : StructureND { * Transform a structure element-by element in place. */ @OptIn(PerformancePitfall::class) -public inline fun MutableStructureND.mapInPlace(action: (IntArray, T) -> T): Unit = +public inline fun MutableStructureND.mapInPlace(action: (index: IntArray, t: T) -> T): Unit = elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } -/** - * A way to convert ND index to linear one and back. - */ -public interface Strides { - /** - * Shape of NDStructure - */ - public val shape: IntArray - - /** - * Array strides - */ - public val strides: IntArray - - /** - * Get linear index from multidimensional index - */ - public fun offset(index: IntArray): Int = index.mapIndexed { i, value -> - if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})") - value * strides[i] - }.sum() - - /** - * Get multidimensional from linear - */ - public fun index(offset: Int): IntArray - - /** - * The size of linear buffer to accommodate all elements of ND-structure corresponding to strides - */ - public val linearSize: Int - - // TODO introduce a fast way to calculate index of the next element? - - /** - * Iterate over ND indices in a natural order - */ - public fun indices(): Sequence = (0 until linearSize).asSequence().map(::index) -} - -/** - * Simple implementation of [Strides]. - */ -public class DefaultStrides private constructor(override val shape: IntArray) : Strides { - override val linearSize: Int - get() = strides[shape.size] - - /** - * Strides for memory access - */ - override val strides: IntArray by lazy { - sequence { - var current = 1 - yield(1) - - shape.forEach { - current *= it - yield(current) - } - }.toList().toIntArray() - } - - override fun index(offset: Int): IntArray { - val res = IntArray(shape.size) - var current = offset - var strideIndex = strides.size - 2 - - while (strideIndex >= 0) { - res[strideIndex] = (current / strides[strideIndex]) - current %= strides[strideIndex] - strideIndex-- - } - - return res - } - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is DefaultStrides) return false - if (!shape.contentEquals(other.shape)) return false - return true - } - - override fun hashCode(): Int = shape.contentHashCode() - - @ThreadLocal - public companion object { - private val defaultStridesCache = HashMap() - - /** - * Cached builder for default strides - */ - public operator fun invoke(shape: IntArray): Strides = - defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) } - } -} - -public inline fun StructureND.combine( +public inline fun StructureND.zip( struct: StructureND, crossinline block: (T, T) -> T, ): StructureND { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/algebraNDExtentions.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/algebraNDExtentions.kt new file mode 100644 index 000000000..0e694bcb3 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/algebraNDExtentions.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.nd + +import space.kscience.kmath.operations.Algebra +import space.kscience.kmath.operations.Group +import space.kscience.kmath.operations.Ring +import kotlin.jvm.JvmName + + +public fun > AlgebraND.structureND( + shapeFirst: Int, + vararg shapeRest: Int, + initializer: A.(IntArray) -> T +): StructureND = structureND(Shape(shapeFirst, *shapeRest), initializer) + +public fun > AlgebraND.zero(shape: Shape): StructureND = structureND(shape) { zero } + +@JvmName("zeroVarArg") +public fun > AlgebraND.zero( + shapeFirst: Int, + vararg shapeRest: Int, +): StructureND = structureND(shapeFirst, *shapeRest) { zero } + +public fun > AlgebraND.one(shape: Shape): StructureND = structureND(shape) { one } + +@JvmName("oneVarArg") +public fun > AlgebraND.one( + shapeFirst: Int, + vararg shapeRest: Int, +): StructureND = structureND(shapeFirst, *shapeRest) { one } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt index 3a1ec430e..0e5c6de1f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations @@ -23,15 +23,13 @@ public interface Algebra { * Wraps a raw string to [T] object. This method is designed for three purposes: * * 1. Mathematical constants (`e`, `pi`). - * 2. Variables for expression-like contexts (`a`, `b`, `c`...). - * 3. Literals (`{1, 2}`, (`(3; 4)`)). + * 1. Variables for expression-like contexts (`a`, `b`, `c`…). + * 1. Literals (`{1, 2}`, (`(3; 4)`)). * - * In case if algebra can't parse the string, this method must throw [kotlin.IllegalStateException]. - * - * Returns `null` if symbol could not be bound to the context + * If algebra can't parse the string, then this method must throw [kotlin.IllegalStateException]. * * @param value the raw string. - * @return an object. + * @return an object or `null` if symbol could not be bound to the context. */ public fun bindSymbolOrNull(value: String): T? = null @@ -42,13 +40,12 @@ public interface Algebra { bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this") /** - * Dynamically dispatches an unary operation with the certain name. + * Dynamically dispatches a unary operation with the certain name. * - * This function must has two features: + * Implementations must fulfil the following requirements: * - * 1. In case operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with second `unaryOperation` overload: - * i.e. `unaryOperationFunction(a)(b) == unaryOperation(a, b)`. + * 1. If operation is not defined in the structure, then the function throws [kotlin.IllegalStateException]. + * 1. Equivalence to [unaryOperation]: for any `a` and `b`, `unaryOperationFunction(a)(b) == unaryOperation(a, b)`. * * @param operation the name of operation. * @return an operation. @@ -57,13 +54,13 @@ public interface Algebra { error("Unary operation $operation not defined in $this") /** - * Dynamically invokes an unary operation with the certain name. + * Dynamically invokes a unary operation with the certain name. * - * This function must follow two properties: + * Implementations must fulfil the following requirements: * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with second [unaryOperationFunction] overload: - * i.e. `unaryOperationFunction(a)(b) == unaryOperation(a, b)`. + * 1. If operation is not defined in the structure, then the function throws [kotlin.IllegalStateException]. + * 1. Equivalence to [unaryOperationFunction]: i.e., for any `a` and `b`, + * `unaryOperationFunction(a)(b) == unaryOperation(a, b)`. * * @param operation the name of operation. * @param arg the argument of operation. @@ -74,11 +71,11 @@ public interface Algebra { /** * Dynamically dispatches a binary operation with the certain name. * - * This function must follow two properties: + * Implementations must fulfil the following requirements: * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with second [binaryOperationFunction] overload: - * i.e. `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`. + * 1. If operation is not defined in the structure, then the function throws [kotlin.IllegalStateException]. + * 1. Equivalence to [binaryOperation]: for any `a`, `b`, and `c`, + * `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`. * * @param operation the name of operation. * @return an operation. @@ -89,11 +86,11 @@ public interface Algebra { /** * Dynamically invokes a binary operation with the certain name. * - * This function must follow two properties: + * Implementations must fulfil the following requirements: * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with second [binaryOperationFunction] overload: - * i.e. `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`. + * 1. If operation is not defined in the structure, then the function throws [kotlin.IllegalStateException]. + * 1. Equivalence to [binaryOperationFunction]: for any `a`, `b`, and `c`, + * `binaryOperationFunction(a)(b, c) == binaryOperation(a, b, c)`. * * @param operation the name of operation. * @param left the first argument of operation. @@ -115,22 +112,22 @@ public fun Algebra.bindSymbol(symbol: Symbol): T = bindSymbol(symbol.iden public inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) /** - * Represents group without neutral element (also known as inverse semigroup), i.e. algebraic structure with + * Represents group without neutral element (also known as inverse semigroup) i.e., algebraic structure with * associative, binary operation [add]. * * @param T the type of element of this semispace. */ -public interface GroupOperations : Algebra { +public interface GroupOps : Algebra { /** * Addition of two elements. * - * @param a the augend. - * @param b the addend. + * @param left the augend. + * @param right the addend. * @return the sum. */ - public fun add(a: T, b: T): T + public fun add(left: T, right: T): T - // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176 + // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176. /** * The negation of this element. @@ -152,27 +149,27 @@ public interface GroupOperations : Algebra { * Addition of two elements. * * @receiver the augend. - * @param b the addend. + * @param arg the addend. * @return the sum. */ - public operator fun T.plus(b: T): T = add(this, b) + public operator fun T.plus(arg: T): T = add(this, arg) /** * Subtraction of two elements. * * @receiver the minuend. - * @param b the subtrahend. + * @param arg the subtrahend. * @return the difference. */ - public operator fun T.minus(b: T): T = add(this, -b) - - public override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { + public operator fun T.minus(arg: T): T = add(this, -arg) + // Dynamic dispatch of operations + override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { PLUS_OPERATION -> { arg -> +arg } MINUS_OPERATION -> { arg -> -arg } else -> super.unaryOperationFunction(operation) } - public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { + override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { PLUS_OPERATION -> ::add MINUS_OPERATION -> { left, right -> left - right } else -> super.binaryOperationFunction(operation) @@ -192,11 +189,11 @@ public interface GroupOperations : Algebra { } /** - * Represents group, i.e. algebraic structure with associative, binary operation [add]. + * Represents group i.e., algebraic structure with associative, binary operation [add]. * * @param T the type of element of this semispace. */ -public interface Group : GroupOperations { +public interface Group : GroupOps { /** * The neutral element of addition. */ @@ -204,29 +201,29 @@ public interface Group : GroupOperations { } /** - * Represents ring without multiplicative and additive identities, i.e. algebraic structure with + * Represents ring without multiplicative and additive identities i.e., algebraic structure with * associative, binary, commutative operation [add] and associative, operation [multiply] distributive over [add]. * * @param T the type of element of this semiring. */ -public interface RingOperations : GroupOperations { +public interface RingOps : GroupOps { /** * Multiplies two elements. * - * @param a the multiplier. - * @param b the multiplicand. + * @param left the multiplier. + * @param right the multiplicand. */ - public fun multiply(a: T, b: T): T + public fun multiply(left: T, right: T): T /** * Multiplies this element by scalar. * * @receiver the multiplier. - * @param b the multiplicand. + * @param arg the multiplicand. */ - public operator fun T.times(b: T): T = multiply(this, b) + public operator fun T.times(arg: T): T = multiply(this, arg) - public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { + override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { TIMES_OPERATION -> ::multiply else -> super.binaryOperationFunction(operation) } @@ -240,12 +237,12 @@ public interface RingOperations : GroupOperations { } /** - * Represents ring, i.e. algebraic structure with two associative binary operations called "addition" and + * Represents ring i.e., algebraic structure with two associative binary operations called "addition" and * "multiplication" and their neutral elements. * * @param T the type of element of this ring. */ -public interface Ring : Group, RingOperations { +public interface Ring : Group, RingOps { /** * The neutral element of multiplication */ @@ -253,31 +250,32 @@ public interface Ring : Group, RingOperations { } /** - * Represents field without without multiplicative and additive identities, i.e. algebraic structure with associative, binary, commutative operations - * [add] and [multiply]; binary operation [divide] as multiplication of left operand by reciprocal of right one. + * Represents field without multiplicative and additive identities i.e., algebraic structure with associative, binary, + * commutative operations [add] and [multiply]; binary operation [divide] as multiplication of left operand by + * reciprocal of right one. * * @param T the type of element of this semifield. */ -public interface FieldOperations : RingOperations { +public interface FieldOps : RingOps { /** * Division of two elements. * - * @param a the dividend. - * @param b the divisor. + * @param left the dividend. + * @param right the divisor. * @return the quotient. */ - public fun divide(a: T, b: T): T + public fun divide(left: T, right: T): T /** * Division of two elements. * * @receiver the dividend. - * @param b the divisor. + * @param arg the divisor. * @return the quotient. */ - public operator fun T.div(b: T): T = divide(this, b) + public operator fun T.div(arg: T): T = divide(this, arg) - public override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { + override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { DIV_OPERATION -> ::divide else -> super.binaryOperationFunction(operation) } @@ -291,12 +289,12 @@ public interface FieldOperations : RingOperations { } /** - * Represents field, i.e. algebraic structure with three operations: associative, commutative addition and + * Represents field i.e., algebraic structure with three operations: associative, commutative addition and * multiplication, and division. **This interface differs from the eponymous mathematical definition: fields in KMath * also support associative multiplication by scalar.** * * @param T the type of element of this field. */ -public interface Field : Ring, FieldOperations, ScaleOperations, NumericAlgebra { - public override fun number(value: Number): T = scale(one, value.toDouble()) +public interface Field : Ring, FieldOps, ScaleOperations, NumericAlgebra { + override fun number(value: Number): T = scale(one, value.toDouble()) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/AlgebraElements.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/AlgebraElements.kt deleted file mode 100644 index cc058d3fc..000000000 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/AlgebraElements.kt +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.operations - -import space.kscience.kmath.misc.UnstableKMathAPI - -/** - * The generic mathematics elements which is able to store its context - * - * @param C the type of mathematical context for this element. - * @param T the type wrapped by this wrapper. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public interface AlgebraElement> { - /** - * The context this element belongs to. - */ - public val context: C -} -// -///** -// * Divides this element by number. -// * -// * @param k the divisor. -// * @return the quotient. -// */ -//public operator fun , S : Space> T.div(k: Number): T = -// context.multiply(this, 1.0 / k.toDouble()) -// -///** -// * Multiplies this element by number. -// * -// * @param k the multiplicand. -// * @return the product. -// */ -//public operator fun , S : Space> T.times(k: Number): T = -// context.multiply(this, k.toDouble()) - -/** - * Subtracts element from this one. - * - * @param b the subtrahend. - * @return the difference. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public operator fun , S : NumbersAddOperations> T.minus(b: T): T = - context.add(this, context.run { -b }) - -/** - * Adds element to this one. - * - * @receiver the augend. - * @param b the addend. - * @return the sum. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public operator fun , S : Ring> T.plus(b: T): T = - context.add(this, b) - -///** -// * Number times element -// */ -//public operator fun , S : Space> Number.times(element: T): T = -// element.times(this) - -/** - * Multiplies this element by another one. - * - * @receiver the multiplicand. - * @param b the multiplier. - * @return the product. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public operator fun , R : Ring> T.times(b: T): T = - context.multiply(this, b) - - -/** - * Divides this element by another one. - * - * @param b the divisor. - * @return the quotient. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public operator fun , F : Field> T.div(b: T): T = - context.divide(this, b) - - -/** - * The element of [Group]. - * - * @param T the type of space operation results. - * @param I self type of the element. Needed for static type checking. - * @param S the type of space. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public interface GroupElement, S : Group> : AlgebraElement - -/** - * The element of [Ring]. - * - * @param T the type of ring operation results. - * @param I self type of the element. Needed for static type checking. - * @param R the type of ring. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public interface RingElement, R : Ring> : GroupElement - -/** - * The element of [Field]. - * - * @param T the type of field operation results. - * @param I self type of the element. Needed for static type checking. - * @param F the type of field. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public interface FieldElement, F : Field> : RingElement diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt index ac53c4d5e..5a713049e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt @@ -1,13 +1,12 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.nd.AlgebraND -import space.kscience.kmath.nd.BufferedRingND +import space.kscience.kmath.nd.BufferedRingOpsND import space.kscience.kmath.operations.BigInt.Companion.BASE import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE import space.kscience.kmath.structures.Buffer @@ -27,7 +26,7 @@ private typealias TBase = ULong * @author Peter Klimai */ @OptIn(UnstableKMathAPI::class) -public object BigIntField : Field, NumbersAddOperations, ScaleOperations { +public object BigIntField : Field, NumbersAddOps, ScaleOperations { override val zero: BigInt = BigInt.ZERO override val one: BigInt = BigInt.ONE @@ -35,10 +34,10 @@ public object BigIntField : Field, NumbersAddOperations, ScaleOp @Suppress("EXTENSION_SHADOWED_BY_MEMBER") override fun BigInt.unaryMinus(): BigInt = -this - override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b) + override fun add(left: BigInt, right: BigInt): BigInt = left.plus(right) override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value)) - override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b) - override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b) + override fun multiply(left: BigInt, right: BigInt): BigInt = left.times(right) + override fun divide(left: BigInt, right: BigInt): BigInt = left.div(right) public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer") public operator fun String.unaryMinus(): BigInt = @@ -49,16 +48,16 @@ public class BigInt internal constructor( private val sign: Byte, private val magnitude: Magnitude, ) : Comparable { - public override fun compareTo(other: BigInt): Int = when { + override fun compareTo(other: BigInt): Int = when { (sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0 sign < other.sign -> -1 sign > other.sign -> 1 else -> sign * compareMagnitudes(magnitude, other.magnitude) } - public override fun equals(other: Any?): Boolean = other is BigInt && compareTo(other) == 0 + override fun equals(other: Any?): Boolean = other is BigInt && compareTo(other) == 0 - public override fun hashCode(): Int = magnitude.hashCode() + sign + override fun hashCode(): Int = magnitude.hashCode() + sign public fun abs(): BigInt = if (sign == 0.toByte()) this else BigInt(1, magnitude) @@ -121,7 +120,7 @@ public class BigInt internal constructor( var r = ZERO val bitSize = - (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.lastOrNull()?.toFloat() ?: 0f + 1)).toInt() + (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.lastOrNull()?.toFloat() ?: (0f + 1))).toInt() for (i in bitSize downTo 0) { r = r shl 1 @@ -442,10 +441,10 @@ public fun UIntArray.toBigInt(sign: Byte): BigInt { } /** - * Returns null if a valid number can not be read from a string + * Returns `null` if a valid number cannot be read from a string */ public fun String.parseBigInteger(): BigInt? { - if (this.isEmpty()) return null + if (isEmpty()) return null val sign: Int val positivePartIndex = when (this[0]) { @@ -527,11 +526,21 @@ public fun String.parseBigInteger(): BigInt? { } } +public val BigInt.algebra: BigIntField get() = BigIntField + +@Deprecated("Use BigInt::buffer") public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer = boxing(size, initializer) +public inline fun BigInt.Companion.buffer(size: Int, initializer: (Int) -> BigInt): Buffer = + Buffer.boxing(size, initializer) + +@Deprecated("Use BigInt::mutableBuffer") public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer = boxing(size, initializer) -public fun AlgebraND.Companion.bigInt(vararg shape: Int): BufferedRingND = - BufferedRingND(shape, BigIntField, Buffer.Companion::bigInt) +public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer = + Buffer.boxing(size, initializer) + +public val BigIntField.nd: BufferedRingOpsND + get() = BufferedRingOpsND(BufferRingOps(BigIntField, BigInt::buffer)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt new file mode 100644 index 000000000..bc05f3904 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BufferAlgebra.kt @@ -0,0 +1,195 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.operations + +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.BufferFactory +import space.kscience.kmath.structures.DoubleBuffer +import space.kscience.kmath.structures.ShortBuffer + +public interface WithSize { + public val size: Int +} + +/** + * An algebra over [Buffer] + */ +public interface BufferAlgebra> : Algebra> { + public val elementAlgebra: A + public val bufferFactory: BufferFactory + + public fun buffer(size: Int, vararg elements: T): Buffer { + require(elements.size == size) { "Expected $size elements but found ${elements.size}" } + return bufferFactory(size) { elements[it] } + } + + //TODO move to multi-receiver inline extension + public fun Buffer.map(block: A.(T) -> T): Buffer = mapInline(this, block) + + public fun Buffer.mapIndexed(block: A.(index: Int, arg: T) -> T): Buffer = mapIndexedInline(this, block) + + public fun Buffer.zip(other: Buffer, block: A.(left: T, right: T) -> T): Buffer = + zipInline(this, other, block) + + override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer { + val operationFunction = elementAlgebra.unaryOperationFunction(operation) + return { arg -> bufferFactory(arg.size) { operationFunction(arg[it]) } } + } + + override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer { + val operationFunction = elementAlgebra.binaryOperationFunction(operation) + return { left, right -> + bufferFactory(left.size) { operationFunction(left[it], right[it]) } + } + } +} + +/** + * Inline map + */ +public inline fun > BufferAlgebra.mapInline( + buffer: Buffer, + crossinline block: A.(T) -> T +): Buffer = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) } + +/** + * Inline map + */ +public inline fun > BufferAlgebra.mapIndexedInline( + buffer: Buffer, + crossinline block: A.(index: Int, arg: T) -> T +): Buffer = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) } + +/** + * Inline zip + */ +public inline fun > BufferAlgebra.zipInline( + l: Buffer, + r: Buffer, + crossinline block: A.(l: T, r: T) -> T +): Buffer { + require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" } + return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) } +} + +public fun BufferAlgebra.buffer(size: Int, initializer: (Int) -> T): Buffer { + return bufferFactory(size, initializer) +} + +public fun A.buffer(initializer: (Int) -> T): Buffer where A : BufferAlgebra, A : WithSize { + return bufferFactory(size, initializer) +} + +public fun > BufferAlgebra.sin(arg: Buffer): Buffer = + mapInline(arg) { sin(it) } + +public fun > BufferAlgebra.cos(arg: Buffer): Buffer = + mapInline(arg) { cos(it) } + +public fun > BufferAlgebra.tan(arg: Buffer): Buffer = + mapInline(arg) { tan(it) } + +public fun > BufferAlgebra.asin(arg: Buffer): Buffer = + mapInline(arg) { asin(it) } + +public fun > BufferAlgebra.acos(arg: Buffer): Buffer = + mapInline(arg) { acos(it) } + +public fun > BufferAlgebra.atan(arg: Buffer): Buffer = + mapInline(arg) { atan(it) } + +public fun > BufferAlgebra.exp(arg: Buffer): Buffer = + mapInline(arg) { exp(it) } + +public fun > BufferAlgebra.ln(arg: Buffer): Buffer = + mapInline(arg) { ln(it) } + +public fun > BufferAlgebra.sinh(arg: Buffer): Buffer = + mapInline(arg) { sinh(it) } + +public fun > BufferAlgebra.cosh(arg: Buffer): Buffer = + mapInline(arg) { cosh(it) } + +public fun > BufferAlgebra.tanh(arg: Buffer): Buffer = + mapInline(arg) { tanh(it) } + +public fun > BufferAlgebra.asinh(arg: Buffer): Buffer = + mapInline(arg) { asinh(it) } + +public fun > BufferAlgebra.acosh(arg: Buffer): Buffer = + mapInline(arg) { acosh(it) } + +public fun > BufferAlgebra.atanh(arg: Buffer): Buffer = + mapInline(arg) { atanh(it) } + +public fun > BufferAlgebra.pow(arg: Buffer, pow: Number): Buffer = + mapInline(arg) { power(it, pow) } + + +public open class BufferRingOps>( + override val elementAlgebra: A, + override val bufferFactory: BufferFactory, +) : BufferAlgebra, RingOps>{ + + override fun add(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l + r } + override fun multiply(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l * r } + override fun Buffer.unaryMinus(): Buffer = map { -it } + + override fun unaryOperationFunction(operation: String): (arg: Buffer) -> Buffer = + super.unaryOperationFunction(operation) + + override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = + super.binaryOperationFunction(operation) +} + +public val ShortRing.bufferAlgebra: BufferRingOps + get() = BufferRingOps(ShortRing, ::ShortBuffer) + +public open class BufferFieldOps>( + elementAlgebra: A, + bufferFactory: BufferFactory, +) : BufferRingOps(elementAlgebra, bufferFactory), BufferAlgebra, FieldOps>, ScaleOperations> { + + override fun add(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l + r } + override fun multiply(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l * r } + override fun divide(left: Buffer, right: Buffer): Buffer = zipInline(left, right) { l, r -> l / r } + + override fun scale(a: Buffer, value: Double): Buffer = a.map { scale(it, value) } + override fun Buffer.unaryMinus(): Buffer = map { -it } + + override fun binaryOperationFunction(operation: String): (left: Buffer, right: Buffer) -> Buffer = + super.binaryOperationFunction(operation) +} + +public class BufferField>( + elementAlgebra: A, + bufferFactory: BufferFactory, + override val size: Int +) : BufferFieldOps(elementAlgebra, bufferFactory), Field>, WithSize { + + override val zero: Buffer = bufferFactory(size) { elementAlgebra.zero } + override val one: Buffer = bufferFactory(size) { elementAlgebra.one } +} + +/** + * Generate full buffer field from given buffer operations + */ +public fun > BufferFieldOps.withSize(size: Int): BufferField = + BufferField(elementAlgebra, bufferFactory, size) + +//Double buffer specialization + +public fun BufferField.buffer(vararg elements: Number): Buffer { + require(elements.size == size) { "Expected $size elements but found ${elements.size}" } + return bufferFactory(size) { elements[it].toDouble() } +} + +public fun > A.bufferAlgebra(bufferFactory: BufferFactory): BufferFieldOps = + BufferFieldOps(this, bufferFactory) + +public val DoubleField.bufferAlgebra: BufferFieldOps + get() = BufferFieldOps(DoubleField, ::DoubleBuffer) + diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt new file mode 100644 index 000000000..060ea5a7e --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferField.kt @@ -0,0 +1,130 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.operations + +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.DoubleBuffer + +/** + * [ExtendedField] over [DoubleBuffer]. + * + * @property size the size of buffers to operate on. + */ +public class DoubleBufferField(public val size: Int) : ExtendedField>, DoubleBufferOps() { + override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } + override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } + + override fun sinh(arg: Buffer): DoubleBuffer = super.sinh(arg) + + override fun cosh(arg: Buffer): DoubleBuffer = super.cosh(arg) + + override fun tanh(arg: Buffer): DoubleBuffer = super.tanh(arg) + + override fun asinh(arg: Buffer): DoubleBuffer = super.asinh(arg) + + override fun acosh(arg: Buffer): DoubleBuffer = super.acosh(arg) + + override fun atanh(arg: Buffer): DoubleBuffer= super.atanh(arg) + + // override fun number(value: Number): Buffer = DoubleBuffer(size) { value.toDouble() } +// +// override fun Buffer.unaryMinus(): Buffer = DoubleBufferOperations.run { +// -this@unaryMinus +// } +// +// override fun add(a: Buffer, b: Buffer): DoubleBuffer { +// require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } +// return DoubleBufferOperations.add(a, b) +// } +// + +// +// override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { +// require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } +// return DoubleBufferOperations.multiply(a, b) +// } +// +// override fun divide(a: Buffer, b: Buffer): DoubleBuffer { +// require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } +// return DoubleBufferOperations.divide(a, b) +// } +// +// override fun sin(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.sin(arg) +// } +// +// override fun cos(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.cos(arg) +// } +// +// override fun tan(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.tan(arg) +// } +// +// override fun asin(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.asin(arg) +// } +// +// override fun acos(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.acos(arg) +// } +// +// override fun atan(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.atan(arg) +// } +// +// override fun sinh(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.sinh(arg) +// } +// +// override fun cosh(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.cosh(arg) +// } +// +// override fun tanh(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.tanh(arg) +// } +// +// override fun asinh(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.asinh(arg) +// } +// +// override fun acosh(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.acosh(arg) +// } +// +// override fun atanh(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.atanh(arg) +// } +// +// override fun power(arg: Buffer, pow: Number): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.power(arg, pow) +// } +// +// override fun exp(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.exp(arg) +// } +// +// override fun ln(arg: Buffer): DoubleBuffer { +// require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } +// return DoubleBufferOperations.ln(arg) +// } + +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt new file mode 100644 index 000000000..3d51b3d32 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt @@ -0,0 +1,195 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.operations + +import space.kscience.kmath.linear.Point +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.DoubleBuffer + +import kotlin.math.* + +/** + * [ExtendedFieldOps] over [DoubleBuffer]. + */ +public abstract class DoubleBufferOps : ExtendedFieldOps>, Norm, Double> { + + override fun Buffer.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) { + DoubleBuffer(size) { -array[it] } + } else { + DoubleBuffer(size) { -get(it) } + } + + override fun add(left: Buffer, right: Buffer): DoubleBuffer { + require(right.size == left.size) { + "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} " + } + + return if (left is DoubleBuffer && right is DoubleBuffer) { + val aArray = left.array + val bArray = right.array + DoubleBuffer(DoubleArray(left.size) { aArray[it] + bArray[it] }) + } else DoubleBuffer(DoubleArray(left.size) { left[it] + right[it] }) + } + + override fun Buffer.plus(arg: Buffer): DoubleBuffer = add(this, arg) + + override fun Buffer.minus(arg: Buffer): DoubleBuffer { + require(arg.size == this.size) { + "The size of the first buffer ${this.size} should be the same as for second one: ${arg.size} " + } + + return if (this is DoubleBuffer && arg is DoubleBuffer) { + val aArray = this.array + val bArray = arg.array + DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] }) + } else DoubleBuffer(DoubleArray(this.size) { this[it] - arg[it] }) + } + + // +// override fun multiply(a: Buffer, k: Number): RealBuffer { +// val kValue = k.toDouble() +// +// return if (a is RealBuffer) { +// val aArray = a.array +// RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) +// } else RealBuffer(DoubleArray(a.size) { a[it] * kValue }) +// } +// +// override fun divide(a: Buffer, k: Number): RealBuffer { +// val kValue = k.toDouble() +// +// return if (a is RealBuffer) { +// val aArray = a.array +// RealBuffer(DoubleArray(a.size) { aArray[it] / kValue }) +// } else RealBuffer(DoubleArray(a.size) { a[it] / kValue }) +// } + + override fun multiply(left: Buffer, right: Buffer): DoubleBuffer { + require(right.size == left.size) { + "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} " + } + + return if (left is DoubleBuffer && right is DoubleBuffer) { + val aArray = left.array + val bArray = right.array + DoubleBuffer(DoubleArray(left.size) { aArray[it] * bArray[it] }) + } else + DoubleBuffer(DoubleArray(left.size) { left[it] * right[it] }) + } + + override fun divide(left: Buffer, right: Buffer): DoubleBuffer { + require(right.size == left.size) { + "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} " + } + + return if (left is DoubleBuffer && right is DoubleBuffer) { + val aArray = left.array + val bArray = right.array + DoubleBuffer(DoubleArray(left.size) { aArray[it] / bArray[it] }) + } else DoubleBuffer(DoubleArray(left.size) { left[it] / right[it] }) + } + + override fun sin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) + } else DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) + + override fun cos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) + } else DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + + override fun tan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { tan(array[it]) }) + } else DoubleBuffer(DoubleArray(arg.size) { tan(arg[it]) }) + + override fun asin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { asin(array[it]) }) + } else + DoubleBuffer(DoubleArray(arg.size) { asin(arg[it]) }) + + override fun acos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { acos(array[it]) }) + } else + DoubleBuffer(DoubleArray(arg.size) { acos(arg[it]) }) + + override fun atan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { atan(array[it]) }) + } else + DoubleBuffer(DoubleArray(arg.size) { atan(arg[it]) }) + + override fun sinh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { sinh(array[it]) }) + } else + DoubleBuffer(DoubleArray(arg.size) { sinh(arg[it]) }) + + override fun cosh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { cosh(array[it]) }) + } else + DoubleBuffer(DoubleArray(arg.size) { cosh(arg[it]) }) + + override fun tanh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { tanh(array[it]) }) + } else + DoubleBuffer(DoubleArray(arg.size) { tanh(arg[it]) }) + + override fun asinh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { asinh(array[it]) }) + } else + DoubleBuffer(DoubleArray(arg.size) { asinh(arg[it]) }) + + override fun acosh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { acosh(array[it]) }) + } else + DoubleBuffer(DoubleArray(arg.size) { acosh(arg[it]) }) + + override fun atanh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { atanh(array[it]) }) + } else + DoubleBuffer(DoubleArray(arg.size) { atanh(arg[it]) }) + + override fun power(arg: Buffer, pow: Number): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) + } else + DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + + override fun exp(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) + } else DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) + + override fun ln(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) + } + + override fun norm(arg: Buffer): Double = DoubleL2Norm.norm(arg) + + override fun scale(a: Buffer, value: Double): DoubleBuffer = if (a is DoubleBuffer) { + val aArray = a.array + DoubleBuffer(DoubleArray(a.size) { aArray[it] * value }) + } else DoubleBuffer(DoubleArray(a.size) { a[it] * value }) + + public companion object : DoubleBufferOps() +} + +public object DoubleL2Norm : Norm, Double> { + override fun norm(arg: Point): Double = sqrt(arg.fold(0.0) { acc: Double, d: Double -> acc + d.pow(2) }) +} + diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/LogicAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/LogicAlgebra.kt index 9037525e1..d50f1e79e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/LogicAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/LogicAlgebra.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt index deeb07e0e..9d9fc0885 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations @@ -26,11 +26,11 @@ public interface NumericAlgebra : Algebra { /** * Dynamically dispatches a binary operation with the certain name with numeric first argument. * - * This function must follow two properties: + * Implementations must fulfil the following requirements: * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with the other [leftSideNumberOperation] overload: - * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b)`. + * 1. If operation is not defined in the structure, then function throws [kotlin.IllegalStateException]. + * 1. Equivalence to [leftSideNumberOperation]: for any `a`, `b`, and `c`, + * `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b)`. * * @param operation the name of operation. * @return an operation. @@ -41,11 +41,11 @@ public interface NumericAlgebra : Algebra { /** * Dynamically invokes a binary operation with the certain name with numeric first argument. * - * This function must follow two properties: + * Implementations must fulfil the following requirements: * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with second [leftSideNumberOperation] overload: - * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. + * 1. If operation is not defined in the structure, then the function throws [kotlin.IllegalStateException]. + * 1. Equivalence to [leftSideNumberOperation]: for any `a`, `b`, and `c`, + * `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. * * @param operation the name of operation. * @param left the first argument of operation. @@ -58,11 +58,11 @@ public interface NumericAlgebra : Algebra { /** * Dynamically dispatches a binary operation with the certain name with numeric first argument. * - * This function must follow two properties: + * Implementations must fulfil the following requirements: * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: - * i.e. `rightSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. + * 1. If operation is not defined in the structure, then the function throws [kotlin.IllegalStateException]. + * 1. Equivalence to [rightSideNumberOperation]: for any `a`, `b`, and `c`, + * `rightSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. * * @param operation the name of operation. * @return an operation. @@ -73,11 +73,11 @@ public interface NumericAlgebra : Algebra { /** * Dynamically invokes a binary operation with the certain name with numeric second argument. * - * This function must follow two properties: + * Implementations must fulfil the following requirements: * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: - * i.e. `rightSideNumberOperationFunction(a)(b, c) == rightSideNumberOperation(a, b, c)`. + * 1. If operation is not defined in the structure, then the function throws [kotlin.IllegalStateException]. + * 1. Equivalence to [rightSideNumberOperationFunction]: for any `a`, `b`, and `c`, + * `rightSideNumberOperationFunction(a)(b, c) == rightSideNumberOperation(a, b, c)`. * * @param operation the name of operation. * @param left the first argument of operation. @@ -87,7 +87,7 @@ public interface NumericAlgebra : Algebra { public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = rightSideNumberOperationFunction(operation)(left, right) - public override fun bindSymbolOrNull(value: String): T? = when (value) { + override fun bindSymbolOrNull(value: String): T? = when (value) { "pi" -> number(PI) "e" -> number(E) else -> super.bindSymbolOrNull(value) @@ -139,10 +139,10 @@ public interface ScaleOperations : Algebra { * Multiplication of this number by element. * * @receiver the multiplier. - * @param b the multiplicand. + * @param arg the multiplicand. * @return the product. */ - public operator fun Number.times(b: T): T = b * this + public operator fun Number.times(arg: T): T = arg * this } /** @@ -150,38 +150,38 @@ public interface ScaleOperations : Algebra { * TODO to be removed and replaced by extensions after multiple receivers are there */ @UnstableKMathAPI -public interface NumbersAddOperations : Ring, NumericAlgebra { +public interface NumbersAddOps : RingOps, NumericAlgebra { /** * Addition of element and scalar. * * @receiver the augend. - * @param b the addend. + * @param other the addend. */ - public operator fun T.plus(b: Number): T = this + number(b) + public operator fun T.plus(other: Number): T = this + number(other) /** * Addition of scalar and element. * * @receiver the augend. - * @param b the addend. + * @param other the addend. */ - public operator fun Number.plus(b: T): T = b + this + public operator fun Number.plus(other: T): T = other + this /** * Subtraction of element from number. * * @receiver the minuend. - * @param b the subtrahend. + * @param other the subtrahend. * @receiver the difference. */ - public operator fun T.minus(b: Number): T = this - number(b) + public operator fun T.minus(other: Number): T = this - number(other) /** * Subtraction of number from element. * * @receiver the minuend. - * @param b the subtrahend. + * @param other the subtrahend. * @receiver the difference. */ - public operator fun Number.minus(b: T): T = -b + this + public operator fun Number.minus(other: T): T = -other + this } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/OptionalOperations.kt index 86365394f..d32e03533 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/OptionalOperations.kt @@ -1,12 +1,10 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations -import space.kscience.kmath.misc.UnstableKMathAPI - /** * A container for trigonometric operations for specific type. * @@ -76,48 +74,6 @@ public interface TrigonometricOperations : Algebra { } } -/** - * Computes the sine of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> sin(arg: T): T = arg.context.sin(arg) - -/** - * Computes the cosine of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> cos(arg: T): T = arg.context.cos(arg) - -/** - * Computes the tangent of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> tan(arg: T): T = arg.context.tan(arg) - -/** - * Computes the inverse sine of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> asin(arg: T): T = arg.context.asin(arg) - -/** - * Computes the inverse cosine of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> acos(arg: T): T = arg.context.acos(arg) - -/** - * Computes the inverse tangent of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> atan(arg: T): T = arg.context.atan(arg) - /** * A context extension to include power operations based on exponentiation. * @@ -152,31 +108,6 @@ public interface PowerOperations : Algebra { } } -/** - * Raises this element to the power [power]. - * - * @receiver the base. - * @param power the exponent. - * @return the base raised to the power. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public infix fun >> T.pow(power: Double): T = context.power(this, power) - -/** - * Computes the square root of the value [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> sqrt(arg: T): T = arg pow 0.5 - -/** - * Computes the square of the value [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> sqr(arg: T): T = arg pow 2.0 - /** * A container for operations related to `exp` and `ln` functions. * @@ -266,62 +197,6 @@ public interface ExponentialOperations : Algebra { } } -/** - * The identifier of exponential function. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> exp(arg: T): T = arg.context.exp(arg) - -/** - * The identifier of natural logarithm. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> ln(arg: T): T = arg.context.ln(arg) - - -/** - * Computes the hyperbolic sine of [arg]. - */ -@UnstableKMathAPI -public fun >> sinh(arg: T): T = arg.context.sinh(arg) - -/** - * Computes the hyperbolic cosine of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> cosh(arg: T): T = arg.context.cosh(arg) - -/** - * Computes the hyperbolic tangent of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> tanh(arg: T): T = arg.context.tanh(arg) - -/** - * Computes the inverse hyperbolic sine of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> asinh(arg: T): T = arg.context.asinh(arg) - -/** - * Computes the inverse hyperbolic cosine of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> acosh(arg: T): T = arg.context.acosh(arg) - -/** - * Computes the inverse hyperbolic tangent of [arg]. - */ -@UnstableKMathAPI -@Deprecated("AlgebraElements are considered odd and will be removed in future releases.") -public fun >> atanh(arg: T): T = arg.context.atanh(arg) - /** * A container for norm functional on element. * @@ -330,13 +205,8 @@ public fun >> atanh(arg: T): */ public interface Norm { /** - * Computes the norm of [arg] (i.e. absolute value or vector length). + * Computes the norm of [arg] (i.e., absolute value or vector length). */ public fun norm(arg: T): R } -/** - * Computes the norm of [arg] (i.e. absolute value or vector length). - */ -@UnstableKMathAPI -public fun >, R> norm(arg: T): R = arg.context.norm(arg) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/algebraExtensions.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/algebraExtensions.kt index d52be943a..b26ebb2ea 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/algebraExtensions.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/algebraExtensions.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/bufferOperation.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/bufferOperation.kt similarity index 94% rename from kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/bufferOperation.kt rename to kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/bufferOperation.kt index 1b89e7838..6bf3266e3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/bufferOperation.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/bufferOperation.kt @@ -1,11 +1,12 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ -package space.kscience.kmath.structures +package space.kscience.kmath.operations import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.structures.* /** * Typealias for buffer transformations. @@ -67,9 +68,9 @@ public inline fun Buffer.map(block: (T) -> R): Buf * Create a new buffer from this one with the given mapping function. * Provided [bufferFactory] is used to construct the new buffer. */ -public fun Buffer.map( +public inline fun Buffer.map( bufferFactory: BufferFactory, - block: (T) -> R, + crossinline block: (T) -> R, ): Buffer = bufferFactory(size) { block(get(it)) } /** diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt index 36c13d6ec..ceb85f3ab 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations @@ -10,15 +10,16 @@ import kotlin.math.pow as kpow /** * Advanced Number-like semifield that implements basic operations. */ -public interface ExtendedFieldOperations : - FieldOperations, +public interface ExtendedFieldOps : + FieldOps, TrigonometricOperations, PowerOperations, - ExponentialOperations { - public override fun tan(arg: T): T = sin(arg) / cos(arg) - public override fun tanh(arg: T): T = sinh(arg) / cosh(arg) + ExponentialOperations, + ScaleOperations { + override fun tan(arg: T): T = sin(arg) / cos(arg) + override fun tanh(arg: T): T = sinh(arg) / cosh(arg) - public override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { + override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { TrigonometricOperations.COS_OPERATION -> ::cos TrigonometricOperations.SIN_OPERATION -> ::sin TrigonometricOperations.TAN_OPERATION -> ::tan @@ -34,22 +35,22 @@ public interface ExtendedFieldOperations : ExponentialOperations.ACOSH_OPERATION -> ::acosh ExponentialOperations.ASINH_OPERATION -> ::asinh ExponentialOperations.ATANH_OPERATION -> ::atanh - else -> super.unaryOperationFunction(operation) + else -> super.unaryOperationFunction(operation) } } /** * Advanced Number-like field that implements basic operations. */ -public interface ExtendedField : ExtendedFieldOperations, Field, NumericAlgebra, ScaleOperations { - public override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0 - public override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0 - public override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) - public override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg) - public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) - public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2.0 +public interface ExtendedField : ExtendedFieldOps, Field, NumericAlgebra{ + override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0 + override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0 + override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) + override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg) + override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) + override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2.0 - public override fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = + override fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = when (operation) { PowerOperations.POW_OPERATION -> ::power else -> super.rightSideNumberOperationFunction(operation) @@ -61,187 +62,199 @@ public interface ExtendedField : ExtendedFieldOperations, Field, Numeri */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object DoubleField : ExtendedField, Norm, ScaleOperations { - public override inline val zero: Double get() = 0.0 - public override inline val one: Double get() = 1.0 + override inline val zero: Double get() = 0.0 + override inline val one: Double get() = 1.0 - public override inline fun number(value: Number): Double = value.toDouble() + override inline fun number(value: Number): Double = value.toDouble() - public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = + override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = when (operation) { PowerOperations.POW_OPERATION -> ::power else -> super.binaryOperationFunction(operation) } - public override inline fun add(a: Double, b: Double): Double = a + b + override inline fun add(left: Double, right: Double): Double = left + right - public override inline fun multiply(a: Double, b: Double): Double = a * b - public override inline fun divide(a: Double, b: Double): Double = a / b + override inline fun multiply(left: Double, right: Double): Double = left * right + override inline fun divide(left: Double, right: Double): Double = left / right - public override inline fun scale(a: Double, value: Double): Double = a * value + override inline fun scale(a: Double, value: Double): Double = a * value - public override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) - public override inline fun cos(arg: Double): Double = kotlin.math.cos(arg) - public override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) - public override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) - public override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) - public override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) + override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) + override inline fun cos(arg: Double): Double = kotlin.math.cos(arg) + override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) + override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) + override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) + override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) - public override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg) - public override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg) - public override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg) - public override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg) - public override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg) - public override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg) + override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg) + override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg) + override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg) + override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg) + override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg) + override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg) - public override inline fun sqrt(arg: Double): Double = kotlin.math.sqrt(arg) - public override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) - public override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) - public override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) + override inline fun sqrt(arg: Double): Double = kotlin.math.sqrt(arg) + override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) + override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) + override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) - public override inline fun norm(arg: Double): Double = abs(arg) + override inline fun norm(arg: Double): Double = abs(arg) - public override inline fun Double.unaryMinus(): Double = -this - public override inline fun Double.plus(b: Double): Double = this + b - public override inline fun Double.minus(b: Double): Double = this - b - public override inline fun Double.times(b: Double): Double = this * b - public override inline fun Double.div(b: Double): Double = this / b + override inline fun Double.unaryMinus(): Double = -this + override inline fun Double.plus(arg: Double): Double = this + arg + override inline fun Double.minus(arg: Double): Double = this - arg + override inline fun Double.times(arg: Double): Double = this * arg + override inline fun Double.div(arg: Double): Double = this / arg } +public val Double.Companion.algebra: DoubleField get() = DoubleField + /** * A field for [Float] without boxing. Does not produce appropriate field element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object FloatField : ExtendedField, Norm { - public override inline val zero: Float get() = 0.0f - public override inline val one: Float get() = 1.0f + override inline val zero: Float get() = 0.0f + override inline val one: Float get() = 1.0f - public override fun number(value: Number): Float = value.toFloat() + override fun number(value: Number): Float = value.toFloat() - public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float = + override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float = when (operation) { PowerOperations.POW_OPERATION -> ::power else -> super.binaryOperationFunction(operation) } - public override inline fun add(a: Float, b: Float): Float = a + b - public override fun scale(a: Float, value: Double): Float = a * value.toFloat() + override inline fun add(left: Float, right: Float): Float = left + right + override fun scale(a: Float, value: Double): Float = a * value.toFloat() - public override inline fun multiply(a: Float, b: Float): Float = a * b + override inline fun multiply(left: Float, right: Float): Float = left * right - public override inline fun divide(a: Float, b: Float): Float = a / b + override inline fun divide(left: Float, right: Float): Float = left / right - public override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) - public override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) - public override inline fun tan(arg: Float): Float = kotlin.math.tan(arg) - public override inline fun acos(arg: Float): Float = kotlin.math.acos(arg) - public override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) - public override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) + override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) + override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) + override inline fun tan(arg: Float): Float = kotlin.math.tan(arg) + override inline fun acos(arg: Float): Float = kotlin.math.acos(arg) + override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) + override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) - public override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg) - public override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg) - public override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg) - public override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg) - public override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg) - public override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg) + override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg) + override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg) + override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg) + override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg) + override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg) + override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg) - public override inline fun sqrt(arg: Float): Float = kotlin.math.sqrt(arg) - public override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat()) - public override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) - public override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) + override inline fun sqrt(arg: Float): Float = kotlin.math.sqrt(arg) + override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat()) + override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) + override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) - public override inline fun norm(arg: Float): Float = abs(arg) + override inline fun norm(arg: Float): Float = abs(arg) - public override inline fun Float.unaryMinus(): Float = -this - public override inline fun Float.plus(b: Float): Float = this + b - public override inline fun Float.minus(b: Float): Float = this - b - public override inline fun Float.times(b: Float): Float = this * b - public override inline fun Float.div(b: Float): Float = this / b + override inline fun Float.unaryMinus(): Float = -this + override inline fun Float.plus(arg: Float): Float = this + arg + override inline fun Float.minus(arg: Float): Float = this - arg + override inline fun Float.times(arg: Float): Float = this * arg + override inline fun Float.div(arg: Float): Float = this / arg } +public val Float.Companion.algebra: FloatField get() = FloatField + /** * A field for [Int] without boxing. Does not produce corresponding ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object IntRing : Ring, Norm, NumericAlgebra { - public override inline val zero: Int + override inline val zero: Int get() = 0 - public override inline val one: Int + override inline val one: Int get() = 1 - public override fun number(value: Number): Int = value.toInt() - public override inline fun add(a: Int, b: Int): Int = a + b - public override inline fun multiply(a: Int, b: Int): Int = a * b - public override inline fun norm(arg: Int): Int = abs(arg) + override fun number(value: Number): Int = value.toInt() + override inline fun add(left: Int, right: Int): Int = left + right + override inline fun multiply(left: Int, right: Int): Int = left * right + override inline fun norm(arg: Int): Int = abs(arg) - public override inline fun Int.unaryMinus(): Int = -this - public override inline fun Int.plus(b: Int): Int = this + b - public override inline fun Int.minus(b: Int): Int = this - b - public override inline fun Int.times(b: Int): Int = this * b + override inline fun Int.unaryMinus(): Int = -this + override inline fun Int.plus(arg: Int): Int = this + arg + override inline fun Int.minus(arg: Int): Int = this - arg + override inline fun Int.times(arg: Int): Int = this * arg } +public val Int.Companion.algebra: IntRing get() = IntRing + /** * A field for [Short] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object ShortRing : Ring, Norm, NumericAlgebra { - public override inline val zero: Short + override inline val zero: Short get() = 0 - public override inline val one: Short + override inline val one: Short get() = 1 - public override fun number(value: Number): Short = value.toShort() - public override inline fun add(a: Short, b: Short): Short = (a + b).toShort() - public override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() - public override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() + override fun number(value: Number): Short = value.toShort() + override inline fun add(left: Short, right: Short): Short = (left + right).toShort() + override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort() + override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() - public override inline fun Short.unaryMinus(): Short = (-this).toShort() - public override inline fun Short.plus(b: Short): Short = (this + b).toShort() - public override inline fun Short.minus(b: Short): Short = (this - b).toShort() - public override inline fun Short.times(b: Short): Short = (this * b).toShort() + override inline fun Short.unaryMinus(): Short = (-this).toShort() + override inline fun Short.plus(arg: Short): Short = (this + arg).toShort() + override inline fun Short.minus(arg: Short): Short = (this - arg).toShort() + override inline fun Short.times(arg: Short): Short = (this * arg).toShort() } +public val Short.Companion.algebra: ShortRing get() = ShortRing + /** * A field for [Byte] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object ByteRing : Ring, Norm, NumericAlgebra { - public override inline val zero: Byte + override inline val zero: Byte get() = 0 - public override inline val one: Byte + override inline val one: Byte get() = 1 - public override fun number(value: Number): Byte = value.toByte() - public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() - public override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() - public override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() + override fun number(value: Number): Byte = value.toByte() + override inline fun add(left: Byte, right: Byte): Byte = (left + right).toByte() + override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte() + override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() - public override inline fun Byte.unaryMinus(): Byte = (-this).toByte() - public override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() - public override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() - public override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() + override inline fun Byte.unaryMinus(): Byte = (-this).toByte() + override inline fun Byte.plus(arg: Byte): Byte = (this + arg).toByte() + override inline fun Byte.minus(arg: Byte): Byte = (this - arg).toByte() + override inline fun Byte.times(arg: Byte): Byte = (this * arg).toByte() } +public val Byte.Companion.algebra: ByteRing get() = ByteRing + /** * A field for [Double] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object LongRing : Ring, Norm, NumericAlgebra { - public override inline val zero: Long + override inline val zero: Long get() = 0L - public override inline val one: Long + override inline val one: Long get() = 1L - public override fun number(value: Number): Long = value.toLong() - public override inline fun add(a: Long, b: Long): Long = a + b - public override inline fun multiply(a: Long, b: Long): Long = a * b - public override fun norm(arg: Long): Long = abs(arg) + override fun number(value: Number): Long = value.toLong() + override inline fun add(left: Long, right: Long): Long = left + right + override inline fun multiply(left: Long, right: Long): Long = left * right + override fun norm(arg: Long): Long = abs(arg) - public override inline fun Long.unaryMinus(): Long = (-this) - public override inline fun Long.plus(b: Long): Long = (this + b) - public override inline fun Long.minus(b: Long): Long = (this - b) - public override inline fun Long.times(b: Long): Long = (this * b) + override inline fun Long.unaryMinus(): Long = (-this) + override inline fun Long.plus(arg: Long): Long = (this + arg) + override inline fun Long.minus(arg: Long): Long = (this - arg) + override inline fun Long.times(arg: Long): Long = (this * arg) } + +public val Long.Companion.algebra: LongRing get() = LongRing diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ArrayBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ArrayBuffer.kt new file mode 100644 index 000000000..393ee99d6 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ArrayBuffer.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.structures + +/** + * [MutableBuffer] implementation over [Array]. + * + * @param T the type of elements contained in the buffer. + * @property array The underlying array. + */ +public class ArrayBuffer(internal val array: Array) : MutableBuffer { + // Can't inline because array is invariant + override val size: Int get() = array.size + + override operator fun get(index: Int): T = array[index] + + override operator fun set(index: Int, value: T) { + array[index] = value + } + + override operator fun iterator(): Iterator = array.iterator() + override fun copy(): MutableBuffer = ArrayBuffer(array.copyOf()) + + override fun toString(): String = Buffer.toString(this) +} + + +/** + * Returns an [ArrayBuffer] that wraps the original array. + */ +public fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt index 82f17b807..c68bca2d9 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt @@ -1,10 +1,11 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures +import space.kscience.kmath.operations.asSequence import kotlin.jvm.JvmInline import kotlin.reflect.KClass @@ -45,7 +46,13 @@ public interface Buffer { */ public operator fun iterator(): Iterator + override fun toString(): String + public companion object { + + public fun toString(buffer: Buffer<*>): String = + buffer.asSequence().joinToString(prefix = "[", separator = ", ", postfix = "]") + /** * Check the element-by-element match of content of two buffers. */ @@ -98,170 +105,6 @@ public interface Buffer { */ public val Buffer<*>.indices: IntRange get() = 0 until size -/** - * A generic mutable random-access structure for both primitives and objects. - * - * @param T the type of elements contained in the buffer. - */ -public interface MutableBuffer : Buffer { - /** - * Sets the array element at the specified [index] to the specified [value]. - */ - public operator fun set(index: Int, value: T) - - /** - * Returns a shallow copy of the buffer. - */ - public fun copy(): MutableBuffer - - public companion object { - /** - * Creates a [DoubleBuffer] with the specified [size], where each element is calculated by calling the specified - * [initializer] function. - */ - public inline fun double(size: Int, initializer: (Int) -> Double): DoubleBuffer = - DoubleBuffer(size, initializer) - - /** - * Creates a [ShortBuffer] with the specified [size], where each element is calculated by calling the specified - * [initializer] function. - */ - public inline fun short(size: Int, initializer: (Int) -> Short): ShortBuffer = - ShortBuffer(size, initializer) - - /** - * Creates a [IntBuffer] with the specified [size], where each element is calculated by calling the specified - * [initializer] function. - */ - public inline fun int(size: Int, initializer: (Int) -> Int): IntBuffer = - IntBuffer(size, initializer) - - /** - * Creates a [LongBuffer] with the specified [size], where each element is calculated by calling the specified - * [initializer] function. - */ - public inline fun long(size: Int, initializer: (Int) -> Long): LongBuffer = - LongBuffer(size, initializer) - - - /** - * Creates a [FloatBuffer] with the specified [size], where each element is calculated by calling the specified - * [initializer] function. - */ - public inline fun float(size: Int, initializer: (Int) -> Float): FloatBuffer = - FloatBuffer(size, initializer) - - - /** - * Create a boxing mutable buffer of given type - */ - public inline fun boxing(size: Int, initializer: (Int) -> T): MutableBuffer = - MutableListBuffer(MutableList(size, initializer)) - - /** - * Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used - * ([IntBuffer], [DoubleBuffer], etc.), [ListBuffer] is returned otherwise. - * - * The [size] is specified, and each element is calculated by calling the specified [initializer] function. - */ - @Suppress("UNCHECKED_CAST") - public inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer = - when (type) { - Double::class -> double(size) { initializer(it) as Double } as MutableBuffer - Short::class -> short(size) { initializer(it) as Short } as MutableBuffer - Int::class -> int(size) { initializer(it) as Int } as MutableBuffer - Float::class -> float(size) { initializer(it) as Float } as MutableBuffer - Long::class -> long(size) { initializer(it) as Long } as MutableBuffer - else -> boxing(size, initializer) - } - - /** - * Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used - * ([IntBuffer], [DoubleBuffer], etc.), [ListBuffer] is returned otherwise. - * - * The [size] is specified, and each element is calculated by calling the specified [initializer] function. - */ - @Suppress("UNCHECKED_CAST") - public inline fun auto(size: Int, initializer: (Int) -> T): MutableBuffer = - auto(T::class, size, initializer) - } -} - -/** - * [Buffer] implementation over [List]. - * - * @param T the type of elements contained in the buffer. - * @property list The underlying list. - */ -public class ListBuffer(public val list: List) : Buffer { - - public constructor(size: Int, initializer: (Int) -> T) : this(List(size, initializer)) - - override val size: Int get() = list.size - - override operator fun get(index: Int): T = list[index] - override operator fun iterator(): Iterator = list.iterator() -} - -/** - * Returns an [ListBuffer] that wraps the original list. - */ -public fun List.asBuffer(): ListBuffer = ListBuffer(this) - -/** - * [MutableBuffer] implementation over [MutableList]. - * - * @param T the type of elements contained in the buffer. - * @property list The underlying list. - */ -@JvmInline -public value class MutableListBuffer(public val list: MutableList) : MutableBuffer { - - public constructor(size: Int, initializer: (Int) -> T) : this(MutableList(size, initializer)) - - override val size: Int get() = list.size - - override operator fun get(index: Int): T = list[index] - - override operator fun set(index: Int, value: T) { - list[index] = value - } - - override operator fun iterator(): Iterator = list.iterator() - override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) -} - -/** - * Returns an [MutableListBuffer] that wraps the original list. - */ -public fun MutableList.asMutableBuffer(): MutableListBuffer = MutableListBuffer(this) - -/** - * [MutableBuffer] implementation over [Array]. - * - * @param T the type of elements contained in the buffer. - * @property array The underlying array. - */ -public class ArrayBuffer(internal val array: Array) : MutableBuffer { - // Can't inline because array is invariant - override val size: Int get() = array.size - - override operator fun get(index: Int): T = array[index] - - override operator fun set(index: Int, value: T) { - array[index] = value - } - - override operator fun iterator(): Iterator = array.iterator() - override fun copy(): MutableBuffer = ArrayBuffer(array.copyOf()) -} - - -/** - * Returns an [ArrayBuffer] that wraps the original array. - */ -public fun Array.asBuffer(): ArrayBuffer = ArrayBuffer(this) - /** * Immutable wrapper for [MutableBuffer]. * @@ -290,6 +133,8 @@ public class VirtualBuffer(override val size: Int, private val generator: } override operator fun iterator(): Iterator = (0 until size).asSequence().map(generator).iterator() + + override fun toString(): String = Buffer.toString(this) } /** diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/BufferAccessor2D.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/BufferAccessor2D.kt index 352c75956..d6a48f42d 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/BufferAccessor2D.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/BufferAccessor2D.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -13,31 +13,31 @@ import space.kscience.kmath.nd.as2D /** * A context that allows to operate on a [MutableBuffer] as on 2d array */ -internal class BufferAccessor2D( - public val rowNum: Int, - public val colNum: Int, +internal class BufferAccessor2D( + val rowNum: Int, + val colNum: Int, val factory: MutableBufferFactory, ) { - public operator fun Buffer.get(i: Int, j: Int): T = get(i * colNum + j) + operator fun Buffer.get(i: Int, j: Int): T = get(i * colNum + j) - public operator fun MutableBuffer.set(i: Int, j: Int, value: T) { + operator fun MutableBuffer.set(i: Int, j: Int, value: T) { set(i * colNum + j, value) } - public inline fun create(crossinline init: (i: Int, j: Int) -> T): MutableBuffer = + inline fun create(crossinline init: (i: Int, j: Int) -> T): MutableBuffer = factory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) } - public fun create(mat: Structure2D): MutableBuffer = create { i, j -> mat[i, j] } + fun create(mat: Structure2D): MutableBuffer = create { i, j -> mat[i, j] } //TODO optimize wrapper - public fun MutableBuffer.collect(): Structure2D = StructureND.buffered( + fun MutableBuffer.collect(): Structure2D = StructureND.buffered( DefaultStrides(intArrayOf(rowNum, colNum)), factory ) { (i, j) -> get(i, j) }.as2D() - public inner class Row(public val buffer: MutableBuffer, public val rowIndex: Int) : MutableBuffer { + inner class Row(val buffer: MutableBuffer, val rowIndex: Int) : MutableBuffer { override val size: Int get() = colNum override operator fun get(index: Int): T = buffer[rowIndex, index] @@ -49,10 +49,12 @@ internal class BufferAccessor2D( override fun copy(): MutableBuffer = factory(colNum) { get(it) } override operator fun iterator(): Iterator = (0 until colNum).map(::get).iterator() + override fun toString(): String = Buffer.toString(this) + } /** * Get row */ - public fun MutableBuffer.row(i: Int): Row = Row(this, i) + fun MutableBuffer.row(i: Int): Row = Row(this, i) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/DoubleBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/DoubleBuffer.kt index b4ef37598..3b554ab07 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/DoubleBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/DoubleBuffer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -25,6 +25,12 @@ public value class DoubleBuffer(public val array: DoubleArray) : MutableBuffer Double): DoubleBuffer = DoubleBuffer(DoubleArray(size) { init(it) }) @@ -47,7 +53,7 @@ public fun DoubleBuffer(vararg doubles: Double): DoubleBuffer = DoubleBuffer(dou public fun DoubleBuffer.contentEquals(vararg doubles: Double): Boolean = array.contentEquals(doubles) /** - * Returns a new [DoubleArray] containing all of the elements of this [Buffer]. + * Returns a new [DoubleArray] containing all the elements of this [Buffer]. */ public fun Buffer.toDoubleArray(): DoubleArray = when (this) { is DoubleBuffer -> array.copyOf() diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/DoubleBufferField.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/DoubleBufferField.kt deleted file mode 100644 index 34b5e373b..000000000 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/DoubleBufferField.kt +++ /dev/null @@ -1,277 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.structures - -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.ExtendedFieldOperations -import kotlin.math.* - -/** - * [ExtendedFieldOperations] over [DoubleBuffer]. - */ -public object DoubleBufferFieldOperations : ExtendedFieldOperations> { - override fun Buffer.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) { - DoubleBuffer(size) { -array[it] } - } else { - DoubleBuffer(size) { -get(it) } - } - - public override fun add(a: Buffer, b: Buffer): DoubleBuffer { - require(b.size == a.size) { - "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " - } - - return if (a is DoubleBuffer && b is DoubleBuffer) { - val aArray = a.array - val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) - } else DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) - } -// -// public override fun multiply(a: Buffer, k: Number): RealBuffer { -// val kValue = k.toDouble() -// -// return if (a is RealBuffer) { -// val aArray = a.array -// RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) -// } else RealBuffer(DoubleArray(a.size) { a[it] * kValue }) -// } -// -// public override fun divide(a: Buffer, k: Number): RealBuffer { -// val kValue = k.toDouble() -// -// return if (a is RealBuffer) { -// val aArray = a.array -// RealBuffer(DoubleArray(a.size) { aArray[it] / kValue }) -// } else RealBuffer(DoubleArray(a.size) { a[it] / kValue }) -// } - - public override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { - require(b.size == a.size) { - "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " - } - - return if (a is DoubleBuffer && b is DoubleBuffer) { - val aArray = a.array - val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) - } else - DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) - } - - public override fun divide(a: Buffer, b: Buffer): DoubleBuffer { - require(b.size == a.size) { - "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " - } - - return if (a is DoubleBuffer && b is DoubleBuffer) { - val aArray = a.array - val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) - } else DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) - } - - public override fun sin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) - } else DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) - - public override fun cos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) - } else DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) - - public override fun tan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { tan(array[it]) }) - } else DoubleBuffer(DoubleArray(arg.size) { tan(arg[it]) }) - - public override fun asin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { asin(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { asin(arg[it]) }) - - public override fun acos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { acos(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { acos(arg[it]) }) - - public override fun atan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { atan(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { atan(arg[it]) }) - - public override fun sinh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sinh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { sinh(arg[it]) }) - - public override fun cosh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cosh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { cosh(arg[it]) }) - - public override fun tanh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { tanh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { tanh(arg[it]) }) - - public override fun asinh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { asinh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { asinh(arg[it]) }) - - public override fun acosh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { acosh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { acosh(arg[it]) }) - - public override fun atanh(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { atanh(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { atanh(arg[it]) }) - - public override fun power(arg: Buffer, pow: Number): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - - public override fun exp(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - - public override fun ln(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) -} - -/** - * [ExtendedField] over [DoubleBuffer]. - * - * @property size the size of buffers to operate on. - */ -public class DoubleBufferField(public val size: Int) : ExtendedField> { - public override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } - public override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } - - override fun number(value: Number): Buffer = DoubleBuffer(size) { value.toDouble() } - - override fun Buffer.unaryMinus(): Buffer = DoubleBufferFieldOperations.run { - -this@unaryMinus - } - - public override fun add(a: Buffer, b: Buffer): DoubleBuffer { - require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } - return DoubleBufferFieldOperations.add(a, b) - } - - public override fun scale(a: Buffer, value: Double): DoubleBuffer { - require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } - - return if (a is DoubleBuffer) { - val aArray = a.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * value }) - } else DoubleBuffer(DoubleArray(a.size) { a[it] * value }) - } - - public override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { - require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } - return DoubleBufferFieldOperations.multiply(a, b) - } - - public override fun divide(a: Buffer, b: Buffer): DoubleBuffer { - require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } - return DoubleBufferFieldOperations.divide(a, b) - } - - public override fun sin(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.sin(arg) - } - - public override fun cos(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.cos(arg) - } - - public override fun tan(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.tan(arg) - } - - public override fun asin(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.asin(arg) - } - - public override fun acos(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.acos(arg) - } - - public override fun atan(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.atan(arg) - } - - public override fun sinh(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.sinh(arg) - } - - public override fun cosh(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.cosh(arg) - } - - public override fun tanh(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.tanh(arg) - } - - public override fun asinh(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.asinh(arg) - } - - public override fun acosh(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.acosh(arg) - } - - public override fun atanh(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.atanh(arg) - } - - public override fun power(arg: Buffer, pow: Number): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.power(arg, pow) - } - - public override fun exp(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.exp(arg) - } - - public override fun ln(arg: Buffer): DoubleBuffer { - require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } - return DoubleBufferFieldOperations.ln(arg) - } -} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FlaggedBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FlaggedBuffer.kt index 0b16a3afc..700a4f17f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FlaggedBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FlaggedBuffer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -51,10 +51,12 @@ public fun FlaggedBuffer<*>.hasFlag(index: Int, flag: ValueFlag): Boolean = (get public fun FlaggedBuffer<*>.isMissing(index: Int): Boolean = hasFlag(index, ValueFlag.MISSING) /** - * A real buffer which supports flags for each value like NaN or Missing + * A [Double] buffer that supports flags for each value like `NaN` or Missing. */ -public class FlaggedDoubleBuffer(public val values: DoubleArray, public val flags: ByteArray) : FlaggedBuffer, - Buffer { +public class FlaggedDoubleBuffer( + public val values: DoubleArray, + public val flags: ByteArray +) : FlaggedBuffer, Buffer { init { require(values.size == flags.size) { "Values and flags must have the same dimensions" } } @@ -68,6 +70,8 @@ public class FlaggedDoubleBuffer(public val values: DoubleArray, public val flag override operator fun iterator(): Iterator = values.indices.asSequence().map { if (isValid(it)) values[it] else null }.iterator() + + override fun toString(): String = Buffer.toString(this) } public inline fun FlaggedDoubleBuffer.forEachValid(block: (Double) -> Unit) { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FloatBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FloatBuffer.kt index 58b7c6aea..dc7903cbf 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FloatBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/FloatBuffer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -34,7 +34,7 @@ public value class FloatBuffer(public val array: FloatArray) : MutableBuffer Float): FloatBuffer = FloatBuffer(FloatArray(size) { init(it) }) @@ -44,7 +44,7 @@ public inline fun FloatBuffer(size: Int, init: (Int) -> Float): FloatBuffer = Fl public fun FloatBuffer(vararg floats: Float): FloatBuffer = FloatBuffer(floats) /** - * Returns a new [FloatArray] containing all of the elements of this [Buffer]. + * Returns a new [FloatArray] containing all the elements of this [Buffer]. */ public fun Buffer.toFloatArray(): FloatArray = when (this) { is FloatBuffer -> array.copyOf() diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/IntBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/IntBuffer.kt index 57b6cfde3..ca078746c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/IntBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/IntBuffer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -33,7 +33,7 @@ public value class IntBuffer(public val array: IntArray) : MutableBuffer { * [init] function. * * The function [init] is called for each array element sequentially starting from the first one. - * It should return the value for an buffer element given its index. + * It should return the value for a buffer element given its index. */ public inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffer(IntArray(size) { init(it) }) @@ -43,7 +43,7 @@ public inline fun IntBuffer(size: Int, init: (Int) -> Int): IntBuffer = IntBuffe public fun IntBuffer(vararg ints: Int): IntBuffer = IntBuffer(ints) /** - * Returns a new [IntArray] containing all of the elements of this [Buffer]. + * Returns a new [IntArray] containing all the elements of this [Buffer]. */ public fun Buffer.toIntArray(): IntArray = when (this) { is IntBuffer -> array.copyOf() diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ListBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ListBuffer.kt new file mode 100644 index 000000000..666722177 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ListBuffer.kt @@ -0,0 +1,60 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.structures + +import kotlin.jvm.JvmInline + +/** + * [Buffer] implementation over [List]. + * + * @param T the type of elements contained in the buffer. + * @property list The underlying list. + */ +public class ListBuffer(public val list: List) : Buffer { + + public constructor(size: Int, initializer: (Int) -> T) : this(List(size, initializer)) + + override val size: Int get() = list.size + + override operator fun get(index: Int): T = list[index] + override operator fun iterator(): Iterator = list.iterator() + + override fun toString(): String = Buffer.toString(this) +} + + +/** + * Returns an [ListBuffer] that wraps the original list. + */ +public fun List.asBuffer(): ListBuffer = ListBuffer(this) + +/** + * [MutableBuffer] implementation over [MutableList]. + * + * @param T the type of elements contained in the buffer. + * @property list The underlying list. + */ +@JvmInline +public value class MutableListBuffer(public val list: MutableList) : MutableBuffer { + + public constructor(size: Int, initializer: (Int) -> T) : this(MutableList(size, initializer)) + + override val size: Int get() = list.size + + override operator fun get(index: Int): T = list[index] + + override operator fun set(index: Int, value: T) { + list[index] = value + } + + override operator fun iterator(): Iterator = list.iterator() + override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) +} + +/** + * Returns an [MutableListBuffer] that wraps the original list. + */ +public fun MutableList.asMutableBuffer(): MutableListBuffer = MutableListBuffer(this) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/LongBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/LongBuffer.kt index 57affa1c5..a0b5c78fa 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/LongBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/LongBuffer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -33,7 +33,7 @@ public value class LongBuffer(public val array: LongArray) : MutableBuffer * [init] function. * * The function [init] is called for each array element sequentially starting from the first one. - * It should return the value for an buffer element given its index. + * It should return the value for a buffer element given its index. */ public inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongBuffer(LongArray(size) { init(it) }) @@ -43,7 +43,7 @@ public inline fun LongBuffer(size: Int, init: (Int) -> Long): LongBuffer = LongB public fun LongBuffer(vararg longs: Long): LongBuffer = LongBuffer(longs) /** - * Returns a new [LongArray] containing all of the elements of this [Buffer]. + * Returns a new [LongArray] containing all the elements of this [Buffer]. */ public fun Buffer.toLongArray(): LongArray = when (this) { is LongBuffer -> array.copyOf() diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MemoryBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MemoryBuffer.kt index 8c98ab9c8..3e08dbbb1 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MemoryBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MemoryBuffer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -22,6 +22,8 @@ public open class MemoryBuffer(protected val memory: Memory, protected override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index) override operator fun iterator(): Iterator = (0 until size).asSequence().map { get(it) }.iterator() + override fun toString(): String = Buffer.toString(this) + public companion object { public fun create(spec: MemorySpec, size: Int): MemoryBuffer = MemoryBuffer(Memory.allocate(size * spec.objectSize), spec) @@ -48,8 +50,8 @@ public class MutableMemoryBuffer(memory: Memory, spec: MemorySpec) : private val writer: MemoryWriter = memory.writer() - public override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) - public override fun copy(): MutableBuffer = MutableMemoryBuffer(memory.copy(), spec) + override operator fun set(index: Int, value: T): Unit = writer.write(spec, spec.objectSize * index, value) + override fun copy(): MutableBuffer = MutableMemoryBuffer(memory.copy(), spec) public companion object { public fun create(spec: MemorySpec, size: Int): MutableMemoryBuffer = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MutableBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MutableBuffer.kt new file mode 100644 index 000000000..97185b918 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/MutableBuffer.kt @@ -0,0 +1,97 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.structures + +import kotlin.reflect.KClass + +/** + * A generic mutable random-access structure for both primitives and objects. + * + * @param T the type of elements contained in the buffer. + */ +public interface MutableBuffer : Buffer { + /** + * Sets the array element at the specified [index] to the specified [value]. + */ + public operator fun set(index: Int, value: T) + + /** + * Returns a shallow copy of the buffer. + */ + public fun copy(): MutableBuffer + + public companion object { + /** + * Creates a [DoubleBuffer] with the specified [size], where each element is calculated by calling the specified + * [initializer] function. + */ + public inline fun double(size: Int, initializer: (Int) -> Double): DoubleBuffer = + DoubleBuffer(size, initializer) + + /** + * Creates a [ShortBuffer] with the specified [size], where each element is calculated by calling the specified + * [initializer] function. + */ + public inline fun short(size: Int, initializer: (Int) -> Short): ShortBuffer = + ShortBuffer(size, initializer) + + /** + * Creates a [IntBuffer] with the specified [size], where each element is calculated by calling the specified + * [initializer] function. + */ + public inline fun int(size: Int, initializer: (Int) -> Int): IntBuffer = + IntBuffer(size, initializer) + + /** + * Creates a [LongBuffer] with the specified [size], where each element is calculated by calling the specified + * [initializer] function. + */ + public inline fun long(size: Int, initializer: (Int) -> Long): LongBuffer = + LongBuffer(size, initializer) + + + /** + * Creates a [FloatBuffer] with the specified [size], where each element is calculated by calling the specified + * [initializer] function. + */ + public inline fun float(size: Int, initializer: (Int) -> Float): FloatBuffer = + FloatBuffer(size, initializer) + + + /** + * Create a boxing mutable buffer of given type + */ + public inline fun boxing(size: Int, initializer: (Int) -> T): MutableBuffer = + MutableListBuffer(MutableList(size, initializer)) + + /** + * Creates a [MutableBuffer] of given [type]. If the type is primitive, specialized buffers are used + * ([IntBuffer], [DoubleBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(type: KClass, size: Int, initializer: (Int) -> T): MutableBuffer = + when (type) { + Double::class -> double(size) { initializer(it) as Double } as MutableBuffer + Short::class -> short(size) { initializer(it) as Short } as MutableBuffer + Int::class -> int(size) { initializer(it) as Int } as MutableBuffer + Float::class -> float(size) { initializer(it) as Float } as MutableBuffer + Long::class -> long(size) { initializer(it) as Long } as MutableBuffer + else -> boxing(size, initializer) + } + + /** + * Creates a [MutableBuffer] of given type [T]. If the type is primitive, specialized buffers are used + * ([IntBuffer], [DoubleBuffer], etc.), [ListBuffer] is returned otherwise. + * + * The [size] is specified, and each element is calculated by calling the specified [initializer] function. + */ + @Suppress("UNCHECKED_CAST") + public inline fun auto(size: Int, initializer: (Int) -> T): MutableBuffer = + auto(T::class, size, initializer) + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ShortBuffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ShortBuffer.kt index 3d4c68b3c..1d2b0188a 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ShortBuffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/ShortBuffer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -14,16 +14,16 @@ import kotlin.jvm.JvmInline */ @JvmInline public value class ShortBuffer(public val array: ShortArray) : MutableBuffer { - public override val size: Int get() = array.size + override val size: Int get() = array.size - public override operator fun get(index: Int): Short = array[index] + override operator fun get(index: Int): Short = array[index] - public override operator fun set(index: Int, value: Short) { + override operator fun set(index: Int, value: Short) { array[index] = value } - public override operator fun iterator(): ShortIterator = array.iterator() - public override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) + override operator fun iterator(): ShortIterator = array.iterator() + override fun copy(): MutableBuffer = ShortBuffer(array.copyOf()) } /** @@ -31,7 +31,7 @@ public value class ShortBuffer(public val array: ShortArray) : MutableBuffer Short): ShortBuffer = ShortBuffer(ShortArray(size) { init(it) }) @@ -41,7 +41,7 @@ public inline fun ShortBuffer(size: Int, init: (Int) -> Short): ShortBuffer = Sh public fun ShortBuffer(vararg shorts: Short): ShortBuffer = ShortBuffer(shorts) /** - * Returns a new [ShortArray] containing all of the elements of this [Buffer]. + * Returns a new [ShortArray] containing all the elements of this [Buffer]. */ public fun Buffer.toShortArray(): ShortArray = when (this) { is ShortBuffer -> array.copyOf() diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/ExpressionFieldTest.kt index 4d1b00b3d..d0b3c7751 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/ExpressionFieldTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions @@ -16,7 +16,7 @@ class ExpressionFieldTest { @Test fun testExpression() { val expression = with(FunctionalExpressionField(DoubleField)) { - val x by binding() + val x by binding x * x + 2 * x + one } @@ -27,7 +27,7 @@ class ExpressionFieldTest { @Test fun separateContext() { fun FunctionalExpressionField.expression(): Expression { - val x by binding() + val x by binding return x * x + 2 * x + one } @@ -38,7 +38,7 @@ class ExpressionFieldTest { @Test fun valueExpression() { val expressionBuilder: FunctionalExpressionField.() -> Expression = { - val x by binding() + val x by binding x * x + 2 * x + one } diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/InterpretTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/InterpretTest.kt index 156334b2e..8bf852653 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/InterpretTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/InterpretTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt index 201890933..7d8ff6202 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.expressions diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/DoubleLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/DoubleLUSolverTest.kt index 2d2a0952b..79153d95d 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/DoubleLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/DoubleLUSolverTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.linear @@ -8,6 +8,7 @@ package space.kscience.kmath.linear import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.algebra import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -21,39 +22,37 @@ fun assertMatrixEquals(expected: StructureND, actual: StructureND = VirtualMatrix(6, 6) { row, col -> when { col == 0 -> .50 @@ -48,7 +50,7 @@ class MatrixTest { infix fun Matrix.pow(power: Int): Matrix { var res = this repeat(power - 1) { - res = LinearSpace.real.run { res dot this@pow } + res = res dot this@pow } return res } @@ -57,19 +59,18 @@ class MatrixTest { } @Test - fun test2DDot() { + fun test2DDot() = Double.algebra.linearSpace.run { val firstMatrix = StructureND.auto(2, 3) { (i, j) -> (i + j).toDouble() }.as2D() val secondMatrix = StructureND.auto(3, 2) { (i, j) -> (i + j).toDouble() }.as2D() - LinearSpace.real.run { // val firstMatrix = produce(2, 3) { i, j -> (i + j).toDouble() } // val secondMatrix = produce(3, 2) { i, j -> (i + j).toDouble() } - val result = firstMatrix dot secondMatrix - assertEquals(2, result.rowNum) - assertEquals(2, result.colNum) - assertEquals(8.0, result[0, 1]) - assertEquals(8.0, result[1, 0]) - assertEquals(14.0, result[1, 1]) - } + val result = firstMatrix dot secondMatrix + assertEquals(2, result.rowNum) + assertEquals(2, result.colNum) + assertEquals(8.0, result[0, 1]) + assertEquals(8.0, result[1, 0]) + assertEquals(14.0, result[1, 1]) + } } diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/misc/CumulativeKtTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/misc/CumulativeKtTest.kt index e5f3f337f..aa7abd8ff 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/misc/CumulativeKtTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/misc/CumulativeKtTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.misc diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntAlgebraTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntAlgebraTest.kt index 0527f5252..75100b116 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntAlgebraTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntAlgebraTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntConstructorTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntConstructorTest.kt index eec3dc3bf..c121c86ae 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntConstructorTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntConstructorTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntConversionsTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntConversionsTest.kt index 85f368f3e..78dcdfe19 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntConversionsTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntConversionsTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntOperationsTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntOperationsTest.kt index 26d6af224..11b8b161c 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntOperationsTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/BigIntOperationsTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/DoubleFieldTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/DoubleFieldTest.kt index 76171fedd..9be75d68e 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/DoubleFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/operations/DoubleFieldTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt index fdfa49d1d..82172af62 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NDFieldTest.kt @@ -1,13 +1,14 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures -import space.kscience.kmath.nd.AlgebraND import space.kscience.kmath.nd.get -import space.kscience.kmath.nd.real +import space.kscience.kmath.nd.ndAlgebra +import space.kscience.kmath.nd.structureND +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke import space.kscience.kmath.testutils.FieldVerifier import kotlin.test.Test @@ -16,12 +17,12 @@ import kotlin.test.assertEquals internal class NDFieldTest { @Test fun verify() { - (AlgebraND.real(12, 32)) { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) } + (DoubleField.ndAlgebra(12, 32)) { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) } } @Test fun testStrides() { - val ndArray = AlgebraND.real(10, 10).produce { (it[0] + it[1]).toDouble() } + val ndArray = DoubleField.ndAlgebra.structureND(10, 10) { (it[0] + it[1]).toDouble() } assertEquals(ndArray[5, 5], 10.0) } } diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt index fb51553f7..61eb6acc8 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/structures/NumberNDFieldTest.kt @@ -1,14 +1,16 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures -import space.kscience.kmath.linear.LinearSpace +import space.kscience.kmath.linear.linearSpace import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.* +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Norm +import space.kscience.kmath.operations.algebra import space.kscience.kmath.operations.invoke import kotlin.math.abs import kotlin.math.pow @@ -17,9 +19,9 @@ import kotlin.test.assertEquals @Suppress("UNUSED_VARIABLE") class NumberNDFieldTest { - val algebra = AlgebraND.real(3, 3) - val array1 = algebra.produce { (i, j) -> (i + j).toDouble() } - val array2 = algebra.produce { (i, j) -> (i - j).toDouble() } + val algebra = DoubleField.ndAlgebra + val array1 = algebra.structureND(3, 3) { (i, j) -> (i + j).toDouble() } + val array2 = algebra.structureND(3, 3) { (i, j) -> (i - j).toDouble() } @Test fun testSum() { @@ -38,17 +40,18 @@ class NumberNDFieldTest { } @Test - fun testGeneration() { + fun testGeneration() = Double.algebra.linearSpace.run { - val array = LinearSpace.real.buildMatrix(3, 3) { i, j -> + val array = buildMatrix(3, 3) { i, j -> (i * 10 + j).toDouble() } - for (i in 0..2) + for (i in 0..2) { for (j in 0..2) { val expected = (i * 10 + j).toDouble() assertEquals(expected, array[i, j], "Error at index [$i, $j]") } + } } @Test @@ -71,7 +74,7 @@ class NumberNDFieldTest { @Test fun combineTest() { - val division = array1.combine(array2, Double::div) + val division = array1.zip(array2, Double::div) } object L2Norm : Norm, Double> { @@ -83,7 +86,7 @@ class NumberNDFieldTest { @Test fun testInternalContext() { algebra { - (AlgebraND.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } } + (DoubleField.ndAlgebra(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } } } } } diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/AlgebraicVerifier.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/AlgebraicVerifier.kt index ddd8fc3ea..544e05707 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/AlgebraicVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/AlgebraicVerifier.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.testutils diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/FieldVerifier.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/FieldVerifier.kt index bd09ff449..d0a312bb2 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/FieldVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/FieldVerifier.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.testutils @@ -10,7 +10,7 @@ import space.kscience.kmath.operations.invoke import kotlin.test.assertEquals import kotlin.test.assertNotEquals -internal class FieldVerifier>( +internal class FieldVerifier>( algebra: A, a: T, b: T, c: T, x: Number, ) : RingVerifier(algebra, a, b, c, x) { diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/RingVerifier.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/RingVerifier.kt index 885857f04..3b0b49f31 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/RingVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/RingVerifier.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.testutils @@ -10,7 +10,7 @@ import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.invoke import kotlin.test.assertEquals -internal open class RingVerifier(algebra: A, a: T, b: T, c: T, x: Number) : +internal open class RingVerifier(algebra: A, a: T, b: T, c: T, x: Number) : SpaceVerifier(algebra, a, b, c, x) where A : Ring, A : ScaleOperations { override fun verify() { diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/SpaceVerifier.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/SpaceVerifier.kt index 951197fc6..4afa97ce5 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/SpaceVerifier.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/testutils/SpaceVerifier.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.testutils diff --git a/kmath-core/src/jsMain/kotlin/space/kscience/kmath/misc/numbers.kt b/kmath-core/src/jsMain/kotlin/space/kscience/kmath/misc/numbers.kt new file mode 100644 index 000000000..a24243cb4 --- /dev/null +++ b/kmath-core/src/jsMain/kotlin/space/kscience/kmath/misc/numbers.kt @@ -0,0 +1,12 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.misc + +public actual fun Long.toIntExact(): Int { + val i = toInt() + if (i.toLong() == this) throw ArithmeticException("integer overflow") + return i +} diff --git a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/misc/numbersJVM.kt b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/misc/numbersJVM.kt new file mode 100644 index 000000000..c50919e88 --- /dev/null +++ b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/misc/numbersJVM.kt @@ -0,0 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.misc + +public actual fun Long.toIntExact(): Int = Math.toIntExact(this) diff --git a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt index 9b46369bb..f63efbef2 100644 --- a/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/space/kscience/kmath/operations/BigNumbers.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.operations @@ -13,16 +13,16 @@ import java.math.MathContext * A field over [BigInteger]. */ public object JBigIntegerField : Ring, NumericAlgebra { - public override val zero: BigInteger get() = BigInteger.ZERO + override val zero: BigInteger get() = BigInteger.ZERO - public override val one: BigInteger get() = BigInteger.ONE + override val one: BigInteger get() = BigInteger.ONE - public override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) - public override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) - public override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b) - public override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) + override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) + override fun add(left: BigInteger, right: BigInteger): BigInteger = left.add(right) + override operator fun BigInteger.minus(arg: BigInteger): BigInteger = subtract(arg) + override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right) - public override operator fun BigInteger.unaryMinus(): BigInteger = negate() + override operator fun BigInteger.unaryMinus(): BigInteger = negate() } /** @@ -33,24 +33,24 @@ public object JBigIntegerField : Ring, NumericAlgebra { public abstract class JBigDecimalFieldBase internal constructor( private val mathContext: MathContext = MathContext.DECIMAL64, ) : Field, PowerOperations, NumericAlgebra, ScaleOperations { - public override val zero: BigDecimal + override val zero: BigDecimal get() = BigDecimal.ZERO - public override val one: BigDecimal + override val one: BigDecimal get() = BigDecimal.ONE - public override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) - public override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) - public override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) + override fun add(left: BigDecimal, right: BigDecimal): BigDecimal = left.add(right) + override operator fun BigDecimal.minus(arg: BigDecimal): BigDecimal = subtract(arg) + override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) - public override fun scale(a: BigDecimal, value: Double): BigDecimal = + override fun scale(a: BigDecimal, value: Double): BigDecimal = a.multiply(value.toBigDecimal(mathContext), mathContext) - public override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) - public override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) - public override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) - public override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) - public override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) + override fun multiply(left: BigDecimal, right: BigDecimal): BigDecimal = left.multiply(right, mathContext) + override fun divide(left: BigDecimal, right: BigDecimal): BigDecimal = left.divide(right, mathContext) + override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) + override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) + override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) } /** diff --git a/kmath-core/src/nativeMain/kotlin/space/kscience/kmath/misc/numbers.kt b/kmath-core/src/nativeMain/kotlin/space/kscience/kmath/misc/numbers.kt new file mode 100644 index 000000000..a24243cb4 --- /dev/null +++ b/kmath-core/src/nativeMain/kotlin/space/kscience/kmath/misc/numbers.kt @@ -0,0 +1,12 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.misc + +public actual fun Long.toIntExact(): Int { + val i = toInt() + if (i.toLong() == this) throw ArithmeticException("integer overflow") + return i +} diff --git a/kmath-coroutines/build.gradle.kts b/kmath-coroutines/build.gradle.kts index 1546e7d96..317691ae5 100644 --- a/kmath-coroutines/build.gradle.kts +++ b/kmath-coroutines/build.gradle.kts @@ -1,6 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") + id("ru.mipt.npm.gradle.native") } kotlin.sourceSets { diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt index 70849f942..a41a30f55 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.chains @@ -32,9 +32,9 @@ public interface BlockingBufferChain : BlockingChain, BufferChain { public fun nextBufferBlocking(size: Int): Buffer - public override fun nextBlocking(): T = nextBufferBlocking(1)[0] + override fun nextBlocking(): T = nextBufferBlocking(1)[0] - public override suspend fun nextBuffer(size: Int): Buffer = nextBufferBlocking(size) + override suspend fun nextBuffer(size: Int): Buffer = nextBufferBlocking(size) override suspend fun fork(): BlockingBufferChain } diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingDoubleChain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingDoubleChain.kt index 526250cf0..7b4d1f2af 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingDoubleChain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingDoubleChain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.chains @@ -15,7 +15,7 @@ public interface BlockingDoubleChain : BlockingBufferChain { /** * Returns an [DoubleArray] chunk of [size] values of [next]. */ - public override fun nextBufferBlocking(size: Int): DoubleBuffer + override fun nextBufferBlocking(size: Int): DoubleBuffer override suspend fun fork(): BlockingDoubleChain @@ -29,4 +29,4 @@ public fun BlockingDoubleChain.map(transform: (Double) -> Double): BlockingDoubl } override suspend fun fork(): BlockingDoubleChain = this@map.fork().map(transform) -} \ No newline at end of file +} diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingIntChain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingIntChain.kt index ac0327d0b..f13d9907c 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingIntChain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingIntChain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.chains diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt index b29165e32..403472f28 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.chains @@ -13,7 +13,7 @@ import kotlinx.coroutines.sync.withLock /** * A not-necessary-Markov chain of some type - * @param T - the chain element type + * @param T the chain element type */ public interface Chain : Flow { /** @@ -22,7 +22,7 @@ public interface Chain : Flow { public suspend fun next(): T /** - * Create a copy of current chain state. Consuming resulting chain does not affect initial chain + * Create a copy of current chain state. Consuming resulting chain does not affect initial chain. */ public suspend fun fork(): Chain @@ -39,8 +39,8 @@ public fun Sequence.asChain(): Chain = iterator().asChain() * A simple chain of independent tokens. [fork] returns the same chain. */ public class SimpleChain(private val gen: suspend () -> R) : Chain { - public override suspend fun next(): R = gen() - public override suspend fun fork(): Chain = this + override suspend fun next(): R = gen() + override suspend fun fork(): Chain = this } /** @@ -52,19 +52,21 @@ public class MarkovChain(private val seed: suspend () -> R, private public fun value(): R? = value - public override suspend fun next(): R = mutex.withLock { + override suspend fun next(): R = mutex.withLock { val newValue = gen(value ?: seed()) value = newValue newValue } - public override suspend fun fork(): Chain = MarkovChain(seed = { value ?: seed() }, gen = gen) + override suspend fun fork(): Chain = MarkovChain(seed = { value ?: seed() }, gen = gen) } /** - * A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share the state - * @param S - the state of the chain - * @param forkState - the function to copy current state without modifying it + * A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share + * the state. + * + * @param S the state of the chain. + * @param forkState the function to copy current state without modifying it. */ public class StatefulChain( private val state: S, @@ -77,26 +79,26 @@ public class StatefulChain( public fun value(): R? = value - public override suspend fun next(): R = mutex.withLock { + override suspend fun next(): R = mutex.withLock { val newValue = state.gen(value ?: state.seed()) value = newValue newValue } - public override suspend fun fork(): Chain = StatefulChain(forkState(state), seed, forkState, gen) + override suspend fun fork(): Chain = StatefulChain(forkState(state), seed, forkState, gen) } /** * A chain that repeats the same value */ public class ConstantChain(public val value: T) : Chain { - public override suspend fun next(): T = value - public override suspend fun fork(): Chain = this + override suspend fun next(): T = value + override suspend fun fork(): Chain = this } /** * Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed - * since mapped chain consumes tokens. Accepts regular transformation function + * since mapped chain consumes tokens. Accepts regular transformation function. */ public fun Chain.map(func: suspend (T) -> R): Chain = object : Chain { override suspend fun next(): R = func(this@map.next()) diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt index ec1203740..1620f029c 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/flowExtra.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.chains @@ -10,12 +10,12 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.runningReduce import kotlinx.coroutines.flow.scan -import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.GroupOps import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.invoke -public fun Flow.cumulativeSum(group: GroupOperations): Flow = +public fun Flow.cumulativeSum(group: GroupOps): Flow = group { runningReduce { sum, element -> sum + element } } @ExperimentalCoroutinesApi diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/coroutines/coroutinesExtra.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/coroutines/coroutinesExtra.kt index 49f32f82f..3b90222dd 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/coroutines/coroutinesExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/coroutines/coroutinesExtra.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.coroutines @@ -15,30 +15,33 @@ public val Dispatchers.Math: CoroutineDispatcher /** * An imitator of [Deferred] which holds a suspended function block and dispatcher */ -internal class LazyDeferred(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) { +@PublishedApi +internal class LazyDeferred(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) { private var deferred: Deferred? = null - internal fun start(scope: CoroutineScope) { + fun start(scope: CoroutineScope) { if (deferred == null) deferred = scope.async(dispatcher, block = block) } suspend fun await(): T = deferred?.await() ?: error("Coroutine not started") } -public class AsyncFlow internal constructor(internal val deferredFlow: Flow>) : Flow { +public class AsyncFlow @PublishedApi internal constructor( + @PublishedApi internal val deferredFlow: Flow>, +) : Flow { override suspend fun collect(collector: FlowCollector): Unit = deferredFlow.collect { collector.emit((it.await())) } } -public fun Flow.async( +public inline fun Flow.async( dispatcher: CoroutineDispatcher = Dispatchers.Default, - block: suspend CoroutineScope.(T) -> R, + crossinline block: suspend CoroutineScope.(T) -> R, ): AsyncFlow { val flow = map { LazyDeferred(dispatcher) { block(it) } } return AsyncFlow(flow) } -public fun AsyncFlow.map(action: (T) -> R): AsyncFlow = +public inline fun AsyncFlow.map(crossinline action: (T) -> R): AsyncFlow = AsyncFlow(deferredFlow.map { input -> //TODO add function composition LazyDeferred(input.dispatcher) { diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/BufferFlow.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/BufferFlow.kt index 0d6a1178a..914139a3e 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/BufferFlow.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/BufferFlow.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.streaming @@ -76,7 +76,7 @@ public fun Flow.chunked(bufferSize: Int): Flow = flow { /** * Map a flow to a moving window buffer. The window step is one. - * In order to get different steps, one could use skip operation. + * To get different steps, one could use skip operation. */ public fun Flow.windowed(window: Int): Flow> = flow { require(window > 1) { "Window size must be more than one" } diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/RingBuffer.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/RingBuffer.kt index 05f2876e3..573b406e2 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/RingBuffer.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/RingBuffer.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.streaming @@ -22,10 +22,10 @@ public class RingBuffer( ) : Buffer { private val mutex: Mutex = Mutex() - public override var size: Int = size + override var size: Int = size private set - public override operator fun get(index: Int): T { + override operator fun get(index: Int): T { require(index >= 0) { "Index must be positive" } require(index < size) { "Index $index is out of circular buffer size $size" } return buffer[startIndex.forward(index)] as T @@ -36,7 +36,7 @@ public class RingBuffer( /** * Iterator could provide wrong results if buffer is changed in initialization (iteration is safe) */ - public override operator fun iterator(): Iterator = object : AbstractIterator() { + override operator fun iterator(): Iterator = object : AbstractIterator() { private var count = size private var index = startIndex val copy = buffer.copy() @@ -69,6 +69,8 @@ public class RingBuffer( @Suppress("NOTHING_TO_INLINE") private inline fun Int.forward(n: Int): Int = (this + n) % (buffer.size) + override fun toString(): String = Buffer.toString(this) + public companion object { public inline fun build(size: Int, empty: T): RingBuffer { val buffer = MutableBuffer.auto(size) { empty } as MutableBuffer diff --git a/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/chains/ChainExt.kt b/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/chains/ChainExt.kt index dd6e39071..0e36706cf 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/chains/ChainExt.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/chains/ChainExt.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.chains diff --git a/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/structures/LazyStructureND.kt b/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/structures/LazyStructureND.kt index ded8c9c44..1feb43f33 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/structures/LazyStructureND.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/space/kscience/kmath/structures/LazyStructureND.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.structures @@ -13,7 +13,7 @@ import space.kscience.kmath.nd.StructureND public class LazyStructureND( public val scope: CoroutineScope, - public override val shape: IntArray, + override val shape: IntArray, public val function: suspend (IntArray) -> T, ) : StructureND { private val cache: MutableMap> = HashMap() @@ -23,12 +23,12 @@ public class LazyStructureND( } public suspend fun await(index: IntArray): T = deferred(index).await() - public override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() } + override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() } @OptIn(PerformancePitfall::class) - public override fun elements(): Sequence> { + override fun elements(): Sequence> { val strides = DefaultStrides(shape) - val res = runBlocking { strides.indices().toList().map { index -> index to await(index) } } + val res = runBlocking { strides.asSequence().toList().map { index -> index to await(index) } } return res.asSequence() } } diff --git a/kmath-coroutines/src/jvmTest/kotlin/space/kscience/kmath/streaming/BufferFlowTest.kt b/kmath-coroutines/src/jvmTest/kotlin/space/kscience/kmath/streaming/BufferFlowTest.kt index 9b67f7253..057ac5feb 100644 --- a/kmath-coroutines/src/jvmTest/kotlin/space/kscience/kmath/streaming/BufferFlowTest.kt +++ b/kmath-coroutines/src/jvmTest/kotlin/space/kscience/kmath/streaming/BufferFlowTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.streaming diff --git a/kmath-coroutines/src/jvmTest/kotlin/space/kscience/kmath/streaming/RingBufferTest.kt b/kmath-coroutines/src/jvmTest/kotlin/space/kscience/kmath/streaming/RingBufferTest.kt index 32e3b2c74..a3143a1ac 100644 --- a/kmath-coroutines/src/jvmTest/kotlin/space/kscience/kmath/streaming/RingBufferTest.kt +++ b/kmath-coroutines/src/jvmTest/kotlin/space/kscience/kmath/streaming/RingBufferTest.kt @@ -1,13 +1,13 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.streaming import kotlinx.coroutines.flow.* import kotlinx.coroutines.runBlocking -import space.kscience.kmath.structures.asSequence +import space.kscience.kmath.operations.asSequence import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Dimensions.kt b/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt similarity index 64% rename from kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Dimensions.kt rename to kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt index 8b17d252f..53482f020 100644 --- a/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Dimensions.kt +++ b/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.dimensions @@ -8,47 +8,47 @@ package space.kscience.kmath.dimensions import kotlin.reflect.KClass /** - * Represents a quantity of dimensions in certain structure. + * Represents a quantity of dimensions in certain structure. **This interface must be implemented only by objects.** * * @property dim The number of dimensions. */ public interface Dimension { - public val dim: UInt + public val dim: Int public companion object } -public fun KClass.dim(): UInt = Dimension.resolve(this).dim +public fun KClass.dim(): Int = Dimension.resolve(this).dim public expect fun Dimension.Companion.resolve(type: KClass): D /** * Finds or creates [Dimension] with [Dimension.dim] equal to [dim]. */ -public expect fun Dimension.Companion.of(dim: UInt): Dimension +public expect fun Dimension.Companion.of(dim: Int): Dimension /** * Finds [Dimension.dim] of given type [D]. */ -public inline fun Dimension.Companion.dim(): UInt = D::class.dim() +public inline fun Dimension.Companion.dim(): Int = D::class.dim() /** * Type representing 1 dimension. */ public object D1 : Dimension { - override val dim: UInt get() = 1U + override val dim: Int get() = 1 } /** * Type representing 2 dimensions. */ public object D2 : Dimension { - override val dim: UInt get() = 2U + override val dim: Int get() = 2 } /** * Type representing 3 dimensions. */ public object D3 : Dimension { - override val dim: UInt get() = 3U + override val dim: Int get() = 3 } diff --git a/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Wrappers.kt index 2ebcc454d..c47f43723 100644 --- a/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Wrappers.kt +++ b/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Wrappers.kt @@ -1,33 +1,31 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.dimensions -import space.kscience.kmath.linear.LinearSpace -import space.kscience.kmath.linear.Matrix -import space.kscience.kmath.linear.Point -import space.kscience.kmath.linear.transpose +import space.kscience.kmath.linear.* import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Ring +import space.kscience.kmath.operations.algebra import kotlin.jvm.JvmInline /** * A matrix with compile-time controlled dimension */ -public interface DMatrix : Structure2D { +public interface DMatrix : Structure2D { public companion object { /** - * Coerces a regular matrix to a matrix with type-safe dimensions and throws a error if coercion failed + * Coerces a regular matrix to a matrix with type-safe dimensions and throws an error if coercion failed */ public inline fun coerce(structure: Structure2D): DMatrix { - require(structure.rowNum == Dimension.dim().toInt()) { + require(structure.rowNum == Dimension.dim()) { "Row number mismatch: expected ${Dimension.dim()} but found ${structure.rowNum}" } - require(structure.colNum == Dimension.dim().toInt()) { + require(structure.colNum == Dimension.dim()) { "Column number mismatch: expected ${Dimension.dim()} but found ${structure.colNum}" } @@ -35,7 +33,7 @@ public interface DMatrix : Structure2D { } /** - * The same as [DMatrix.coerce] but without dimension checks. Use with caution + * The same as [DMatrix.coerce] but without dimension checks. Use with caution. */ public fun coerceUnsafe(structure: Structure2D): DMatrix = DMatrixWrapper(structure) @@ -46,7 +44,7 @@ public interface DMatrix : Structure2D { * An inline wrapper for a Matrix */ @JvmInline -public value class DMatrixWrapper( +public value class DMatrixWrapper( private val structure: Structure2D, ) : DMatrix { override val shape: IntArray get() = structure.shape @@ -58,10 +56,10 @@ public value class DMatrixWrapper( /** * Dimension-safe point */ -public interface DPoint : Point { +public interface DPoint : Point { public companion object { public inline fun coerce(point: Point): DPoint { - require(point.size == Dimension.dim().toInt()) { + require(point.size == Dimension.dim()) { "Vector dimension mismatch: expected ${Dimension.dim()}, but found ${point.size}" } @@ -76,7 +74,7 @@ public interface DPoint : Point { * Dimension-safe point wrapper */ @JvmInline -public value class DPointWrapper(public val point: Point) : +public value class DPointWrapper(public val point: Point) : DPoint { override val size: Int get() = point.size @@ -92,11 +90,11 @@ public value class DPointWrapper(public val point: Point) : @JvmInline public value class DMatrixContext>(public val context: LinearSpace) { public inline fun Matrix.coerce(): DMatrix { - require(rowNum == Dimension.dim().toInt()) { + require(rowNum == Dimension.dim()) { "Row number mismatch: expected ${Dimension.dim()} but found $rowNum" } - require(colNum == Dimension.dim().toInt()) { + require(colNum == Dimension.dim()) { "Column number mismatch: expected ${Dimension.dim()} but found $colNum" } @@ -111,7 +109,7 @@ public value class DMatrixContext>(public val context: ): DMatrix { val rows = Dimension.dim() val cols = Dimension.dim() - return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce() + return context.buildMatrix(rows, cols, initializer).coerce() } public inline fun point(noinline initializer: A.(Int) -> T): DPoint { @@ -119,7 +117,7 @@ public value class DMatrixContext>(public val context: return DPoint.coerceUnsafe( context.buildVector( - size.toInt(), + size, initializer ) ) @@ -151,7 +149,7 @@ public value class DMatrixContext>(public val context: context.run { (this@transpose as Matrix).transpose() }.coerce() public companion object { - public val real: DMatrixContext = DMatrixContext(LinearSpace.real) + public val real: DMatrixContext = DMatrixContext(Double.algebra.linearSpace) } } @@ -167,4 +165,4 @@ public inline fun DMatrixContext.on public inline fun DMatrixContext.zero(): DMatrix = produce { _, _ -> 0.0 - } \ No newline at end of file + } diff --git a/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/space/kscience/dimensions/DMatrixContextTest.kt similarity index 95% rename from kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt rename to kmath-dimensions/src/commonTest/kotlin/space/kscience/dimensions/DMatrixContextTest.kt index 59260fe73..efa3170a3 100644 --- a/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt +++ b/kmath-dimensions/src/commonTest/kotlin/space/kscience/dimensions/DMatrixContextTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.dimensions diff --git a/kmath-dimensions/src/jsMain/kotlin/space/kscience/kmath/dimensions/dimJs.kt b/kmath-dimensions/src/jsMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt similarity index 56% rename from kmath-dimensions/src/jsMain/kotlin/space/kscience/kmath/dimensions/dimJs.kt rename to kmath-dimensions/src/jsMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt index 27912f5bc..324c78108 100644 --- a/kmath-dimensions/src/jsMain/kotlin/space/kscience/kmath/dimensions/dimJs.kt +++ b/kmath-dimensions/src/jsMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt @@ -1,23 +1,23 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.dimensions import kotlin.reflect.KClass -private val dimensionMap: MutableMap = hashMapOf(1u to D1, 2u to D2, 3u to D3) +private val dimensionMap: MutableMap = hashMapOf(1 to D1, 2 to D2, 3 to D3) @Suppress("UNCHECKED_CAST") public actual fun Dimension.Companion.resolve(type: KClass): D = dimensionMap .entries - .map(MutableMap.MutableEntry::value) + .map(MutableMap.MutableEntry::value) .find { it::class == type } as? D ?: error("Can't resolve dimension $type") -public actual fun Dimension.Companion.of(dim: UInt): Dimension = dimensionMap.getOrPut(dim) { +public actual fun Dimension.Companion.of(dim: Int): Dimension = dimensionMap.getOrPut(dim) { object : Dimension { - override val dim: UInt get() = dim + override val dim: Int get() = dim } } diff --git a/kmath-dimensions/src/jvmMain/kotlin/space/kscience/kmath/dimensions/dimJvm.kt b/kmath-dimensions/src/jvmMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt similarity index 62% rename from kmath-dimensions/src/jvmMain/kotlin/space/kscience/kmath/dimensions/dimJvm.kt rename to kmath-dimensions/src/jvmMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt index f21a3e18f..8fc683ed6 100644 --- a/kmath-dimensions/src/jvmMain/kotlin/space/kscience/kmath/dimensions/dimJvm.kt +++ b/kmath-dimensions/src/jvmMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt @@ -1,8 +1,10 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ +@file:JvmName("DimensionJVM") + package space.kscience.kmath.dimensions import kotlin.reflect.KClass @@ -10,12 +12,12 @@ import kotlin.reflect.KClass public actual fun Dimension.Companion.resolve(type: KClass): D = type.objectInstance ?: error("No object instance for dimension class") -public actual fun Dimension.Companion.of(dim: UInt): Dimension = when (dim) { - 1u -> D1 - 2u -> D2 - 3u -> D3 +public actual fun Dimension.Companion.of(dim: Int): Dimension = when (dim) { + 1 -> D1 + 2 -> D2 + 3 -> D3 else -> object : Dimension { - override val dim: UInt get() = dim + override val dim: Int get() = dim } -} \ No newline at end of file +} diff --git a/kmath-dimensions/src/nativeMain/kotlin/space/kscience/kmath/dimensions/dimNative.kt b/kmath-dimensions/src/nativeMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt similarity index 59% rename from kmath-dimensions/src/nativeMain/kotlin/space/kscience/kmath/dimensions/dimNative.kt rename to kmath-dimensions/src/nativeMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt index 9aa58e64a..001d68935 100644 --- a/kmath-dimensions/src/nativeMain/kotlin/space/kscience/kmath/dimensions/dimNative.kt +++ b/kmath-dimensions/src/nativeMain/kotlin/space/kscience/kmath/dimensions/Dimension.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.dimensions @@ -9,17 +9,17 @@ import kotlin.native.concurrent.ThreadLocal import kotlin.reflect.KClass @ThreadLocal -private val dimensionMap: MutableMap = hashMapOf(1u to D1, 2u to D2, 3u to D3) +private val dimensionMap: MutableMap = hashMapOf(1 to D1, 2 to D2, 3 to D3) @Suppress("UNCHECKED_CAST") public actual fun Dimension.Companion.resolve(type: KClass): D = dimensionMap .entries - .map(MutableMap.MutableEntry::value) + .map(MutableMap.MutableEntry::value) .find { it::class == type } as? D ?: error("Can't resolve dimension $type") -public actual fun Dimension.Companion.of(dim: UInt): Dimension = dimensionMap.getOrPut(dim) { +public actual fun Dimension.Companion.of(dim: Int): Dimension = dimensionMap.getOrPut(dim) { object : Dimension { - override val dim: UInt get() = dim + override val dim: Int get() = dim } } diff --git a/kmath-ejml/README.md b/kmath-ejml/README.md index 10e7bd606..f88f53000 100644 --- a/kmath-ejml/README.md +++ b/kmath-ejml/README.md @@ -9,7 +9,7 @@ EJML based linear algebra implementation. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-13`. +The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-14`. **Gradle:** ```gradle @@ -19,7 +19,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-ejml:0.3.0-dev-13' + implementation 'space.kscience:kmath-ejml:0.3.0-dev-14' } ``` **Gradle Kotlin DSL:** @@ -30,6 +30,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-ejml:0.3.0-dev-13") + implementation("space.kscience:kmath-ejml:0.3.0-dev-14") } ``` diff --git a/kmath-ejml/build.gradle.kts b/kmath-ejml/build.gradle.kts index 5107cfb68..727d21e3a 100644 --- a/kmath-ejml/build.gradle.kts +++ b/kmath-ejml/build.gradle.kts @@ -6,10 +6,10 @@ plugins { } dependencies { - api("org.ejml:ejml-ddense:0.40") - api("org.ejml:ejml-fdense:0.40") - api("org.ejml:ejml-dsparse:0.40") - api("org.ejml:ejml-fsparse:0.40") + api("org.ejml:ejml-ddense:0.41") + api("org.ejml:ejml-fdense:0.41") + api("org.ejml:ejml-dsparse:0.41") + api("org.ejml:ejml-fsparse:0.41") api(project(":kmath-core")) } diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt index f88e83369..25333157a 100644 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt @@ -1,13 +1,16 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ejml +import space.kscience.kmath.linear.InverseMatrixFeature import space.kscience.kmath.linear.LinearSpace import space.kscience.kmath.linear.Matrix import space.kscience.kmath.linear.Point +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.operations.Ring /** @@ -36,4 +39,9 @@ public abstract class EjmlLinearSpace, out M : org.ejml ): EjmlMatrix public abstract override fun buildVector(size: Int, initializer: A.(Int) -> T): EjmlVector + + @Suppress("UNCHECKED_CAST") + @UnstableKMathAPI + public fun EjmlMatrix.inverse(): Structure2D = + computeFeature(this, InverseMatrixFeature::class)?.inverse as Structure2D } diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt index cec31eb7d..9ad0f9c77 100644 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlMatrix.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ejml @@ -17,6 +17,6 @@ import space.kscience.kmath.nd.Structure2D * @author Iaroslav Postovalov */ public abstract class EjmlMatrix(public open val origin: M) : Structure2D { - public override val rowNum: Int get() = origin.numRows - public override val colNum: Int get() = origin.numCols + override val rowNum: Int get() = origin.numRows + override val colNum: Int get() = origin.numCols } diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt index 5d10d1fbb..a6de1b657 100644 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlVector.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ejml @@ -17,10 +17,10 @@ import space.kscience.kmath.linear.Point * @author Iaroslav Postovalov */ public abstract class EjmlVector(public open val origin: M) : Point { - public override val size: Int + override val size: Int get() = origin.numCols - public override operator fun iterator(): Iterator = object : Iterator { + override operator fun iterator(): Iterator = object : Iterator { private var cursor: Int = 0 override fun next(): T { @@ -31,5 +31,5 @@ public abstract class EjmlVector(public open val origin: override fun hasNext(): Boolean = cursor < origin.numCols * origin.numRows } - public override fun toString(): String = "EjmlVector(origin=$origin)" + override fun toString(): String = "EjmlVector(origin=$origin)" } diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt index 139c55697..dce739dc2 100644 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ /* This file is generated with buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt */ @@ -34,37 +34,37 @@ import kotlin.reflect.cast /** * [EjmlVector] specialization for [Double]. */ -public class EjmlDoubleVector(public override val origin: M) : EjmlVector(origin) { +public class EjmlDoubleVector(override val origin: M) : EjmlVector(origin) { init { require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } } - public override operator fun get(index: Int): Double = origin[0, index] + override operator fun get(index: Int): Double = origin[0, index] } /** * [EjmlVector] specialization for [Float]. */ -public class EjmlFloatVector(public override val origin: M) : EjmlVector(origin) { +public class EjmlFloatVector(override val origin: M) : EjmlVector(origin) { init { require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } } - public override operator fun get(index: Int): Float = origin[0, index] + override operator fun get(index: Int): Float = origin[0, index] } /** * [EjmlMatrix] specialization for [Double]. */ -public class EjmlDoubleMatrix(public override val origin: M) : EjmlMatrix(origin) { - public override operator fun get(i: Int, j: Int): Double = origin[i, j] +public class EjmlDoubleMatrix(override val origin: M) : EjmlMatrix(origin) { + override operator fun get(i: Int, j: Int): Double = origin[i, j] } /** * [EjmlMatrix] specialization for [Float]. */ -public class EjmlFloatMatrix(public override val origin: M) : EjmlMatrix(origin) { - public override operator fun get(i: Int, j: Int): Float = origin[i, j] +public class EjmlFloatMatrix(override val origin: M) : EjmlMatrix(origin) { + override operator fun get(i: Int, j: Int): Float = origin[i, j] } /** @@ -75,23 +75,23 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace.toEjml(): EjmlDoubleMatrix = when { + override fun Matrix.toEjml(): EjmlDoubleMatrix = when { this is EjmlDoubleMatrix<*> && origin is DMatrixRMaj -> this as EjmlDoubleMatrix else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } } @Suppress("UNCHECKED_CAST") - public override fun Point.toEjml(): EjmlDoubleVector = when { + override fun Point.toEjml(): EjmlDoubleVector = when { this is EjmlDoubleVector<*> && origin is DMatrixRMaj -> this as EjmlDoubleVector else -> EjmlDoubleVector(DMatrixRMaj(size, 1).also { (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } }) } - public override fun buildMatrix( + override fun buildMatrix( rows: Int, columns: Int, initializer: DoubleField.(i: Int, j: Int) -> Double, @@ -101,7 +101,7 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace Double, ): EjmlDoubleVector = EjmlDoubleVector(DMatrixRMaj(size, 1).also { @@ -111,21 +111,21 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace T.wrapMatrix() = EjmlDoubleMatrix(this) private fun T.wrapVector() = EjmlDoubleVector(this) - public override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } + override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } - public override fun Matrix.dot(other: Matrix): EjmlDoubleMatrix { + override fun Matrix.dot(other: Matrix): EjmlDoubleMatrix { val out = DMatrixRMaj(1, 1) CommonOps_DDRM.mult(toEjml().origin, other.toEjml().origin, out) return out.wrapMatrix() } - public override fun Matrix.dot(vector: Point): EjmlDoubleVector { + override fun Matrix.dot(vector: Point): EjmlDoubleVector { val out = DMatrixRMaj(1, 1) CommonOps_DDRM.mult(toEjml().origin, vector.toEjml().origin, out) return out.wrapVector() } - public override operator fun Matrix.minus(other: Matrix): EjmlDoubleMatrix { + override operator fun Matrix.minus(other: Matrix): EjmlDoubleMatrix { val out = DMatrixRMaj(1, 1) CommonOps_DDRM.add( @@ -139,19 +139,19 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace.times(value: Double): EjmlDoubleMatrix { + override operator fun Matrix.times(value: Double): EjmlDoubleMatrix { val res = DMatrixRMaj(1, 1) CommonOps_DDRM.scale(value, toEjml().origin, res) return res.wrapMatrix() } - public override fun Point.unaryMinus(): EjmlDoubleVector { + override fun Point.unaryMinus(): EjmlDoubleVector { val res = DMatrixRMaj(1, 1) CommonOps_DDRM.changeSign(toEjml().origin, res) return res.wrapVector() } - public override fun Matrix.plus(other: Matrix): EjmlDoubleMatrix { + override fun Matrix.plus(other: Matrix): EjmlDoubleMatrix { val out = DMatrixRMaj(1, 1) CommonOps_DDRM.add( @@ -165,7 +165,7 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace.plus(other: Point): EjmlDoubleVector { + override fun Point.plus(other: Point): EjmlDoubleVector { val out = DMatrixRMaj(1, 1) CommonOps_DDRM.add( @@ -179,7 +179,7 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace.minus(other: Point): EjmlDoubleVector { + override fun Point.minus(other: Point): EjmlDoubleVector { val out = DMatrixRMaj(1, 1) CommonOps_DDRM.add( @@ -193,18 +193,18 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace): EjmlDoubleMatrix = m * this + override fun Double.times(m: Matrix): EjmlDoubleMatrix = m * this - public override fun Point.times(value: Double): EjmlDoubleVector { + override fun Point.times(value: Double): EjmlDoubleVector { val res = DMatrixRMaj(1, 1) CommonOps_DDRM.scale(value, toEjml().origin, res) return res.wrapVector() } - public override fun Double.times(v: Point): EjmlDoubleVector = v * this + override fun Double.times(v: Point): EjmlDoubleVector = v * this @UnstableKMathAPI - public override fun getFeature(structure: Matrix, type: KClass): F? { + override fun computeFeature(structure: Matrix, type: KClass): F? { structure.getFeature(type)?.let { return it } val origin = structure.toEjml().origin @@ -239,10 +239,10 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace by lazy { - qr.getQ(null, false).wrapMatrix() + OrthogonalFeature + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) } - override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix() + UFeature } + override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } } CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { @@ -250,7 +250,7 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace by lazy { - lup.getLower(null).wrapMatrix() + LFeature + lup.getLower(null).wrapMatrix().withFeature(LFeature) } override val u: Matrix by lazy { - lup.getUpper(null).wrapMatrix() + UFeature + lup.getUpper(null).wrapMatrix().withFeature(UFeature) } override val p: Matrix by lazy { lup.getRowPivot(null).wrapMatrix() } @@ -309,23 +309,23 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace.toEjml(): EjmlFloatMatrix = when { + override fun Matrix.toEjml(): EjmlFloatMatrix = when { this is EjmlFloatMatrix<*> && origin is FMatrixRMaj -> this as EjmlFloatMatrix else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } } @Suppress("UNCHECKED_CAST") - public override fun Point.toEjml(): EjmlFloatVector = when { + override fun Point.toEjml(): EjmlFloatVector = when { this is EjmlFloatVector<*> && origin is FMatrixRMaj -> this as EjmlFloatVector else -> EjmlFloatVector(FMatrixRMaj(size, 1).also { (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } }) } - public override fun buildMatrix( + override fun buildMatrix( rows: Int, columns: Int, initializer: FloatField.(i: Int, j: Int) -> Float, @@ -335,7 +335,7 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace Float, ): EjmlFloatVector = EjmlFloatVector(FMatrixRMaj(size, 1).also { @@ -345,21 +345,21 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace T.wrapMatrix() = EjmlFloatMatrix(this) private fun T.wrapVector() = EjmlFloatVector(this) - public override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } + override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } - public override fun Matrix.dot(other: Matrix): EjmlFloatMatrix { + override fun Matrix.dot(other: Matrix): EjmlFloatMatrix { val out = FMatrixRMaj(1, 1) CommonOps_FDRM.mult(toEjml().origin, other.toEjml().origin, out) return out.wrapMatrix() } - public override fun Matrix.dot(vector: Point): EjmlFloatVector { + override fun Matrix.dot(vector: Point): EjmlFloatVector { val out = FMatrixRMaj(1, 1) CommonOps_FDRM.mult(toEjml().origin, vector.toEjml().origin, out) return out.wrapVector() } - public override operator fun Matrix.minus(other: Matrix): EjmlFloatMatrix { + override operator fun Matrix.minus(other: Matrix): EjmlFloatMatrix { val out = FMatrixRMaj(1, 1) CommonOps_FDRM.add( @@ -373,19 +373,19 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace.times(value: Float): EjmlFloatMatrix { + override operator fun Matrix.times(value: Float): EjmlFloatMatrix { val res = FMatrixRMaj(1, 1) CommonOps_FDRM.scale(value, toEjml().origin, res) return res.wrapMatrix() } - public override fun Point.unaryMinus(): EjmlFloatVector { + override fun Point.unaryMinus(): EjmlFloatVector { val res = FMatrixRMaj(1, 1) CommonOps_FDRM.changeSign(toEjml().origin, res) return res.wrapVector() } - public override fun Matrix.plus(other: Matrix): EjmlFloatMatrix { + override fun Matrix.plus(other: Matrix): EjmlFloatMatrix { val out = FMatrixRMaj(1, 1) CommonOps_FDRM.add( @@ -399,7 +399,7 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace.plus(other: Point): EjmlFloatVector { + override fun Point.plus(other: Point): EjmlFloatVector { val out = FMatrixRMaj(1, 1) CommonOps_FDRM.add( @@ -413,7 +413,7 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace.minus(other: Point): EjmlFloatVector { + override fun Point.minus(other: Point): EjmlFloatVector { val out = FMatrixRMaj(1, 1) CommonOps_FDRM.add( @@ -427,18 +427,18 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace): EjmlFloatMatrix = m * this + override fun Float.times(m: Matrix): EjmlFloatMatrix = m * this - public override fun Point.times(value: Float): EjmlFloatVector { + override fun Point.times(value: Float): EjmlFloatVector { val res = FMatrixRMaj(1, 1) CommonOps_FDRM.scale(value, toEjml().origin, res) return res.wrapVector() } - public override fun Float.times(v: Point): EjmlFloatVector = v * this + override fun Float.times(v: Point): EjmlFloatVector = v * this @UnstableKMathAPI - public override fun getFeature(structure: Matrix, type: KClass): F? { + override fun computeFeature(structure: Matrix, type: KClass): F? { structure.getFeature(type)?.let { return it } val origin = structure.toEjml().origin @@ -473,10 +473,10 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace by lazy { - qr.getQ(null, false).wrapMatrix() + OrthogonalFeature + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) } - override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix() + UFeature } + override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } } CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { @@ -484,7 +484,7 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace by lazy { - lup.getLower(null).wrapMatrix() + LFeature + lup.getLower(null).wrapMatrix().withFeature(LFeature) } override val u: Matrix by lazy { - lup.getUpper(null).wrapMatrix() + UFeature + lup.getUpper(null).wrapMatrix().withFeature(UFeature) } override val p: Matrix by lazy { lup.getRowPivot(null).wrapMatrix() } @@ -543,23 +543,23 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace.toEjml(): EjmlDoubleMatrix = when { + override fun Matrix.toEjml(): EjmlDoubleMatrix = when { this is EjmlDoubleMatrix<*> && origin is DMatrixSparseCSC -> this as EjmlDoubleMatrix else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } } @Suppress("UNCHECKED_CAST") - public override fun Point.toEjml(): EjmlDoubleVector = when { + override fun Point.toEjml(): EjmlDoubleVector = when { this is EjmlDoubleVector<*> && origin is DMatrixSparseCSC -> this as EjmlDoubleVector else -> EjmlDoubleVector(DMatrixSparseCSC(size, 1).also { (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } }) } - public override fun buildMatrix( + override fun buildMatrix( rows: Int, columns: Int, initializer: DoubleField.(i: Int, j: Int) -> Double, @@ -569,7 +569,7 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace Double, ): EjmlDoubleVector = EjmlDoubleVector(DMatrixSparseCSC(size, 1).also { @@ -579,21 +579,21 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace T.wrapMatrix() = EjmlDoubleMatrix(this) private fun T.wrapVector() = EjmlDoubleVector(this) - public override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } + override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } - public override fun Matrix.dot(other: Matrix): EjmlDoubleMatrix { + override fun Matrix.dot(other: Matrix): EjmlDoubleMatrix { val out = DMatrixSparseCSC(1, 1) CommonOps_DSCC.mult(toEjml().origin, other.toEjml().origin, out) return out.wrapMatrix() } - public override fun Matrix.dot(vector: Point): EjmlDoubleVector { + override fun Matrix.dot(vector: Point): EjmlDoubleVector { val out = DMatrixSparseCSC(1, 1) CommonOps_DSCC.mult(toEjml().origin, vector.toEjml().origin, out) return out.wrapVector() } - public override operator fun Matrix.minus(other: Matrix): EjmlDoubleMatrix { + override operator fun Matrix.minus(other: Matrix): EjmlDoubleMatrix { val out = DMatrixSparseCSC(1, 1) CommonOps_DSCC.add( @@ -609,19 +609,19 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace.times(value: Double): EjmlDoubleMatrix { + override operator fun Matrix.times(value: Double): EjmlDoubleMatrix { val res = DMatrixSparseCSC(1, 1) CommonOps_DSCC.scale(value, toEjml().origin, res) return res.wrapMatrix() } - public override fun Point.unaryMinus(): EjmlDoubleVector { + override fun Point.unaryMinus(): EjmlDoubleVector { val res = DMatrixSparseCSC(1, 1) CommonOps_DSCC.changeSign(toEjml().origin, res) return res.wrapVector() } - public override fun Matrix.plus(other: Matrix): EjmlDoubleMatrix { + override fun Matrix.plus(other: Matrix): EjmlDoubleMatrix { val out = DMatrixSparseCSC(1, 1) CommonOps_DSCC.add( @@ -637,7 +637,7 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace.plus(other: Point): EjmlDoubleVector { + override fun Point.plus(other: Point): EjmlDoubleVector { val out = DMatrixSparseCSC(1, 1) CommonOps_DSCC.add( @@ -653,7 +653,7 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace.minus(other: Point): EjmlDoubleVector { + override fun Point.minus(other: Point): EjmlDoubleVector { val out = DMatrixSparseCSC(1, 1) CommonOps_DSCC.add( @@ -669,18 +669,18 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace): EjmlDoubleMatrix = m * this + override fun Double.times(m: Matrix): EjmlDoubleMatrix = m * this - public override fun Point.times(value: Double): EjmlDoubleVector { + override fun Point.times(value: Double): EjmlDoubleVector { val res = DMatrixSparseCSC(1, 1) CommonOps_DSCC.scale(value, toEjml().origin, res) return res.wrapVector() } - public override fun Double.times(v: Point): EjmlDoubleVector = v * this + override fun Double.times(v: Point): EjmlDoubleVector = v * this @UnstableKMathAPI - public override fun getFeature(structure: Matrix, type: KClass): F? { + override fun computeFeature(structure: Matrix, type: KClass): F? { structure.getFeature(type)?.let { return it } val origin = structure.toEjml().origin @@ -691,10 +691,10 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace by lazy { - qr.getQ(null, false).wrapMatrix() + OrthogonalFeature + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) } - override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix() + UFeature } + override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } } CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { @@ -702,7 +702,7 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace by lazy { - lu.getLower(null).wrapMatrix() + LFeature + lu.getLower(null).wrapMatrix().withFeature(LFeature) } override val u: Matrix by lazy { - lu.getUpper(null).wrapMatrix() + UFeature + lu.getUpper(null).wrapMatrix().withFeature(UFeature) } override val inverse: Matrix by lazy { @@ -772,23 +772,23 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace.toEjml(): EjmlFloatMatrix = when { + override fun Matrix.toEjml(): EjmlFloatMatrix = when { this is EjmlFloatMatrix<*> && origin is FMatrixSparseCSC -> this as EjmlFloatMatrix else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } } @Suppress("UNCHECKED_CAST") - public override fun Point.toEjml(): EjmlFloatVector = when { + override fun Point.toEjml(): EjmlFloatVector = when { this is EjmlFloatVector<*> && origin is FMatrixSparseCSC -> this as EjmlFloatVector else -> EjmlFloatVector(FMatrixSparseCSC(size, 1).also { (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } }) } - public override fun buildMatrix( + override fun buildMatrix( rows: Int, columns: Int, initializer: FloatField.(i: Int, j: Int) -> Float, @@ -798,7 +798,7 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace Float, ): EjmlFloatVector = EjmlFloatVector(FMatrixSparseCSC(size, 1).also { @@ -808,21 +808,21 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace T.wrapMatrix() = EjmlFloatMatrix(this) private fun T.wrapVector() = EjmlFloatVector(this) - public override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } + override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } - public override fun Matrix.dot(other: Matrix): EjmlFloatMatrix { + override fun Matrix.dot(other: Matrix): EjmlFloatMatrix { val out = FMatrixSparseCSC(1, 1) CommonOps_FSCC.mult(toEjml().origin, other.toEjml().origin, out) return out.wrapMatrix() } - public override fun Matrix.dot(vector: Point): EjmlFloatVector { + override fun Matrix.dot(vector: Point): EjmlFloatVector { val out = FMatrixSparseCSC(1, 1) CommonOps_FSCC.mult(toEjml().origin, vector.toEjml().origin, out) return out.wrapVector() } - public override operator fun Matrix.minus(other: Matrix): EjmlFloatMatrix { + override operator fun Matrix.minus(other: Matrix): EjmlFloatMatrix { val out = FMatrixSparseCSC(1, 1) CommonOps_FSCC.add( @@ -838,19 +838,19 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace.times(value: Float): EjmlFloatMatrix { + override operator fun Matrix.times(value: Float): EjmlFloatMatrix { val res = FMatrixSparseCSC(1, 1) CommonOps_FSCC.scale(value, toEjml().origin, res) return res.wrapMatrix() } - public override fun Point.unaryMinus(): EjmlFloatVector { + override fun Point.unaryMinus(): EjmlFloatVector { val res = FMatrixSparseCSC(1, 1) CommonOps_FSCC.changeSign(toEjml().origin, res) return res.wrapVector() } - public override fun Matrix.plus(other: Matrix): EjmlFloatMatrix { + override fun Matrix.plus(other: Matrix): EjmlFloatMatrix { val out = FMatrixSparseCSC(1, 1) CommonOps_FSCC.add( @@ -866,7 +866,7 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace.plus(other: Point): EjmlFloatVector { + override fun Point.plus(other: Point): EjmlFloatVector { val out = FMatrixSparseCSC(1, 1) CommonOps_FSCC.add( @@ -882,7 +882,7 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace.minus(other: Point): EjmlFloatVector { + override fun Point.minus(other: Point): EjmlFloatVector { val out = FMatrixSparseCSC(1, 1) CommonOps_FSCC.add( @@ -898,18 +898,18 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace): EjmlFloatMatrix = m * this + override fun Float.times(m: Matrix): EjmlFloatMatrix = m * this - public override fun Point.times(value: Float): EjmlFloatVector { + override fun Point.times(value: Float): EjmlFloatVector { val res = FMatrixSparseCSC(1, 1) CommonOps_FSCC.scale(value, toEjml().origin, res) return res.wrapVector() } - public override fun Float.times(v: Point): EjmlFloatVector = v * this + override fun Float.times(v: Point): EjmlFloatVector = v * this @UnstableKMathAPI - public override fun getFeature(structure: Matrix, type: KClass): F? { + override fun computeFeature(structure: Matrix, type: KClass): F? { structure.getFeature(type)?.let { return it } val origin = structure.toEjml().origin @@ -920,10 +920,10 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace by lazy { - qr.getQ(null, false).wrapMatrix() + OrthogonalFeature + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) } - override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix() + UFeature } + override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } } CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { @@ -931,7 +931,7 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace by lazy { - lu.getLower(null).wrapMatrix() + LFeature + lu.getLower(null).wrapMatrix().withFeature(LFeature) } override val u: Matrix by lazy { - lu.getUpper(null).wrapMatrix() + UFeature + lu.getUpper(null).wrapMatrix().withFeature(UFeature) } override val inverse: Matrix by lazy { diff --git a/kmath-ejml/src/test/kotlin/space/kscience/kmath/ejml/EjmlMatrixTest.kt b/kmath-ejml/src/test/kotlin/space/kscience/kmath/ejml/EjmlMatrixTest.kt index 50675bdac..5b8b2af98 100644 --- a/kmath-ejml/src/test/kotlin/space/kscience/kmath/ejml/EjmlMatrixTest.kt +++ b/kmath-ejml/src/test/kotlin/space/kscience/kmath/ejml/EjmlMatrixTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ejml @@ -9,12 +9,11 @@ import org.ejml.data.DMatrixRMaj import org.ejml.dense.row.CommonOps_DDRM import org.ejml.dense.row.RandomMatrices_DDRM import org.ejml.dense.row.factory.DecompositionFactory_DDRM -import space.kscience.kmath.linear.DeterminantFeature -import space.kscience.kmath.linear.LupDecompositionFeature -import space.kscience.kmath.linear.getFeature +import space.kscience.kmath.linear.* import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.algebra import kotlin.random.Random import kotlin.random.asJavaRandom import kotlin.test.* @@ -59,9 +58,9 @@ internal class EjmlMatrixTest { fun features() { val m = randomMatrix val w = EjmlDoubleMatrix(m) - val det: DeterminantFeature = EjmlLinearSpaceDDRM.getFeature(w) ?: fail() + val det: DeterminantFeature = EjmlLinearSpaceDDRM.computeFeature(w) ?: fail() assertEquals(CommonOps_DDRM.det(m), det.determinant) - val lup: LupDecompositionFeature = EjmlLinearSpaceDDRM.getFeature(w) ?: fail() + val lup: LupDecompositionFeature = EjmlLinearSpaceDDRM.computeFeature(w) ?: fail() val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows, m.numCols) .also { it.decompose(m.copy()) } @@ -82,4 +81,24 @@ internal class EjmlMatrixTest { val m = randomMatrix assertSame(m, EjmlDoubleMatrix(m).origin) } + + @Test + fun inverse() = EjmlLinearSpaceDDRM { + val random = Random(1224) + val dim = 20 + + val space = Double.algebra.linearSpace + + //creating invertible matrix + val u = space.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } + val l = space.buildMatrix(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 } + val matrix = space { l dot u } + val inverted = matrix.toEjml().inverse() + + val res = matrix dot inverted + + println(StructureND.toString(res)) + + assertTrue { StructureND.contentEquals(one(dim, dim), res, 1e-3) } + } } diff --git a/kmath-ejml/src/test/kotlin/space/kscience/kmath/ejml/EjmlVectorTest.kt b/kmath-ejml/src/test/kotlin/space/kscience/kmath/ejml/EjmlVectorTest.kt index 9592bfa6c..c87a01436 100644 --- a/kmath-ejml/src/test/kotlin/space/kscience/kmath/ejml/EjmlVectorTest.kt +++ b/kmath-ejml/src/test/kotlin/space/kscience/kmath/ejml/EjmlVectorTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.ejml diff --git a/kmath-for-real/README.md b/kmath-for-real/README.md index 6339782dd..d449b4540 100644 --- a/kmath-for-real/README.md +++ b/kmath-for-real/README.md @@ -9,7 +9,7 @@ Specialization of KMath APIs for Double numbers. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-13`. +The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-14`. **Gradle:** ```gradle @@ -19,7 +19,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-for-real:0.3.0-dev-13' + implementation 'space.kscience:kmath-for-real:0.3.0-dev-14' } ``` **Gradle Kotlin DSL:** @@ -30,6 +30,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-for-real:0.3.0-dev-13") + implementation("space.kscience:kmath-for-real:0.3.0-dev-14") } ``` diff --git a/kmath-for-real/build.gradle.kts b/kmath-for-real/build.gradle.kts index f6d12decd..4cccaef5c 100644 --- a/kmath-for-real/build.gradle.kts +++ b/kmath-for-real/build.gradle.kts @@ -1,6 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") + id("ru.mipt.npm.gradle.native") } kotlin.sourceSets.commonMain { @@ -11,7 +12,7 @@ kotlin.sourceSets.commonMain { readme { description = """ - Extension module that should be used to achieve numpy-like behavior. + Extension module that should be used to achieve numpy-like behavior. All operations are specialized to work with `Double` numbers without declaring algebraic contexts. One can still use generic algebras though. """.trimIndent() diff --git a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealMatrix.kt b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealMatrix.kt index 8023236ea..c1ee8b48f 100644 --- a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealMatrix.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:OptIn(PerformancePitfall::class) @@ -12,9 +12,10 @@ import space.kscience.kmath.linear.* import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.algebra +import space.kscience.kmath.operations.asIterable import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.DoubleBuffer -import space.kscience.kmath.structures.asIterable import kotlin.math.pow /* @@ -32,18 +33,18 @@ import kotlin.math.pow public typealias RealMatrix = Matrix public fun realMatrix(rowNum: Int, colNum: Int, initializer: DoubleField.(i: Int, j: Int) -> Double): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum, initializer) + Double.algebra.linearSpace.buildMatrix(rowNum, colNum, initializer) @OptIn(UnstableKMathAPI::class) public fun realMatrix(rowNum: Int, colNum: Int): MatrixBuilder = - LinearSpace.real.matrix(rowNum, colNum) + Double.algebra.linearSpace.matrix(rowNum, colNum) public fun Array.toMatrix(): RealMatrix { - return LinearSpace.real.buildMatrix(size, this[0].size) { row, col -> this@toMatrix[row][col] } + return Double.algebra.linearSpace.buildMatrix(size, this[0].size) { row, col -> this@toMatrix[row][col] } } public fun Sequence.toMatrix(): RealMatrix = toList().let { - LinearSpace.real.buildMatrix(it.size, it[0].size) { row, col -> it[row][col] } + Double.algebra.linearSpace.buildMatrix(it.size, it[0].size) { row, col -> it[row][col] } } public fun RealMatrix.repeatStackVertical(n: Int): RealMatrix = @@ -56,37 +57,37 @@ public fun RealMatrix.repeatStackVertical(n: Int): RealMatrix = */ public operator fun RealMatrix.times(double: Double): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> + Double.algebra.linearSpace.buildMatrix(rowNum, colNum) { row, col -> get(row, col) * double } public operator fun RealMatrix.plus(double: Double): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> + Double.algebra.linearSpace.buildMatrix(rowNum, colNum) { row, col -> get(row, col) + double } public operator fun RealMatrix.minus(double: Double): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> + Double.algebra.linearSpace.buildMatrix(rowNum, colNum) { row, col -> get(row, col) - double } public operator fun RealMatrix.div(double: Double): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> + Double.algebra.linearSpace.buildMatrix(rowNum, colNum) { row, col -> get(row, col) / double } public operator fun Double.times(matrix: RealMatrix): RealMatrix = - LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col -> + Double.algebra.linearSpace.buildMatrix(matrix.rowNum, matrix.colNum) { row, col -> this@times * matrix[row, col] } public operator fun Double.plus(matrix: RealMatrix): RealMatrix = - LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col -> + Double.algebra.linearSpace.buildMatrix(matrix.rowNum, matrix.colNum) { row, col -> this@plus + matrix[row, col] } public operator fun Double.minus(matrix: RealMatrix): RealMatrix = - LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col -> + Double.algebra.linearSpace.buildMatrix(matrix.rowNum, matrix.colNum) { row, col -> this@minus - matrix[row, col] } @@ -101,20 +102,20 @@ public operator fun Double.minus(matrix: RealMatrix): RealMatrix = @UnstableKMathAPI public operator fun RealMatrix.times(other: RealMatrix): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this@times[row, col] * other[row, col] } + Double.algebra.linearSpace.buildMatrix(rowNum, colNum) { row, col -> this@times[row, col] * other[row, col] } public operator fun RealMatrix.plus(other: RealMatrix): RealMatrix = - LinearSpace.real.run { this@plus + other } + Double.algebra.linearSpace.run { this@plus + other } public operator fun RealMatrix.minus(other: RealMatrix): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this@minus[row, col] - other[row, col] } + Double.algebra.linearSpace.buildMatrix(rowNum, colNum) { row, col -> this@minus[row, col] - other[row, col] } /* * Operations on columns */ public inline fun RealMatrix.appendColumn(crossinline mapper: (Buffer) -> Double): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum + 1) { row, col -> + Double.algebra.linearSpace.buildMatrix(rowNum, colNum + 1) { row, col -> if (col < colNum) get(row, col) else @@ -122,7 +123,7 @@ public inline fun RealMatrix.appendColumn(crossinline mapper: (Buffer) - } public fun RealMatrix.extractColumns(columnRange: IntRange): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, columnRange.count()) { row, col -> + Double.algebra.linearSpace.buildMatrix(rowNum, columnRange.count()) { row, col -> this@extractColumns[row, columnRange.first + col] } @@ -155,14 +156,14 @@ public fun RealMatrix.max(): Double? = elements().map { (_, value) -> value }.ma public fun RealMatrix.average(): Double = elements().map { (_, value) -> value }.average() public inline fun RealMatrix.map(crossinline transform: (Double) -> Double): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum) { i, j -> + Double.algebra.linearSpace.buildMatrix(rowNum, colNum) { i, j -> transform(get(i, j)) } /** * Inverse a square real matrix using LUP decomposition */ -public fun RealMatrix.inverseWithLup(): RealMatrix = LinearSpace.real.inverseWithLup(this) +public fun RealMatrix.inverseWithLup(): RealMatrix = Double.algebra.linearSpace.lupSolver().inverse(this) //extended operations @@ -184,4 +185,4 @@ public fun tan(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.tan(it) } public fun ln(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.ln(it) } -public fun log10(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.log10(it) } \ No newline at end of file +public fun log10(arg: RealMatrix): RealMatrix = arg.map { kotlin.math.log10(it) } diff --git a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealVector.kt b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealVector.kt index d3867ea89..cca1c3551 100644 --- a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealVector.kt +++ b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealVector.kt @@ -1,20 +1,18 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.real import space.kscience.kmath.linear.Point import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.Norm +import space.kscience.kmath.operations.DoubleL2Norm import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.MutableBuffer.Companion.double import space.kscience.kmath.structures.asBuffer -import space.kscience.kmath.structures.fold import space.kscience.kmath.structures.indices import kotlin.math.pow -import kotlin.math.sqrt public typealias DoubleVector = Point @@ -22,7 +20,7 @@ public typealias DoubleVector = Point public fun DoubleVector(vararg doubles: Double): DoubleVector = doubles.asBuffer() /** - * Fill the vector of given [size] with given [value] + * Fill the vector with given [size] with given [value] */ @UnstableKMathAPI public fun Buffer.Companion.same(size: Int, value: Number): DoubleVector = double(size) { value.toDouble() } @@ -105,8 +103,4 @@ public fun DoubleVector.sum(): Double { return res } -public object VectorL2Norm : Norm { - override fun norm(arg: DoubleVector): Double = sqrt(arg.fold(0.0) { acc: Double, d: Double -> acc + d.pow(2) }) -} - -public val DoubleVector.norm: Double get() = VectorL2Norm.norm(this) \ No newline at end of file +public val DoubleVector.norm: Double get() = DoubleL2Norm.norm(this) \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/dot.kt b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/dot.kt index b79e5030c..883a63f46 100644 --- a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/dot.kt +++ b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/dot.kt @@ -1,17 +1,18 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.real -import space.kscience.kmath.linear.LinearSpace import space.kscience.kmath.linear.Matrix +import space.kscience.kmath.linear.linearSpace +import space.kscience.kmath.operations.algebra /** * Optimized dot product for real matrices */ -public infix fun Matrix.dot(other: Matrix): Matrix = LinearSpace.real.run { +public infix fun Matrix.dot(other: Matrix): Matrix = Double.algebra.linearSpace.run { this@dot dot other } \ No newline at end of file diff --git a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/grids.kt b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/grids.kt index c3556216d..fba999e6c 100644 --- a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/grids.kt +++ b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/grids.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.real @@ -45,8 +45,8 @@ public fun Buffer.Companion.withFixedStep(range: ClosedFloatingPointRange.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND { - val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) } - return BufferND(strides, DoubleBuffer(array)) + val array = DoubleArray(indices.linearSize) { offset -> DoubleField.transform(buffer[offset]) } + return BufferND(indices, DoubleBuffer(array)) } /** diff --git a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/DoubleMatrixTest.kt b/kmath-for-real/src/commonTest/kotlin/space/kscience/kmath/real/DoubleMatrixTest.kt similarity index 92% rename from kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/DoubleMatrixTest.kt rename to kmath-for-real/src/commonTest/kotlin/space/kscience/kmath/real/DoubleMatrixTest.kt index b3e129c2e..3277410c0 100644 --- a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/DoubleMatrixTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/space/kscience/kmath/real/DoubleMatrixTest.kt @@ -1,16 +1,16 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ -package kaceince.kmath.real +package space.kscience.kmath.real -import space.kscience.kmath.linear.LinearSpace +import space.kscience.kmath.linear.linearSpace import space.kscience.kmath.linear.matrix import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.real.* +import space.kscience.kmath.operations.algebra import space.kscience.kmath.structures.contentEquals import kotlin.test.Test import kotlin.test.assertEquals @@ -32,7 +32,7 @@ internal class DoubleMatrixTest { @Test fun testSequenceToMatrix() { - val m = Sequence { + val m = Sequence { listOf( DoubleArray(10) { 10.0 }, DoubleArray(10) { 20.0 }, @@ -59,13 +59,13 @@ internal class DoubleMatrixTest { } @Test - fun testMatrixAndDouble() { + fun testMatrixAndDouble() = Double.algebra.linearSpace.run { val matrix1 = realMatrix(2, 3)( 1.0, 0.0, 3.0, 4.0, 6.0, 2.0 ) val matrix2 = (matrix1 * 2.5 + 1.0 - 2.0) / 2.0 - val expectedResult = LinearSpace.real.matrix(2, 3)( + val expectedResult = matrix(2, 3)( 0.75, -0.5, 3.25, 4.5, 7.0, 2.0 ) @@ -159,8 +159,8 @@ internal class DoubleMatrixTest { } @Test - fun testAllElementOperations() { - val matrix1 = LinearSpace.real.matrix(2, 4)( + fun testAllElementOperations() = Double.algebra.linearSpace.run { + val matrix1 = matrix(2, 4)( -1.0, 0.0, 3.0, 15.0, 4.0, -6.0, 7.0, -11.0 ) diff --git a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/DoubleVectorTest.kt b/kmath-for-real/src/commonTest/kotlin/space/kscience/kmath/real/DoubleVectorTest.kt similarity index 80% rename from kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/DoubleVectorTest.kt rename to kmath-for-real/src/commonTest/kotlin/space/kscience/kmath/real/DoubleVectorTest.kt index 9de54381c..771981772 100644 --- a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/DoubleVectorTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/space/kscience/kmath/real/DoubleVectorTest.kt @@ -1,14 +1,14 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ -package kaceince.kmath.real +package space.kscience.kmath.real -import space.kscience.kmath.linear.LinearSpace import space.kscience.kmath.linear.asMatrix +import space.kscience.kmath.linear.linearSpace import space.kscience.kmath.linear.transpose -import space.kscience.kmath.real.plus +import space.kscience.kmath.operations.algebra import space.kscience.kmath.structures.DoubleBuffer import kotlin.test.Test import kotlin.test.assertEquals @@ -30,12 +30,12 @@ internal class DoubleVectorTest { } @Test - fun testDot() { + fun testDot() = Double.algebra.linearSpace.run { val vector1 = DoubleBuffer(5) { it.toDouble() } val vector2 = DoubleBuffer(5) { 5 - it.toDouble() } val matrix1 = vector1.asMatrix() val matrix2 = vector2.asMatrix().transpose() - val product = LinearSpace.real.run { matrix1 dot matrix2 } + val product = matrix1 dot matrix2 assertEquals(5.0, product[1, 0]) assertEquals(6.0, product[2, 2]) } diff --git a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt b/kmath-for-real/src/commonTest/kotlin/space/kscience/kmath/real/GridTest.kt similarity index 90% rename from kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt rename to kmath-for-real/src/commonTest/kotlin/space/kscience/kmath/real/GridTest.kt index 0d3b80336..ec1ed8f50 100644 --- a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/space/kscience/kmath/real/GridTest.kt @@ -1,9 +1,9 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ -package kaceince.kmath.real +package space.kscience.kmath.real import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.real.DoubleVector diff --git a/kmath-functions/README.md b/kmath-functions/README.md index 77f55528a..d0beae2c8 100644 --- a/kmath-functions/README.md +++ b/kmath-functions/README.md @@ -11,7 +11,7 @@ Functions and interpolations. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-13`. +The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-14`. **Gradle:** ```gradle @@ -21,7 +21,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-functions:0.3.0-dev-13' + implementation 'space.kscience:kmath-functions:0.3.0-dev-14' } ``` **Gradle Kotlin DSL:** @@ -32,6 +32,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-functions:0.3.0-dev-13") + implementation("space.kscience:kmath-functions:0.3.0-dev-14") } ``` diff --git a/kmath-functions/build.gradle.kts b/kmath-functions/build.gradle.kts index f77df3833..fadbac091 100644 --- a/kmath-functions/build.gradle.kts +++ b/kmath-functions/build.gradle.kts @@ -1,6 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") + id("ru.mipt.npm.gradle.native") } description = "Functions, integration and interpolation" diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Piecewise.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Piecewise.kt index 73fa57c7b..4225a7572 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Piecewise.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Piecewise.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.functions @@ -14,7 +14,7 @@ import space.kscience.kmath.operations.Ring * @param T the piece key type. * @param R the sub-function type. */ -public fun interface Piecewise { +public fun interface Piecewise { /** * Returns the appropriate sub-function for given piece key. */ @@ -23,12 +23,14 @@ public fun interface Piecewise { /** * Represents piecewise-defined function where all the sub-functions are polynomials. - * @param pieces An ordered list of range-polynomial pairs. The list does not in general guarantee that there are no "holes" in it. + * + * @property pieces An ordered list of range-polynomial pairs. The list does not in general guarantee that there are no + * "holes" in it. */ public interface PiecewisePolynomial> : Piecewise> { public val pieces: Collection, Polynomial>> - public override fun findPiece(arg: T): Polynomial? + override fun findPiece(arg: T): Polynomial? } /** @@ -44,8 +46,8 @@ public fun > PiecewisePolynomial( } /** - * An optimized piecewise which uses not separate pieces, but a range separated by delimiters. - * The pices search is logarithmic + * An optimized piecewise that uses not separate pieces, but a range separated by delimiters. + * The pieces search is logarithmic. */ private class OrderedPiecewisePolynomial>( override val pieces: List, Polynomial>>, @@ -77,7 +79,7 @@ public class PiecewiseBuilder>(delimiter: T) { /** * Dynamically adds a piece to the right side (beyond maximum argument value of previous piece) * - * @param right new rightmost position. If is less then current rightmost position, an error is thrown. + * @param right new rightmost position. If is less than current rightmost position, an error is thrown. * @param piece the sub-function. */ public fun putRight(right: T, piece: Polynomial) { @@ -89,7 +91,7 @@ public class PiecewiseBuilder>(delimiter: T) { /** * Dynamically adds a piece to the left side (beyond maximum argument value of previous piece) * - * @param left the new leftmost position. If is less then current rightmost position, an error is thrown. + * @param left the new leftmost position. If is less than current rightmost position, an error is thrown. * @param piece the sub-function. */ public fun putLeft(left: T, piece: Polynomial) { @@ -112,7 +114,7 @@ public fun > PiecewisePolynomial( ): PiecewisePolynomial = PiecewiseBuilder(startingPoint).apply(builder).build() /** - * Return a value of polynomial function with given [ring] an given [arg] or null if argument is outside of piecewise + * Return a value of polynomial function with given [ring] a given [arg] or null if argument is outside piecewise * definition. */ public fun , C : Ring> PiecewisePolynomial.value(ring: C, arg: T): T? = diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt index ba77d7b25..e862c0b9d 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.functions @@ -17,7 +17,7 @@ import kotlin.math.pow * * @param coefficients constant is the leftmost coefficient. */ -public class Polynomial(public val coefficients: List) { +public class Polynomial(public val coefficients: List) { override fun toString(): String = "Polynomial$coefficients" } @@ -69,7 +69,7 @@ public fun Polynomial.differentiate( public fun Polynomial.integrate( algebra: A, ): Polynomial where A : Field, A : NumericAlgebra = algebra { - val integratedCoefficients = buildList(coefficients.size + 1) { + val integratedCoefficients = buildList(coefficients.size + 1) { add(zero) coefficients.forEachIndexed{ index, t -> add(t / (number(index) + one)) } } @@ -98,23 +98,23 @@ public fun > Polynomial.integrate( public class PolynomialSpace( private val ring: C, ) : Group>, ScaleOperations> where C : Ring, C : ScaleOperations { - public override val zero: Polynomial = Polynomial(emptyList()) + override val zero: Polynomial = Polynomial(emptyList()) override fun Polynomial.unaryMinus(): Polynomial = ring { Polynomial(coefficients.map { -it }) } - public override fun add(a: Polynomial, b: Polynomial): Polynomial { - val dim = max(a.coefficients.size, b.coefficients.size) + override fun add(left: Polynomial, right: Polynomial): Polynomial { + val dim = max(left.coefficients.size, right.coefficients.size) return ring { Polynomial(List(dim) { index -> - a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero } + left.coefficients.getOrElse(index) { zero } + right.coefficients.getOrElse(index) { zero } }) } } - public override fun scale(a: Polynomial, value: Double): Polynomial = + override fun scale(a: Polynomial, value: Double): Polynomial = ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * value }) } /** diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/functionTypes.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/functionTypes.kt index 88b24c756..52b7e50db 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/functionTypes.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/functionTypes.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.functions diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt index 283f97557..2b426d204 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration @@ -10,15 +10,17 @@ import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.asBuffer import space.kscience.kmath.structures.indices - - /** * A simple one-pass integrator based on Gauss rule * Following integrand features are accepted: - * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GaussLegendreRuleFactory] - * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval. - * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 10 points. - * [UnivariateIntegrandRanges] - Set of ranges and number of points per range. Defaults to given [IntegrationRange] and [IntegrandMaxCalls] + * + * * [GaussIntegratorRuleFactory]—a factory for computing the Gauss integration rule. By default, uses + * [GaussLegendreRuleFactory]. + * * [IntegrationRange]—the univariate range of integration. By default, uses `0..1` interval. + * * [IntegrandMaxCalls]—the maximum number of function calls during integration. For non-iterative rules, always + * uses the maximum number of points. By default, uses 10 points. + * * [UnivariateIntegrandRanges]—set of ranges and number of points per range. Defaults to given + * [IntegrationRange] and [IntegrandMaxCalls]. */ public class GaussIntegrator( public val algebra: Field, @@ -51,7 +53,7 @@ public class GaussIntegrator( } } - override fun integrate(integrand: UnivariateIntegrand): UnivariateIntegrand = with(algebra) { + override fun process(integrand: UnivariateIntegrand): UnivariateIntegrand = with(algebra) { val f = integrand.function val (points, weights) = buildRule(integrand) var res = zero @@ -71,14 +73,15 @@ public class GaussIntegrator( } /** - * Create a Gauss-Legendre integrator for this field + * Create a Gauss integrator for this field. By default, uses Legendre rule to compute points and weights. + * Custom rules could be provided by [GaussIntegratorRuleFactory] feature. * @see [GaussIntegrator] */ -public val Field.gaussIntegrator: GaussIntegrator get() = GaussIntegrator(this) +public val Field.gaussIntegrator: GaussIntegrator get() = GaussIntegrator(this) /** - * Integrate using [intervals] segments with Gauss-Legendre rule of [order] order + * Integrate using [intervals] segments with Gauss-Legendre rule of [order] order. */ @UnstableKMathAPI public fun GaussIntegrator.integrate( @@ -95,7 +98,7 @@ public fun GaussIntegrator.integrate( val ranges = UnivariateIntegrandRanges( (0 until intervals).map { i -> (range.start + rangeSize * i)..(range.start + rangeSize * (i + 1)) to order } ) - return integrate( + return process( UnivariateIntegrand( function, IntegrationRange(range), diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegratorRuleFactory.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegratorRuleFactory.kt index 594ca9940..94c73832b 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegratorRuleFactory.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegratorRuleFactory.kt @@ -1,14 +1,14 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration +import space.kscience.kmath.operations.map import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.asBuffer -import space.kscience.kmath.structures.map import kotlin.jvm.Synchronized import kotlin.math.ulp import kotlin.native.concurrent.ThreadLocal @@ -72,7 +72,7 @@ public object GaussLegendreRuleFactory : GaussIntegratorRuleFactory { } // Get previous rule. - // If it has not been computed yet it will trigger a recursive call + // If it has not been computed, yet it will trigger a recursive call // to this method. val previousPoints: Buffer = getOrBuildRule(numPoints - 1).first @@ -146,7 +146,7 @@ public object GaussLegendreRuleFactory : GaussIntegratorRuleFactory { } // If "numPoints" is odd, 0 is a root. // Note: as written, the test for oddness will work for negative - // integers too (although it is not necessary here), preventing + // integers too (although it is unnecessary here), preventing // a FindBugs warning. if (numPoints % 2 != 0) { var pmc = 1.0 diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrand.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrand.kt index f9c26e88b..ca96e80fe 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrand.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrand.kt @@ -1,24 +1,27 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration +import space.kscience.kmath.misc.Feature +import space.kscience.kmath.misc.FeatureSet +import space.kscience.kmath.misc.Featured import kotlin.reflect.KClass -public interface IntegrandFeature { +public interface IntegrandFeature : Feature { override fun toString(): String } -public interface Integrand { - public val features: Set - public fun getFeature(type: KClass): T? +public interface Integrand : Featured { + public val features: FeatureSet + override fun getFeature(type: KClass): T? = features.getFeature(type) } public inline fun Integrand.getFeature(): T? = getFeature(T::class) -public class IntegrandValue(public val value: T) : IntegrandFeature { +public class IntegrandValue(public val value: T) : IntegrandFeature { override fun toString(): String = "Value($value)" } diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrator.kt index abe6ea5ff..868ecd0fd 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration @@ -12,5 +12,5 @@ public interface Integrator { /** * Runs one integration pass and return a new [Integrand] with a new set of features. */ - public fun integrate(integrand: I): I + public fun process(integrand: I): I } diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/MultivariateIntegrand.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/MultivariateIntegrand.kt index 5ba411bf9..1546894f5 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/MultivariateIntegrand.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/MultivariateIntegrand.kt @@ -1,34 +1,26 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration import space.kscience.kmath.linear.Point -import kotlin.reflect.KClass +import space.kscience.kmath.misc.FeatureSet public class MultivariateIntegrand internal constructor( - private val featureMap: Map, IntegrandFeature>, + override val features: FeatureSet, public val function: (Point) -> T, ) : Integrand { - override val features: Set get() = featureMap.values.toSet() - - @Suppress("UNCHECKED_CAST") - override fun getFeature(type: KClass): T? = featureMap[type] as? T - - public operator fun plus(pair: Pair, F>): MultivariateIntegrand = - MultivariateIntegrand(featureMap + pair, function) - public operator fun plus(feature: F): MultivariateIntegrand = - plus(feature::class to feature) + MultivariateIntegrand(features.with(feature), function) } @Suppress("FunctionName") public fun MultivariateIntegrand( vararg features: IntegrandFeature, function: (Point) -> T, -): MultivariateIntegrand = MultivariateIntegrand(features.associateBy { it::class }, function) +): MultivariateIntegrand = MultivariateIntegrand(FeatureSet.of(*features), function) public val MultivariateIntegrand.value: T? get() = getFeature>()?.value diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SimpsonIntegrator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SimpsonIntegrator.kt index baa9d4af8..f65cc8423 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SimpsonIntegrator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SimpsonIntegrator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration @@ -13,9 +13,10 @@ import space.kscience.kmath.operations.sum /** * Use double pass Simpson rule integration with a fixed number of points. - * Requires [UnivariateIntegrandRanges] or [IntegrationRange] and [IntegrandMaxCalls] - * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval. - * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 10 points. + * Requires [UnivariateIntegrandRanges] or [IntegrationRange] and [IntegrandMaxCalls]. + * * [IntegrationRange]—the univariate range of integration. By default, uses `0..1` interval. + * * [IntegrandMaxCalls]—the maximum number of function calls during integration. For non-iterative rules, always + * uses the maximum number of points. By default, uses 10 points. */ @UnstableKMathAPI public class SimpsonIntegrator( @@ -43,7 +44,7 @@ public class SimpsonIntegrator( return res } - override fun integrate(integrand: UnivariateIntegrand): UnivariateIntegrand { + override fun process(integrand: UnivariateIntegrand): UnivariateIntegrand { val ranges = integrand.getFeature() return if (ranges != null) { val res = algebra.sum(ranges.ranges.map { integrateRange(integrand, it.first, it.second) }) @@ -63,12 +64,12 @@ public val Field.simpsonIntegrator: SimpsonIntegrator get() = Si /** * Use double pass Simpson rule integration with a fixed number of points. - * Requires [UnivariateIntegrandRanges] or [IntegrationRange] and [IntegrandMaxCalls] - * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval. - * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 10 points. + * Requires [UnivariateIntegrandRanges] or [IntegrationRange] and [IntegrandMaxCalls]. + * * [IntegrationRange]—the univariate range of integration. By default, uses `0.0..1.0` interval. + * * [IntegrandMaxCalls]—the maximum number of function calls during integration. For non-iterative rules, always uses + * the maximum number of points. By default, uses 10 points. */ public object DoubleSimpsonIntegrator : UnivariateIntegrator { - private fun integrateRange( integrand: UnivariateIntegrand, range: ClosedRange, numPoints: Int, ): Double { @@ -89,7 +90,7 @@ public object DoubleSimpsonIntegrator : UnivariateIntegrator { return res } - override fun integrate(integrand: UnivariateIntegrand): UnivariateIntegrand { + override fun process(integrand: UnivariateIntegrand): UnivariateIntegrand { val ranges = integrand.getFeature() return if (ranges != null) { val res = ranges.ranges.sumOf { integrateRange(integrand, it.first, it.second) } diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SplineIntegrator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SplineIntegrator.kt index 23d7bdd8d..6abe89aad 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SplineIntegrator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/SplineIntegrator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration @@ -10,19 +10,17 @@ import space.kscience.kmath.functions.integrate import space.kscience.kmath.interpolation.PolynomialInterpolator import space.kscience.kmath.interpolation.SplineInterpolator import space.kscience.kmath.interpolation.interpolatePolynomials +import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.Field -import space.kscience.kmath.operations.invoke -import space.kscience.kmath.operations.sum +import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.MutableBufferFactory -import space.kscience.kmath.structures.map /** * Compute analytical indefinite integral of this [PiecewisePolynomial], keeping all intervals intact */ +@OptIn(PerformancePitfall::class) @UnstableKMathAPI public fun > PiecewisePolynomial.integrate(algebra: Field): PiecewisePolynomial = PiecewisePolynomial(pieces.map { it.first to it.second.integrate(algebra) }) @@ -45,18 +43,20 @@ public fun > PiecewisePolynomial.integrate( /** * A generic spline-interpolation-based analytic integration - * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval. - * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 10 points. + * * [IntegrationRange]—the univariate range of integration. By default, uses `0..1` interval. + * * [IntegrandMaxCalls]—the maximum number of function calls during integration. For non-iterative rules, always uses + * the maximum number of points. By default, uses 10 points. */ @UnstableKMathAPI public class SplineIntegrator>( public val algebra: Field, public val bufferFactory: MutableBufferFactory, ) : UnivariateIntegrator { - override fun integrate(integrand: UnivariateIntegrand): UnivariateIntegrand = algebra { + override fun process(integrand: UnivariateIntegrand): UnivariateIntegrand = algebra { val range = integrand.getFeature()?.range ?: 0.0..1.0 val interpolator: PolynomialInterpolator = SplineInterpolator(algebra, bufferFactory) + val nodes: Buffer = integrand.getFeature()?.nodes ?: run { val numPoints = integrand.getFeature()?.maxCalls ?: 100 val step = (range.endInclusive - range.start) / (numPoints - 1) @@ -75,15 +75,16 @@ public class SplineIntegrator>( /** * A simplified double-based spline-interpolation-based analytic integration - * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval. - * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 10 points. + * * [IntegrationRange]—the univariate range of integration. By default, uses `0.0..1.0` interval. + * * [IntegrandMaxCalls]—the maximum number of function calls during integration. For non-iterative rules, always + * uses the maximum number of points. By default, uses 10 points. */ @UnstableKMathAPI public object DoubleSplineIntegrator : UnivariateIntegrator { - override fun integrate(integrand: UnivariateIntegrand): UnivariateIntegrand { + override fun process(integrand: UnivariateIntegrand): UnivariateIntegrand { val range = integrand.getFeature()?.range ?: 0.0..1.0 - val interpolator: PolynomialInterpolator = SplineInterpolator(DoubleField, ::DoubleBuffer) + val nodes: Buffer = integrand.getFeature()?.nodes ?: run { val numPoints = integrand.getFeature()?.maxCalls ?: 100 val step = (range.endInclusive - range.start) / (numPoints - 1) diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt index e265f54e8..bd2a20594 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt @@ -1,37 +1,28 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration +import space.kscience.kmath.misc.FeatureSet import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.DoubleBuffer -import kotlin.reflect.KClass public class UnivariateIntegrand internal constructor( - private val featureMap: Map, IntegrandFeature>, + override val features: FeatureSet, public val function: (Double) -> T, ) : Integrand { - - override val features: Set get() = featureMap.values.toSet() - - @Suppress("UNCHECKED_CAST") - override fun getFeature(type: KClass): T? = featureMap[type] as? T - - public operator fun plus(pair: Pair, F>): UnivariateIntegrand = - UnivariateIntegrand(featureMap + pair, function) - public operator fun plus(feature: F): UnivariateIntegrand = - plus(feature::class to feature) + UnivariateIntegrand(features.with(feature), function) } @Suppress("FunctionName") public fun UnivariateIntegrand( function: (Double) -> T, vararg features: IntegrandFeature, -): UnivariateIntegrand = UnivariateIntegrand(features.associateBy { it::class }, function) +): UnivariateIntegrand = UnivariateIntegrand(FeatureSet.of(*features), function) public typealias UnivariateIntegrator = Integrator> @@ -40,8 +31,8 @@ public class IntegrationRange(public val range: ClosedRange) : Integrand } /** - * Set of univariate integration ranges. First components correspond to ranges themselves, second components to number of - * integration nodes per range + * Set of univariate integration ranges. First components correspond to the ranges themselves, second components to + * number of integration nodes per range. */ public class UnivariateIntegrandRanges(public val ranges: List, Int>>) : IntegrandFeature { public constructor(vararg pairs: Pair, Int>) : this(pairs.toList()) @@ -79,7 +70,7 @@ public val UnivariateIntegrand.value: T get() = valueOrNull ?: erro public fun UnivariateIntegrator.integrate( vararg features: IntegrandFeature, function: (Double) -> T, -): UnivariateIntegrand = integrate(UnivariateIntegrand(function, *features)) +): UnivariateIntegrand = process(UnivariateIntegrand(function, *features)) /** * A shortcut method to integrate a [function] in [range] with additional [features]. @@ -90,10 +81,10 @@ public fun UnivariateIntegrator.integrate( range: ClosedRange, vararg features: IntegrandFeature, function: (Double) -> T, -): UnivariateIntegrand = integrate(UnivariateIntegrand(function, IntegrationRange(range), *features)) +): UnivariateIntegrand = process(UnivariateIntegrand(function, IntegrationRange(range), *features)) /** - * A shortcut method to integrate a [function] in [range] with additional [features]. + * A shortcut method to integrate a [function] in [range] with additional features. * The [function] is placed in the end position to allow passing a lambda. */ @UnstableKMathAPI @@ -107,5 +98,5 @@ public fun UnivariateIntegrator.integrate( featureBuilder() add(IntegrationRange(range)) } - return integrate(UnivariateIntegrand(function, *features.toTypedArray())) + return process(UnivariateIntegrand(function, *features.toTypedArray())) } diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/Interpolator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/Interpolator.kt index c9ec0d527..5f89a9619 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/Interpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/Interpolator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ @file:OptIn(UnstableKMathAPI::class) @@ -18,7 +18,7 @@ import space.kscience.kmath.structures.asBuffer /** * And interpolator for data with x column type [X], y column type [Y]. */ -public fun interface Interpolator { +public fun interface Interpolator { public fun interpolate(points: XYColumnarData): (X) -> Y } @@ -42,20 +42,20 @@ public fun > PolynomialInterpolator.interpolatePolynomials( x: Buffer, y: Buffer, ): PiecewisePolynomial { - val pointSet = XYColumnarData(x, y) + val pointSet = XYColumnarData.of(x, y) return interpolatePolynomials(pointSet) } public fun > PolynomialInterpolator.interpolatePolynomials( data: Map, ): PiecewisePolynomial { - val pointSet = XYColumnarData(data.keys.toList().asBuffer(), data.values.toList().asBuffer()) + val pointSet = XYColumnarData.of(data.keys.toList().asBuffer(), data.values.toList().asBuffer()) return interpolatePolynomials(pointSet) } public fun > PolynomialInterpolator.interpolatePolynomials( data: List>, ): PiecewisePolynomial { - val pointSet = XYColumnarData(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer()) + val pointSet = XYColumnarData.of(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer()) return interpolatePolynomials(pointSet) } diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/LinearInterpolator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/LinearInterpolator.kt index 24c049647..eff9cd97d 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/LinearInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/LinearInterpolator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.interpolation @@ -21,9 +21,9 @@ internal fun > insureSorted(points: XYColumnarData<*, T, *>) { /** * Reference JVM implementation: https://github.com/apache/commons-math/blob/master/src/main/java/org/apache/commons/math4/analysis/interpolation/LinearInterpolator.java */ -public class LinearInterpolator>(public override val algebra: Field) : PolynomialInterpolator { +public class LinearInterpolator>(override val algebra: Field) : PolynomialInterpolator { @OptIn(UnstableKMathAPI::class) - public override fun interpolatePolynomials(points: XYColumnarData): PiecewisePolynomial = algebra { + override fun interpolatePolynomials(points: XYColumnarData): PiecewisePolynomial = algebra { require(points.size > 0) { "Point array should not be empty" } insureSorted(points) diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt index bf291c315..ac9708d01 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/interpolation/SplineInterpolator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.interpolation @@ -23,13 +23,13 @@ import space.kscience.kmath.structures.MutableBufferFactory * https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java */ public class SplineInterpolator>( - public override val algebra: Field, + override val algebra: Field, public val bufferFactory: MutableBufferFactory, ) : PolynomialInterpolator { //TODO possibly optimize zeroed buffers @OptIn(UnstableKMathAPI::class) - public override fun interpolatePolynomials(points: XYColumnarData): PiecewisePolynomial = algebra { + override fun interpolatePolynomials(points: XYColumnarData): PiecewisePolynomial = algebra { require(points.size >= 3) { "Can't use spline interpolator with less than 3 points" } insureSorted(points) // Number of intervals. The number of data points is n + 1. diff --git a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/functions/PolynomialTest.kt b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/functions/PolynomialTest.kt index 05c16d17e..21e5473a0 100644 --- a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/functions/PolynomialTest.kt +++ b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/functions/PolynomialTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.functions diff --git a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/GaussIntegralTest.kt b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/GaussIntegralTest.kt index 9f48a15ea..533389a6e 100644 --- a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/GaussIntegralTest.kt +++ b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/GaussIntegralTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration diff --git a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/SimpsonIntegralTest.kt b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/SimpsonIntegralTest.kt index 9f2d71554..eaf7abbfd 100644 --- a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/SimpsonIntegralTest.kt +++ b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/SimpsonIntegralTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration diff --git a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/SplineIntegralTest.kt b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/SplineIntegralTest.kt index afeba0be4..4dffb276f 100644 --- a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/SplineIntegralTest.kt +++ b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/SplineIntegralTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.integration diff --git a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/interpolation/LinearInterpolatorTest.kt b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/interpolation/LinearInterpolatorTest.kt index bec678bae..c3388c265 100644 --- a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/interpolation/LinearInterpolatorTest.kt +++ b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/interpolation/LinearInterpolatorTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.interpolation diff --git a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/interpolation/SplineInterpolatorTest.kt b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/interpolation/SplineInterpolatorTest.kt index 3adaab2d1..42f41ab80 100644 --- a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/interpolation/SplineInterpolatorTest.kt +++ b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/interpolation/SplineInterpolatorTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.interpolation diff --git a/kmath-geometry/build.gradle.kts b/kmath-geometry/build.gradle.kts index 9b6e593b2..7eb814683 100644 --- a/kmath-geometry/build.gradle.kts +++ b/kmath-geometry/build.gradle.kts @@ -1,6 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") + id("ru.mipt.npm.gradle.native") } kotlin.sourceSets.commonMain { diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt index 2a4837ee0..5e3cbff83 100644 --- a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean2DSpace.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.geometry @@ -12,22 +12,22 @@ import space.kscience.kmath.operations.invoke import kotlin.math.sqrt @OptIn(UnstableKMathAPI::class) -public interface Vector2D : Point, Vector{ +public interface Vector2D : Point, Vector { public val x: Double public val y: Double - public override val size: Int get() = 2 + override val size: Int get() = 2 - public override operator fun get(index: Int): Double = when (index) { - 1 -> x - 2 -> y + override operator fun get(index: Int): Double = when (index) { + 0 -> x + 1 -> y else -> error("Accessing outside of point bounds") } - public override operator fun iterator(): Iterator = listOf(x, y).iterator() + override operator fun iterator(): Iterator = listOf(x, y).iterator() } public val Vector2D.r: Double - get() = Euclidean2DSpace { sqrt(norm()) } + get() = Euclidean2DSpace { norm() } @Suppress("FunctionName") public fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y) @@ -41,13 +41,13 @@ private data class Vector2DImpl( * 2D Euclidean space */ public object Euclidean2DSpace : GeometrySpace, ScaleOperations { - public override val zero: Vector2D by lazy { Vector2D(0.0, 0.0) } + override val zero: Vector2D by lazy { Vector2D(0.0, 0.0) } public fun Vector2D.norm(): Double = sqrt(x * x + y * y) override fun Vector2D.unaryMinus(): Vector2D = Vector2D(-x, -y) - public override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm() - public override fun add(a: Vector2D, b: Vector2D): Vector2D = Vector2D(a.x + b.x, a.y + b.y) - public override fun scale(a: Vector2D, value: Double): Vector2D = Vector2D(a.x * value, a.y * value) - public override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y + override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm() + override fun add(left: Vector2D, right: Vector2D): Vector2D = Vector2D(left.x + right.x, left.y + right.y) + override fun scale(a: Vector2D, value: Double): Vector2D = Vector2D(a.x * value, a.y * value) + override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y } diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean3DSpace.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean3DSpace.kt index 37e7d2cb2..96f307ed6 100644 --- a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean3DSpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Euclidean3DSpace.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.geometry @@ -16,22 +16,22 @@ public interface Vector3D : Point, Vector { public val x: Double public val y: Double public val z: Double - public override val size: Int get() = 3 + override val size: Int get() = 3 - public override operator fun get(index: Int): Double = when (index) { - 1 -> x - 2 -> y - 3 -> z + override operator fun get(index: Int): Double = when (index) { + 0 -> x + 1 -> y + 2 -> z else -> error("Accessing outside of point bounds") } - public override operator fun iterator(): Iterator = listOf(x, y, z).iterator() + override operator fun iterator(): Iterator = listOf(x, y, z).iterator() } @Suppress("FunctionName") public fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z) -public val Vector3D.r: Double get() = Euclidean3DSpace { sqrt(norm()) } +public val Vector3D.r: Double get() = Euclidean3DSpace { norm() } private data class Vector3DImpl( override val x: Double, @@ -40,19 +40,19 @@ private data class Vector3DImpl( ) : Vector3D public object Euclidean3DSpace : GeometrySpace, ScaleOperations { - public override val zero: Vector3D by lazy { Vector3D(0.0, 0.0, 0.0) } + override val zero: Vector3D by lazy { Vector3D(0.0, 0.0, 0.0) } public fun Vector3D.norm(): Double = sqrt(x * x + y * y + z * z) override fun Vector3D.unaryMinus(): Vector3D = Vector3D(-x, -y, -z) - public override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm() + override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm() - public override fun add(a: Vector3D, b: Vector3D): Vector3D = - Vector3D(a.x + b.x, a.y + b.y, a.z + b.z) + override fun add(left: Vector3D, right: Vector3D): Vector3D = + Vector3D(left.x + right.x, left.y + right.y, left.z + right.z) - public override fun scale(a: Vector3D, value: Double): Vector3D = + override fun scale(a: Vector3D, value: Double): Vector3D = Vector3D(a.x * value, a.y * value, a.z * value) - public override fun Vector3D.dot(other: Vector3D): Double = + override fun Vector3D.dot(other: Vector3D): Double = x * other.x + y * other.y + z * other.z } diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/GeometrySpace.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/GeometrySpace.kt index d4245c744..3d3f8b653 100644 --- a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/GeometrySpace.kt +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/GeometrySpace.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.geometry diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Line.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Line.kt index 5a6d23709..d9dc57ec2 100644 --- a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Line.kt +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Line.kt @@ -1,11 +1,11 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.geometry -public data class Line(val base: V, val direction: V) +public data class Line(val base: V, val direction: V) public typealias Line2D = Line public typealias Line3D = Line diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Projections.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Projections.kt new file mode 100644 index 000000000..205bc17e7 --- /dev/null +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/Projections.kt @@ -0,0 +1,20 @@ +package space.kscience.kmath.geometry + +/** + * Project vector onto a line. + * @param vector to project + * @param line line to which vector should be projected + */ +public fun GeometrySpace.projectToLine(vector: V, line: Line): V = with(line) { + base + (direction dot (vector - base)) / (direction dot direction) * direction +} + +/** + * Project vector onto a hyperplane, which is defined by a normal and base. + * In 2D case it is the projection to a line, in 3d case it is the one to a plane. + * @param vector to project + * @param normal normal (perpendicular) vector to a hyper-plane to which vector should be projected + * @param base point belonging to a hyper-plane to which vector should be projected + */ +public fun GeometrySpace.projectAlong(vector: V, normal: V, base: V): V = + vector + normal * ((base - vector) dot normal) / (normal dot normal) diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/ReferenceFrame.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/ReferenceFrame.kt index a7a28b596..7bb95c009 100644 --- a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/ReferenceFrame.kt +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/ReferenceFrame.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.geometry diff --git a/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Euclidean2DSpaceTest.kt b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Euclidean2DSpaceTest.kt new file mode 100644 index 000000000..5913b2fa9 --- /dev/null +++ b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Euclidean2DSpaceTest.kt @@ -0,0 +1,62 @@ +package space.kscience.kmath.geometry + +import kotlin.math.sqrt +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class Euclidean2DSpaceTest { + @Test + fun zero() { + assertVectorEquals(Vector2D(0.0, 0.0), Euclidean2DSpace.zero) + } + + @Test + fun norm() { + with(Euclidean2DSpace) { + assertEquals(0.0, zero.norm()) + assertEquals(1.0, Vector2D(1.0, 0.0).norm()) + assertEquals(sqrt(2.0), Vector2D(1.0, 1.0).norm()) + assertEquals(sqrt(5.002001), Vector2D(-2.0, 1.001).norm()) + } + } + + @Test + fun dotProduct() { + with(Euclidean2DSpace) { + assertEquals(0.0, zero dot zero) + assertEquals(0.0, zero dot Vector2D(1.0, 0.0)) + assertEquals(0.0, Vector2D(-2.0, 0.001) dot zero) + assertEquals(0.0, Vector2D(1.0, 0.0) dot Vector2D(0.0, 1.0)) + + assertEquals(1.0, Vector2D(1.0, 0.0) dot Vector2D(1.0, 0.0)) + assertEquals(-2.0, Vector2D(0.0, 1.0) dot Vector2D(1.0, -2.0)) + assertEquals(2.0, Vector2D(1.0, 1.0) dot Vector2D(1.0, 1.0)) + assertEquals(4.001001, Vector2D(-2.0, 1.001) dot Vector2D(-2.0, 0.001)) + + assertEquals(-4.998, Vector2D(1.0, 2.0) dot Vector2D(-5.0, 0.001)) + } + } + + @Test + fun add() { + with(Euclidean2DSpace) { + assertVectorEquals( + Vector2D(-2.0, 0.001), + Vector2D(-2.0, 0.001) + zero + ) + assertVectorEquals( + Vector2D(-3.0, 3.001), + Vector2D(2.0, 3.0) + Vector2D(-5.0, 0.001) + ) + } + } + + @Test + fun multiply() { + with(Euclidean2DSpace) { + assertVectorEquals(Vector2D(-4.0, 0.0), Vector2D(-2.0, 0.0) * 2) + assertVectorEquals(Vector2D(4.0, 0.0), Vector2D(-2.0, 0.0) * -2) + assertVectorEquals(Vector2D(300.0, 0.0003), Vector2D(100.0, 0.0001) * 3) + } + } +} diff --git a/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Euclidean3DSpaceTest.kt b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Euclidean3DSpaceTest.kt new file mode 100644 index 000000000..2c74cbd27 --- /dev/null +++ b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Euclidean3DSpaceTest.kt @@ -0,0 +1,74 @@ +package space.kscience.kmath.geometry + +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class Euclidean3DSpaceTest { + @Test + fun zero() { + assertVectorEquals(Vector3D(0.0, 0.0, 0.0), Euclidean3DSpace.zero) + } + + @Test + fun distance() { + with(Euclidean3DSpace) { + assertEquals(0.0, zero.distanceTo(zero)) + assertEquals(1.0, zero.distanceTo(Vector3D(1.0, 0.0, 0.0))) + assertEquals(kotlin.math.sqrt(5.000001), Vector3D(1.0, -2.0, 0.001).distanceTo(zero)) + assertEquals(0.0, Vector3D(1.0, -2.0, 0.001).distanceTo(Vector3D(1.0, -2.0, 0.001))) + assertEquals(0.0, Vector3D(1.0, 0.0, 0.0).distanceTo(Vector3D(1.0, 0.0, 0.0))) + assertEquals(kotlin.math.sqrt(2.0), Vector3D(1.0, 0.0, 0.0).distanceTo(Vector3D(1.0, 1.0, 1.0))) + assertEquals(3.1622778182822584, Vector3D(0.0, 1.0, 0.0).distanceTo(Vector3D(1.0, -2.0, 0.001))) + assertEquals(0.0, Vector3D(1.0, -2.0, 0.001).distanceTo(Vector3D(1.0, -2.0, 0.001))) + assertEquals(9.695050335093676, Vector3D(1.0, 2.0, 3.0).distanceTo(Vector3D(7.0, -5.0, 0.001))) + } + } + + @Test + fun norm() { + with(Euclidean3DSpace) { + assertEquals(0.0, zero.norm()) + assertEquals(1.0, Vector3D(1.0, 0.0, 0.0).norm()) + assertEquals(kotlin.math.sqrt(3.0), Vector3D(1.0, 1.0, 1.0).norm()) + assertEquals(kotlin.math.sqrt(5.000001), Vector3D(1.0, -2.0, 0.001).norm()) + } + } + + @Test + fun dotProduct() { + with(Euclidean3DSpace) { + assertEquals(0.0, zero dot zero) + assertEquals(0.0, zero dot Vector3D(1.0, 0.0, 0.0)) + assertEquals(0.0, Vector3D(1.0, -2.0, 0.001) dot zero) + + assertEquals(1.0, Vector3D(1.0, 0.0, 0.0) dot Vector3D(1.0, 0.0, 0.0)) + assertEquals(1.0, Vector3D(1.0, 0.0, 0.0) dot Vector3D(1.0, 1.0, 1.0)) + assertEquals(-2.0, Vector3D(0.0, 1.0, 0.0) dot Vector3D(1.0, -2.0, 0.001)) + assertEquals(3.0, Vector3D(1.0, 1.0, 1.0) dot Vector3D(1.0, 1.0, 1.0)) + assertEquals(5.000001, Vector3D(1.0, -2.0, 0.001) dot Vector3D(1.0, -2.0, 0.001)) + + assertEquals(-2.997, Vector3D(1.0, 2.0, 3.0) dot Vector3D(7.0, -5.0, 0.001)) + } + } + + @Test + fun add() { + with(Euclidean3DSpace) { + assertVectorEquals( + Vector3D(1.0, -2.0, 0.001), + Vector3D(1.0, -2.0, 0.001) + zero + ) + assertVectorEquals( + Vector3D(8.0, -3.0, 3.001), + Vector3D(1.0, 2.0, 3.0) + Vector3D(7.0, -5.0, 0.001) + ) + } + } + + @Test + fun multiply() { + with(Euclidean3DSpace) { + assertVectorEquals(Vector3D(2.0, -4.0, 0.0), Vector3D(1.0, -2.0, 0.0) * 2) + } + } +} diff --git a/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/ProjectionAlongTest.kt b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/ProjectionAlongTest.kt new file mode 100644 index 000000000..55fc39aad --- /dev/null +++ b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/ProjectionAlongTest.kt @@ -0,0 +1,61 @@ +package space.kscience.kmath.geometry + +import kotlin.test.Test +import kotlin.test.assertTrue + +internal class ProjectionAlongTest { + @Test + fun projectionIntoYEqualsX() { + with(Euclidean2DSpace) { + val normal = Vector2D(-2.0, 2.0) + val base = Vector2D(2.3, 2.3) + + assertVectorEquals(zero, projectAlong(zero, normal, base)) + + grid(-10.0..10.0, -10.0..10.0, 0.15).forEach { (x, y) -> + val d = (y - x) / 2.0 + assertVectorEquals(Vector2D(x + d, y - d), projectAlong(Vector2D(x, y), normal, base)) + } + } + } + + @Test + fun projectionOntoLine() { + with(Euclidean2DSpace) { + val a = 5.0 + val b = -3.0 + val c = -15.0 + val normal = Vector2D(-5.0, 3.0) + val base = Vector2D(3.0, 0.0) + + grid(-10.0..10.0, -10.0..10.0, 0.15).forEach { (x, y) -> + val xProj = (b * (b * x - a * y) - a * c) / (a * a + b * b) + val yProj = (a * (-b * x + a * y) - b * c) / (a * a + b * b) + assertVectorEquals(Vector2D(xProj, yProj), projectAlong(Vector2D(x, y), normal, base)) + } + } + } + + @Test + fun projectOntoPlane() { + val normal = Vector3D(1.0, 3.5, 0.07) + val base = Vector3D(2.0, -0.0037, 11.1111) + + with(Euclidean3DSpace) { + val testDomain = (-10.0..10.0).generateList(0.43) + for (x in testDomain) { + for (y in testDomain) { + for (z in testDomain) { + val v = Vector3D(x, y, z) + val result = projectAlong(v, normal, base) + + // assert that result is on plane + assertTrue(isOrthogonal(result - base, normal)) + // assert that PV vector is collinear to normal vector + assertTrue(isCollinear(v - result, normal)) + } + } + } + } + } +} diff --git a/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/ProjectionOntoLineTest.kt b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/ProjectionOntoLineTest.kt new file mode 100644 index 000000000..ab6ef3628 --- /dev/null +++ b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/ProjectionOntoLineTest.kt @@ -0,0 +1,83 @@ +package space.kscience.kmath.geometry + +import kotlin.test.Test +import kotlin.test.assertTrue + +internal class ProjectionOntoLineTest { + @Test + fun projectionIntoOx() { + with(Euclidean2DSpace) { + val ox = Line(zero, Vector2D(1.0, 0.0)) + + grid(-10.0..10.0, -10.0..10.0, 0.15).forEach { (x, y) -> + assertVectorEquals(Vector2D(x, 0.0), projectToLine(Vector2D(x, y), ox)) + } + } + } + + @Test + fun projectionIntoOy() { + with(Euclidean2DSpace) { + val line = Line(zero, Vector2D(0.0, 1.0)) + + grid(-10.0..10.0, -10.0..10.0, 0.15).forEach { (x, y) -> + assertVectorEquals(Vector2D(0.0, y), projectToLine(Vector2D(x, y), line)) + } + } + } + + @Test + fun projectionIntoYEqualsX() { + with(Euclidean2DSpace) { + val line = Line(zero, Vector2D(1.0, 1.0)) + + assertVectorEquals(zero, projectToLine(zero, line)) + + grid(-10.0..10.0, -10.0..10.0, 0.15).forEach { (x, y) -> + val d = (y - x) / 2.0 + assertVectorEquals(Vector2D(x + d, y - d), projectToLine(Vector2D(x, y), line)) + } + } + } + + @Test + fun projectionOntoLine2d() { + with(Euclidean2DSpace) { + val a = 5.0 + val b = -3.0 + val c = -15.0 + val line = Line(Vector2D(3.0, 0.0), Vector2D(3.0, 5.0)) + + grid(-10.0..10.0, -10.0..10.0, 0.15).forEach { (x, y) -> + val xProj = (b * (b * x - a * y) - a * c) / (a * a + b * b) + val yProj = (a * (-b * x + a * y) - b * c) / (a * a + b * b) + assertVectorEquals(Vector2D(xProj, yProj), projectToLine(Vector2D(x, y), line)) + } + } + } + + @Test + fun projectionOntoLine3d() { + val line = Line3D( + base = Vector3D(1.0, 3.5, 0.07), + direction = Vector3D(2.0, -0.0037, 11.1111) + ) + + with(Euclidean3DSpace) { + val testDomain = (-10.0..10.0).generateList(0.43) + for (x in testDomain) { + for (y in testDomain) { + for (z in testDomain) { + val v = Vector3D(x, y, z) + val result = projectToLine(v, line) + + // assert that result is on line + assertTrue(isCollinear(result - line.base, line.direction)) + // assert that PV vector is orthogonal to direction vector + assertTrue(isOrthogonal(v - result, line.direction)) + } + } + } + } + } +} diff --git a/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Vector2DTest.kt b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Vector2DTest.kt new file mode 100644 index 000000000..89ee23354 --- /dev/null +++ b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Vector2DTest.kt @@ -0,0 +1,35 @@ +package space.kscience.kmath.geometry + +import space.kscience.kmath.operations.toList +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class Vector2DTest { + private val vector = Vector2D(1.0, -7.999) + + @Test + fun size() { + assertEquals(2, vector.size) + } + + @Test + fun get() { + assertEquals(1.0, vector[0]) + assertEquals(-7.999, vector[1]) + } + + @Test + fun iterator() { + assertEquals(listOf(1.0, -7.999), vector.toList()) + } + + @Test + fun x() { + assertEquals(1.0, vector.x) + } + + @Test + fun y() { + assertEquals(-7.999, vector.y) + } +} diff --git a/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Vector3DTest.kt b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Vector3DTest.kt new file mode 100644 index 000000000..70f8f4ebd --- /dev/null +++ b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/Vector3DTest.kt @@ -0,0 +1,41 @@ +package space.kscience.kmath.geometry + +import space.kscience.kmath.operations.toList +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class Vector3DTest { + private val vector = Vector3D(1.0, -7.999, 0.001) + + @Test + fun size() { + assertEquals(3, vector.size) + } + + @Test + fun get() { + assertEquals(1.0, vector[0]) + assertEquals(-7.999, vector[1]) + assertEquals(0.001, vector[2]) + } + + @Test + fun iterator() { + assertEquals(listOf(1.0, -7.999, 0.001), vector.toList()) + } + + @Test + fun x() { + assertEquals(1.0, vector.x) + } + + @Test + fun y() { + assertEquals(-7.999, vector.y) + } + + @Test + fun z() { + assertEquals(0.001, vector.z) + } +} diff --git a/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/testUtils.kt b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/testUtils.kt new file mode 100644 index 000000000..1277c0130 --- /dev/null +++ b/kmath-geometry/src/commonTest/kotlin/space/kscience/kmath/geometry/testUtils.kt @@ -0,0 +1,41 @@ +package space.kscience.kmath.geometry + +import kotlin.math.abs +import kotlin.test.assertEquals + +fun ClosedRange.generateList(step: Double): List = generateSequence(start) { previous -> + if (previous == Double.POSITIVE_INFINITY) return@generateSequence null + val next = previous + step + if (next > endInclusive) null else next +}.toList() + +fun grid( + xRange: ClosedRange, + yRange: ClosedRange, + step: Double +): List> { + val xs = xRange.generateList(step) + val ys = yRange.generateList(step) + + return xs.flatMap { x -> ys.map { y -> x to y } } +} + +fun assertVectorEquals(expected: Vector2D, actual: Vector2D, absoluteTolerance: Double = 1e-6) { + assertEquals(expected.x, actual.x, absoluteTolerance) + assertEquals(expected.y, actual.y, absoluteTolerance) +} + +fun assertVectorEquals(expected: Vector3D, actual: Vector3D, absoluteTolerance: Double = 1e-6) { + assertEquals(expected.x, actual.x, absoluteTolerance) + assertEquals(expected.y, actual.y, absoluteTolerance) + assertEquals(expected.z, actual.z, absoluteTolerance) +} + +fun GeometrySpace.isCollinear(a: V, b: V, absoluteTolerance: Double = 1e-6): Boolean { + val aDist = a.distanceTo(zero) + val bDist = b.distanceTo(zero) + return abs(aDist) < absoluteTolerance || abs(bDist) < absoluteTolerance || abs(abs((a dot b) / (aDist * bDist)) - 1) < absoluteTolerance +} + +fun GeometrySpace.isOrthogonal(a: V, b: V, absoluteTolerance: Double = 1e-6): Boolean = + abs(a dot b) < absoluteTolerance diff --git a/kmath-histograms/build.gradle.kts b/kmath-histograms/build.gradle.kts index 2167726c0..7e511faa0 100644 --- a/kmath-histograms/build.gradle.kts +++ b/kmath-histograms/build.gradle.kts @@ -1,6 +1,7 @@ plugins { kotlin("multiplatform") id("ru.mipt.npm.gradle.common") + id("ru.mipt.npm.gradle.native") } kscience { diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Counter.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Counter.kt index 3e5d93768..4f5a1ceba 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Counter.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Counter.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.histogram diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/DoubleHistogramSpace.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/DoubleHistogramSpace.kt index e792ef767..c452edc9c 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/DoubleHistogramSpace.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/DoubleHistogramSpace.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.histogram @@ -9,6 +9,7 @@ import space.kscience.kmath.domains.Domain import space.kscience.kmath.domains.HyperSquareDomain import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.* +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.structures.* import kotlin.math.floor @@ -27,10 +28,9 @@ public class DoubleHistogramSpace( public val dimension: Int get() = lower.size - private val shape = IntArray(binNums.size) { binNums[it] + 2 } - override val histogramValueSpace: DoubleFieldND = AlgebraND.real(*shape) + override val shape: IntArray = IntArray(binNums.size) { binNums[it] + 2 } + override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape) - override val strides: Strides get() = histogramValueSpace.strides private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } /** @@ -51,7 +51,7 @@ public class DoubleHistogramSpace( val lowerBoundary = index.mapIndexed { axis, i -> when (i) { 0 -> Double.NEGATIVE_INFINITY - strides.shape[axis] - 1 -> upper[axis] + shape[axis] - 1 -> upper[axis] else -> lower[axis] + (i.toDouble()) * binSize[axis] } }.asBuffer() @@ -59,7 +59,7 @@ public class DoubleHistogramSpace( val upperBoundary = index.mapIndexed { axis, i -> when (i) { 0 -> lower[axis] - strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY + shape[axis] - 1 -> Double.POSITIVE_INFINITY else -> lower[axis] + (i.toDouble() + 1) * binSize[axis] } }.asBuffer() @@ -74,7 +74,7 @@ public class DoubleHistogramSpace( } override fun produce(builder: HistogramBuilder.() -> Unit): IndexedHistogram { - val ndCounter = StructureND.auto(strides) { Counter.real() } + val ndCounter = StructureND.auto(shape) { Counter.real() } val hBuilder = HistogramBuilder { point, value -> val index = getIndex(point) ndCounter[index].add(value.toDouble()) diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Histogram.kt index fcb5e96dc..946aa814b 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/Histogram.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.histogram @@ -11,16 +11,16 @@ import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.asBuffer /** - * The binned data element. Could be a histogram bin with a number of counts or an artificial construct + * The binned data element. Could be a histogram bin with a number of counts or an artificial construct. */ -public interface Bin : Domain { +public interface Bin : Domain { /** * The value of this bin. */ public val value: Number } -public interface Histogram> { +public interface Histogram> { /** * Find existing bin, corresponding to given coordinates */ @@ -34,7 +34,7 @@ public interface Histogram> { public val bins: Iterable } -public fun interface HistogramBuilder { +public fun interface HistogramBuilder { /** * Increment appropriate bin diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/IndexedHistogramSpace.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/IndexedHistogramSpace.kt index e5f6830c5..f36f45389 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/IndexedHistogramSpace.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/IndexedHistogramSpace.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.histogram @@ -8,8 +8,9 @@ package space.kscience.kmath.histogram import space.kscience.kmath.domains.Domain import space.kscience.kmath.linear.Point import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.nd.FieldND -import space.kscience.kmath.nd.Strides +import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND import space.kscience.kmath.operations.Group import space.kscience.kmath.operations.ScaleOperations @@ -18,9 +19,9 @@ import space.kscience.kmath.operations.invoke /** * A simple histogram bin based on domain */ -public data class DomainBin>( +public data class DomainBin>( public val domain: Domain, - public override val value: Number, + override val value: Number, ) : Bin, Domain by domain @OptIn(UnstableKMathAPI::class) @@ -34,10 +35,10 @@ public class IndexedHistogram, V : Any>( return context.produceBin(index, values[index]) } - override val dimension: Int get() = context.strides.shape.size + override val dimension: Int get() = context.shape.size override val bins: Iterable> - get() = context.strides.indices().map { + get() = DefaultStrides(context.shape).asSequence().map { context.produceBin(it, values[it]) }.asIterable() @@ -49,7 +50,7 @@ public class IndexedHistogram, V : Any>( public interface IndexedHistogramSpace, V : Any> : Group>, ScaleOperations> { //public val valueSpace: Space - public val strides: Strides + public val shape: Shape public val histogramValueSpace: FieldND //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape), /** @@ -66,10 +67,10 @@ public interface IndexedHistogramSpace, V : Any> public fun produce(builder: HistogramBuilder.() -> Unit): IndexedHistogram - override fun add(a: IndexedHistogram, b: IndexedHistogram): IndexedHistogram { - require(a.context == this) { "Can't operate on a histogram produced by external space" } - require(b.context == this) { "Can't operate on a histogram produced by external space" } - return IndexedHistogram(this, histogramValueSpace { a.values + b.values }) + override fun add(left: IndexedHistogram, right: IndexedHistogram): IndexedHistogram { + require(left.context == this) { "Can't operate on a histogram produced by external space" } + require(right.context == this) { "Can't operate on a histogram produced by external space" } + return IndexedHistogram(this, histogramValueSpace { left.values + right.values }) } override fun scale(a: IndexedHistogram, value: Double): IndexedHistogram { diff --git a/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt b/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt index 51f9dabc5..e07488741 100644 --- a/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt +++ b/kmath-histograms/src/commonTest/kotlin/space/kscience/kmath/histogram/MultivariateHistogramTest.kt @@ -1,10 +1,11 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.histogram +import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.operations.invoke import space.kscience.kmath.real.DoubleVector import kotlin.random.Random @@ -69,7 +70,7 @@ internal class MultivariateHistogramTest { } val res = histogram1 - histogram2 assertTrue { - strides.indices().all { index -> + DefaultStrides(shape).asSequence().all { index -> res.values[index] <= histogram1.values[index] } } diff --git a/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramSpace.kt b/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramSpace.kt index 8d05df68a..cc54d7e1a 100644 --- a/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramSpace.kt +++ b/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/TreeHistogramSpace.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.histogram @@ -36,9 +36,10 @@ public class TreeHistogram( } @OptIn(UnstableKMathAPI::class) -private class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder { +@PublishedApi +internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder { - private class BinCounter(val domain: UnivariateDomain, val counter: Counter = Counter.real()) : + internal class BinCounter(val domain: UnivariateDomain, val counter: Counter = Counter.real()) : ClosedFloatingPointRange by domain.range private val bins: TreeMap = TreeMap() @@ -80,27 +81,27 @@ private class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) */ @UnstableKMathAPI public class TreeHistogramSpace( - private val binFactory: (Double) -> UnivariateDomain, + @PublishedApi internal val binFactory: (Double) -> UnivariateDomain, ) : Group, ScaleOperations { - public fun fill(block: UnivariateHistogramBuilder.() -> Unit): UnivariateHistogram = + public inline fun fill(block: UnivariateHistogramBuilder.() -> Unit): UnivariateHistogram = TreeHistogramBuilder(binFactory).apply(block).build() override fun add( - a: UnivariateHistogram, - b: UnivariateHistogram, + left: UnivariateHistogram, + right: UnivariateHistogram, ): UnivariateHistogram { // require(a.context == this) { "Histogram $a does not belong to this context" } // require(b.context == this) { "Histogram $b does not belong to this context" } val bins = TreeMap().apply { - (a.bins.map { it.domain } union b.bins.map { it.domain }).forEach { def -> + (left.bins.map { it.domain } union right.bins.map { it.domain }).forEach { def -> put( def.center, UnivariateBin( def, - value = (a[def.center]?.value ?: 0.0) + (b[def.center]?.value ?: 0.0), - standardDeviation = (a[def.center]?.standardDeviation - ?: 0.0) + (b[def.center]?.standardDeviation ?: 0.0) + value = (left[def.center]?.value ?: 0.0) + (right[def.center]?.value ?: 0.0), + standardDeviation = (left[def.center]?.standardDeviation + ?: 0.0) + (right[def.center]?.standardDeviation ?: 0.0) ) ) } @@ -115,8 +116,8 @@ public class TreeHistogramSpace( bin.domain.center, UnivariateBin( bin.domain, - value = bin.value * value.toDouble(), - standardDeviation = abs(bin.standardDeviation * value.toDouble()) + value = bin.value * value, + standardDeviation = abs(bin.standardDeviation * value) ) ) } diff --git a/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/UnivariateHistogram.kt index 0ad96ad46..d5b74fb9b 100644 --- a/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/space/kscience/kmath/histogram/UnivariateHistogram.kt @@ -1,14 +1,14 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.histogram import space.kscience.kmath.domains.UnivariateDomain import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.asSequence import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.asSequence @UnstableKMathAPI @@ -16,9 +16,10 @@ public val UnivariateDomain.center: Double get() = (range.endInclusive + range.start) / 2 /** - * A univariate bin based an a range - * @param value The value of histogram including weighting - * @param standardDeviation Standard deviation of the bin value. Zero or negative if not applicable + * A univariate bin based on a range + * + * @property value The value of histogram including weighting + * @property standardDeviation Standard deviation of the bin value. Zero or negative if not applicable */ @UnstableKMathAPI public class UnivariateBin( @@ -27,15 +28,15 @@ public class UnivariateBin( public val standardDeviation: Double, ) : Bin, ClosedFloatingPointRange by domain.range { - public override val dimension: Int get() = 1 + override val dimension: Int get() = 1 - public override fun contains(point: Buffer): Boolean = point.size == 1 && contains(point[0]) + override fun contains(point: Buffer): Boolean = point.size == 1 && contains(point[0]) } @OptIn(UnstableKMathAPI::class) -public interface UnivariateHistogram : Histogram{ +public interface UnivariateHistogram : Histogram { public operator fun get(value: Double): UnivariateBin? - public override operator fun get(point: Buffer): UnivariateBin? = get(point[0]) + override operator fun get(point: Buffer): UnivariateBin? = get(point[0]) public companion object { /** diff --git a/kmath-histograms/src/jvmTest/kotlin/space/kscience/kmath/histogram/TreeHistogramTest.kt b/kmath-histograms/src/jvmTest/kotlin/space/kscience/kmath/histogram/TreeHistogramTest.kt index 28a1b03cb..e71602c7b 100644 --- a/kmath-histograms/src/jvmTest/kotlin/space/kscience/kmath/histogram/TreeHistogramTest.kt +++ b/kmath-histograms/src/jvmTest/kotlin/space/kscience/kmath/histogram/TreeHistogramTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.histogram diff --git a/kmath-jafama/README.md b/kmath-jafama/README.md index ef8fcd352..3c5d4e19d 100644 --- a/kmath-jafama/README.md +++ b/kmath-jafama/README.md @@ -7,7 +7,7 @@ Integration with [Jafama](https://github.com/jeffhain/jafama). ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-jafama:0.3.0-dev-13`. +The Maven coordinates of this project are `space.kscience:kmath-jafama:0.3.0-dev-14`. **Gradle:** ```gradle @@ -17,7 +17,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-jafama:0.3.0-dev-13' + implementation 'space.kscience:kmath-jafama:0.3.0-dev-14' } ``` **Gradle Kotlin DSL:** @@ -28,7 +28,7 @@ repositories { } dependencies { - implementation("space.kscience:kmath-jafama:0.3.0-dev-13") + implementation("space.kscience:kmath-jafama:0.3.0-dev-14") } ``` @@ -52,22 +52,4 @@ fun main() { According to KMath benchmarks on GraalVM, Jafama functions are slower than JDK math; however, there are indications that on Hotspot Jafama is a bit faster. -
- -Report for benchmark configuration jafamaDouble - - -* Run on OpenJDK 64-Bit Server VM (build 11.0.11+8-jvmci-21.1-b05) with Java process: - -``` -/home/commandertvis/graalvm-ce-java11/bin/java -XX:+UnlockExperimentalVMOptions -XX:+EnableJVMCIProduct -XX:-UnlockExperimentalVMOptions -XX:ThreadPriorityPolicy=1 -javaagent:/home/commandertvis/.gradle/caches/modules-2/files-2.1/org.jetbrains.kotlinx/kotlinx-coroutines-core-jvm/1.5.0/d8cebccdcddd029022aa8646a5a953ff88b13ac8/kotlinx-coroutines-core-jvm-1.5.0.jar -Dfile.encoding=UTF-8 -Duser.country=US -Duser.language=en -Duser.variant -ea -``` -* JMH 1.21 was used in `thrpt` mode with 1 warmup iteration by 1000 ms and 5 measurement iterations by 1000 ms. - -| Benchmark | Score | -|:---------:|:-----:| -|`space.kscience.kmath.benchmarks.JafamaBenchmark.core`|14.296120859512893 ± 0.36462633435888736 ops/s| -|`space.kscience.kmath.benchmarks.JafamaBenchmark.jafama`|11.431566395649781 ± 2.570896777898243 ops/s| -|`space.kscience.kmath.benchmarks.JafamaBenchmark.strictJafama`|11.746020495694117 ± 6.205909559197869 ops/s| -
- +> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**. diff --git a/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt b/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt index cf6f9471d..64a935705 100644 --- a/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt +++ b/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + package space.kscience.kmath.jafama import net.jafama.FastMath @@ -12,50 +17,50 @@ import space.kscience.kmath.operations.ScaleOperations */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object JafamaDoubleField : ExtendedField, Norm, ScaleOperations { - public override inline val zero: Double get() = 0.0 - public override inline val one: Double get() = 1.0 + override inline val zero: Double get() = 0.0 + override inline val one: Double get() = 1.0 - public override inline fun number(value: Number): Double = value.toDouble() + override inline fun number(value: Number): Double = value.toDouble() - public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = + override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = when (operation) { PowerOperations.POW_OPERATION -> ::power else -> super.binaryOperationFunction(operation) } - public override inline fun add(a: Double, b: Double): Double = a + b + override inline fun add(left: Double, right: Double): Double = left + right - public override inline fun multiply(a: Double, b: Double): Double = a * b - public override inline fun divide(a: Double, b: Double): Double = a / b + override inline fun multiply(left: Double, right: Double): Double = left * right + override inline fun divide(left: Double, right: Double): Double = left / right - public override inline fun scale(a: Double, value: Double): Double = a * value + override inline fun scale(a: Double, value: Double): Double = a * value - public override inline fun sin(arg: Double): Double = FastMath.sin(arg) - public override inline fun cos(arg: Double): Double = FastMath.cos(arg) - public override inline fun tan(arg: Double): Double = FastMath.tan(arg) - public override inline fun acos(arg: Double): Double = FastMath.acos(arg) - public override inline fun asin(arg: Double): Double = FastMath.asin(arg) - public override inline fun atan(arg: Double): Double = FastMath.atan(arg) + override inline fun sin(arg: Double): Double = FastMath.sin(arg) + override inline fun cos(arg: Double): Double = FastMath.cos(arg) + override inline fun tan(arg: Double): Double = FastMath.tan(arg) + override inline fun acos(arg: Double): Double = FastMath.acos(arg) + override inline fun asin(arg: Double): Double = FastMath.asin(arg) + override inline fun atan(arg: Double): Double = FastMath.atan(arg) - public override inline fun sinh(arg: Double): Double = FastMath.sinh(arg) - public override inline fun cosh(arg: Double): Double = FastMath.cosh(arg) - public override inline fun tanh(arg: Double): Double = FastMath.tanh(arg) - public override inline fun asinh(arg: Double): Double = FastMath.asinh(arg) - public override inline fun acosh(arg: Double): Double = FastMath.acosh(arg) - public override inline fun atanh(arg: Double): Double = FastMath.atanh(arg) + override inline fun sinh(arg: Double): Double = FastMath.sinh(arg) + override inline fun cosh(arg: Double): Double = FastMath.cosh(arg) + override inline fun tanh(arg: Double): Double = FastMath.tanh(arg) + override inline fun asinh(arg: Double): Double = FastMath.asinh(arg) + override inline fun acosh(arg: Double): Double = FastMath.acosh(arg) + override inline fun atanh(arg: Double): Double = FastMath.atanh(arg) - public override inline fun sqrt(arg: Double): Double = FastMath.sqrt(arg) - public override inline fun power(arg: Double, pow: Number): Double = FastMath.pow(arg, pow.toDouble()) - public override inline fun exp(arg: Double): Double = FastMath.exp(arg) - public override inline fun ln(arg: Double): Double = FastMath.log(arg) + override inline fun sqrt(arg: Double): Double = FastMath.sqrt(arg) + override inline fun power(arg: Double, pow: Number): Double = FastMath.pow(arg, pow.toDouble()) + override inline fun exp(arg: Double): Double = FastMath.exp(arg) + override inline fun ln(arg: Double): Double = FastMath.log(arg) - public override inline fun norm(arg: Double): Double = FastMath.abs(arg) + override inline fun norm(arg: Double): Double = FastMath.abs(arg) - public override inline fun Double.unaryMinus(): Double = -this - public override inline fun Double.plus(b: Double): Double = this + b - public override inline fun Double.minus(b: Double): Double = this - b - public override inline fun Double.times(b: Double): Double = this * b - public override inline fun Double.div(b: Double): Double = this / b + override inline fun Double.unaryMinus(): Double = -this + override inline fun Double.plus(arg: Double): Double = this + arg + override inline fun Double.minus(arg: Double): Double = this - arg + override inline fun Double.times(arg: Double): Double = this * arg + override inline fun Double.div(arg: Double): Double = this / arg } /** @@ -63,48 +68,48 @@ public object JafamaDoubleField : ExtendedField, Norm, S */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object StrictJafamaDoubleField : ExtendedField, Norm, ScaleOperations { - public override inline val zero: Double get() = 0.0 - public override inline val one: Double get() = 1.0 + override inline val zero: Double get() = 0.0 + override inline val one: Double get() = 1.0 - public override inline fun number(value: Number): Double = value.toDouble() + override inline fun number(value: Number): Double = value.toDouble() - public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = + override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = when (operation) { PowerOperations.POW_OPERATION -> ::power else -> super.binaryOperationFunction(operation) } - public override inline fun add(a: Double, b: Double): Double = a + b + override inline fun add(left: Double, right: Double): Double = left + right - public override inline fun multiply(a: Double, b: Double): Double = a * b - public override inline fun divide(a: Double, b: Double): Double = a / b + override inline fun multiply(left: Double, right: Double): Double = left * right + override inline fun divide(left: Double, right: Double): Double = left / right - public override inline fun scale(a: Double, value: Double): Double = a * value + override inline fun scale(a: Double, value: Double): Double = a * value - public override inline fun sin(arg: Double): Double = StrictFastMath.sin(arg) - public override inline fun cos(arg: Double): Double = StrictFastMath.cos(arg) - public override inline fun tan(arg: Double): Double = StrictFastMath.tan(arg) - public override inline fun acos(arg: Double): Double = StrictFastMath.acos(arg) - public override inline fun asin(arg: Double): Double = StrictFastMath.asin(arg) - public override inline fun atan(arg: Double): Double = StrictFastMath.atan(arg) + override inline fun sin(arg: Double): Double = StrictFastMath.sin(arg) + override inline fun cos(arg: Double): Double = StrictFastMath.cos(arg) + override inline fun tan(arg: Double): Double = StrictFastMath.tan(arg) + override inline fun acos(arg: Double): Double = StrictFastMath.acos(arg) + override inline fun asin(arg: Double): Double = StrictFastMath.asin(arg) + override inline fun atan(arg: Double): Double = StrictFastMath.atan(arg) - public override inline fun sinh(arg: Double): Double = StrictFastMath.sinh(arg) - public override inline fun cosh(arg: Double): Double = StrictFastMath.cosh(arg) - public override inline fun tanh(arg: Double): Double = StrictFastMath.tanh(arg) - public override inline fun asinh(arg: Double): Double = StrictFastMath.asinh(arg) - public override inline fun acosh(arg: Double): Double = StrictFastMath.acosh(arg) - public override inline fun atanh(arg: Double): Double = StrictFastMath.atanh(arg) + override inline fun sinh(arg: Double): Double = StrictFastMath.sinh(arg) + override inline fun cosh(arg: Double): Double = StrictFastMath.cosh(arg) + override inline fun tanh(arg: Double): Double = StrictFastMath.tanh(arg) + override inline fun asinh(arg: Double): Double = StrictFastMath.asinh(arg) + override inline fun acosh(arg: Double): Double = StrictFastMath.acosh(arg) + override inline fun atanh(arg: Double): Double = StrictFastMath.atanh(arg) - public override inline fun sqrt(arg: Double): Double = StrictFastMath.sqrt(arg) - public override inline fun power(arg: Double, pow: Number): Double = StrictFastMath.pow(arg, pow.toDouble()) - public override inline fun exp(arg: Double): Double = StrictFastMath.exp(arg) - public override inline fun ln(arg: Double): Double = StrictFastMath.log(arg) + override inline fun sqrt(arg: Double): Double = StrictFastMath.sqrt(arg) + override inline fun power(arg: Double, pow: Number): Double = StrictFastMath.pow(arg, pow.toDouble()) + override inline fun exp(arg: Double): Double = StrictFastMath.exp(arg) + override inline fun ln(arg: Double): Double = StrictFastMath.log(arg) - public override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg) + override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg) - public override inline fun Double.unaryMinus(): Double = -this - public override inline fun Double.plus(b: Double): Double = this + b - public override inline fun Double.minus(b: Double): Double = this - b - public override inline fun Double.times(b: Double): Double = this * b - public override inline fun Double.div(b: Double): Double = this / b + override inline fun Double.unaryMinus(): Double = -this + override inline fun Double.plus(arg: Double): Double = this + arg + override inline fun Double.minus(arg: Double): Double = this - arg + override inline fun Double.times(arg: Double): Double = this * arg + override inline fun Double.div(arg: Double): Double = this / arg } diff --git a/kmath-jupyter/build.gradle.kts b/kmath-jupyter/build.gradle.kts index 83a6a771a..5bd08c485 100644 --- a/kmath-jupyter/build.gradle.kts +++ b/kmath-jupyter/build.gradle.kts @@ -20,3 +20,7 @@ readme { kotlin.sourceSets.all { languageSettings.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI") } + +tasks.processJupyterApiResources { + libraryProducers = listOf("space.kscience.kmath.jupyter.KMathJupyter") +} diff --git a/kmath-jupyter/src/main/kotlin/space/kscience/kmath/jupyter/KMathJupyter.kt b/kmath-jupyter/src/main/kotlin/space/kscience/kmath/jupyter/KMathJupyter.kt index e3767e13c..9731908b3 100644 --- a/kmath-jupyter/src/main/kotlin/space/kscience/kmath/jupyter/KMathJupyter.kt +++ b/kmath-jupyter/src/main/kotlin/space/kscience/kmath/jupyter/KMathJupyter.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + package space.kscience.kmath.jupyter import kotlinx.html.Unsafe @@ -6,7 +11,6 @@ import kotlinx.html.stream.createHTML import kotlinx.html.unsafe import org.jetbrains.kotlinx.jupyter.api.DisplayResult import org.jetbrains.kotlinx.jupyter.api.HTML -import org.jetbrains.kotlinx.jupyter.api.annotations.JupyterLibrary import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess import space.kscience.kmath.ast.rendering.MathMLSyntaxRenderer @@ -17,16 +21,15 @@ import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MstRing import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.Structure2D +import space.kscience.kmath.operations.asSequence import space.kscience.kmath.operations.invoke import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.asSequence /** * A function for conversion of number to MST for pretty print */ public fun Number.toMst(): MST.Numeric = MST.Numeric(this) -@JupyterLibrary internal class KMathJupyter : JupyterIntegration() { private val mathRender = FeaturedMathRendererWithPostProcess.Default private val syntaxRender = MathMLSyntaxRenderer diff --git a/kmath-kotlingrad/README.md b/kmath-kotlingrad/README.md index 31c7bb819..aeb44ea13 100644 --- a/kmath-kotlingrad/README.md +++ b/kmath-kotlingrad/README.md @@ -1,14 +1,14 @@ # Module kmath-kotlingrad -[Kotlin∇](https://www.htmlsymbols.xyz/unicode/U+2207) integration module. +[Kotlin∇](https://github.com/breandan/kotlingrad) integration module. - - [differentiable-mst-expression](src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt) : MST based DifferentiableExpression. - - [differentiable-mst-expression](src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt) : Conversions between Kotlin∇'s SFun and MST + - [differentiable-mst-expression](src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt) : MST based DifferentiableExpression. + - [scalars-adapters](src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt) : Conversions between Kotlin∇'s SFun and MST ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-kotlingrad:0.3.0-dev-13`. +The Maven coordinates of this project are `space.kscience:kmath-kotlingrad:0.3.0-dev-14`. **Gradle:** ```gradle @@ -18,7 +18,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-kotlingrad:0.3.0-dev-13' + implementation 'space.kscience:kmath-kotlingrad:0.3.0-dev-14' } ``` **Gradle Kotlin DSL:** @@ -29,6 +29,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-kotlingrad:0.3.0-dev-13") + implementation("space.kscience:kmath-kotlingrad:0.3.0-dev-14") } ``` diff --git a/kmath-kotlingrad/build.gradle.kts b/kmath-kotlingrad/build.gradle.kts index 01b42d7ba..d222ed7d6 100644 --- a/kmath-kotlingrad/build.gradle.kts +++ b/kmath-kotlingrad/build.gradle.kts @@ -18,14 +18,14 @@ readme { feature( "differentiable-mst-expression", - "src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt", + "src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt", ) { "MST based DifferentiableExpression." } feature( - "differentiable-mst-expression", - "src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt", + "scalars-adapters", + "src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt", ) { "Conversions between Kotlin∇'s SFun and MST" } diff --git a/kmath-kotlingrad/docs/README-TEMPLATE.md b/kmath-kotlingrad/docs/README-TEMPLATE.md index ac38c849b..bc99bdf5f 100644 --- a/kmath-kotlingrad/docs/README-TEMPLATE.md +++ b/kmath-kotlingrad/docs/README-TEMPLATE.md @@ -1,6 +1,6 @@ # Module kmath-kotlingrad -[Kotlin∇](https://www.htmlsymbols.xyz/unicode/U+2207) integration module. +[Kotlin∇](https://github.com/breandan/kotlingrad) integration module. ${features} diff --git a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KMathNumber.kt b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KMathNumber.kt index 9c9d07b81..0f10c6cdd 100644 --- a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KMathNumber.kt +++ b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KMathNumber.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.kotlingrad @@ -16,15 +16,15 @@ import space.kscience.kmath.operations.NumericAlgebra * @property algebra The algebra. * @property value The value of this number. */ -public class KMathNumber(public val algebra: A, public override val value: T) : +public class KMathNumber(public val algebra: A, override val value: T) : SConst>(value) where T : Number, A : NumericAlgebra { /** * Returns a string representation of the [value]. */ - public override fun toString(): String = value.toString() + override fun toString(): String = value.toString() /** * Wraps [Number] to [KMathNumber]. */ - public override fun wrap(number: Number): KMathNumber = KMathNumber(algebra, algebra.number(number)) + override fun wrap(number: Number): KMathNumber = KMathNumber(algebra, algebra.number(number)) } diff --git a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt index 28f6cd59e..84171101f 100644 --- a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt +++ b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.kotlingrad @@ -25,17 +25,28 @@ public class KotlingradExpression>( public val algebra: A, public val mst: MST, ) : SpecialDifferentiableExpression> { - public override fun invoke(arguments: Map): T = mst.interpret(algebra, arguments) + override fun invoke(arguments: Map): T = mst.interpret(algebra, arguments) - public override fun derivativeOrNull(symbols: List): KotlingradExpression = - KotlingradExpression( - algebra, - symbols.map(Symbol::identity) - .map(MstNumericAlgebra::bindSymbol) - .map>>(Symbol::toSVar) - .fold(mst.toSFun(), SFun>::d) - .toMst(), - ) + override fun derivativeOrNull( + symbols: List, + ): KotlingradExpression = KotlingradExpression( + algebra, + symbols.map(Symbol::identity) + .map(MstNumericAlgebra::bindSymbol) + .map>>(Symbol::toSVar) + .fold(mst.toSFun(), SFun>::d) + .toMst(), + ) +} + +/** + * A diff processor using [MST] to Kotlingrad converter + */ +public class KotlingradProcessor>( + public val algebra: A, +) : AutoDiffProcessor { + override fun differentiate(function: MstExtendedField.() -> MST): DifferentiableExpression = + MstExtendedField.function().toKotlingradExpression(algebra) } /** diff --git a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt index 6c0b98c59..11e5853a8 100644 --- a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt +++ b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/scalarsAdapters.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.kotlingrad @@ -106,8 +106,8 @@ public fun > MST.toSFun(): SFun = when (this) { is Symbol -> toSVar() is MST.Unary -> when (operation) { - GroupOperations.PLUS_OPERATION -> +value.toSFun() - GroupOperations.MINUS_OPERATION -> -value.toSFun() + GroupOps.PLUS_OPERATION -> +value.toSFun() + GroupOps.MINUS_OPERATION -> -value.toSFun() TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun()) TrigonometricOperations.COS_OPERATION -> cos(value.toSFun()) TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun()) @@ -124,10 +124,10 @@ public fun > MST.toSFun(): SFun = when (this) { } is MST.Binary -> when (operation) { - GroupOperations.PLUS_OPERATION -> left.toSFun() + right.toSFun() - GroupOperations.MINUS_OPERATION -> left.toSFun() - right.toSFun() - RingOperations.TIMES_OPERATION -> left.toSFun() * right.toSFun() - FieldOperations.DIV_OPERATION -> left.toSFun() / right.toSFun() + GroupOps.PLUS_OPERATION -> left.toSFun() + right.toSFun() + GroupOps.MINUS_OPERATION -> left.toSFun() - right.toSFun() + RingOps.TIMES_OPERATION -> left.toSFun() * right.toSFun() + FieldOps.DIV_OPERATION -> left.toSFun() / right.toSFun() PowerOperations.POW_OPERATION -> left.toSFun() pow (right as MST.Numeric).toSConst() else -> error("Binary operation $operation not defined in $this") } diff --git a/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt b/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt index 9378adfea..67332a680 100644 --- a/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt +++ b/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.kotlingrad diff --git a/kmath-memory/src/commonMain/kotlin/space/kscience/kmath/memory/Memory.kt b/kmath-memory/src/commonMain/kotlin/space/kscience/kmath/memory/Memory.kt index 930b21095..9f73ae2f3 100644 --- a/kmath-memory/src/commonMain/kotlin/space/kscience/kmath/memory/Memory.kt +++ b/kmath-memory/src/commonMain/kotlin/space/kscience/kmath/memory/Memory.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.memory @@ -156,6 +156,6 @@ public expect fun Memory.Companion.allocate(length: Int): Memory /** * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied - * and could be mutated independently from the resulting [Memory]. + * and could be mutated independently of the resulting [Memory]. */ public expect fun Memory.Companion.wrap(array: ByteArray): Memory diff --git a/kmath-memory/src/commonMain/kotlin/space/kscience/kmath/memory/MemorySpec.kt b/kmath-memory/src/commonMain/kotlin/space/kscience/kmath/memory/MemorySpec.kt index 1ee1cf4e2..2f2af4d9c 100644 --- a/kmath-memory/src/commonMain/kotlin/space/kscience/kmath/memory/MemorySpec.kt +++ b/kmath-memory/src/commonMain/kotlin/space/kscience/kmath/memory/MemorySpec.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.memory diff --git a/kmath-memory/src/jsMain/kotlin/space/kscience/kmath/memory/DataViewMemory.kt b/kmath-memory/src/jsMain/kotlin/space/kscience/kmath/memory/DataViewMemory.kt index 9a622ea36..db5eb556e 100644 --- a/kmath-memory/src/jsMain/kotlin/space/kscience/kmath/memory/DataViewMemory.kt +++ b/kmath-memory/src/jsMain/kotlin/space/kscience/kmath/memory/DataViewMemory.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.memory @@ -95,7 +95,7 @@ public actual fun Memory.Companion.allocate(length: Int): Memory { /** * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied - * and could be mutated independently from the resulting [Memory]. + * and could be mutated independently of the resulting [Memory]. */ public actual fun Memory.Companion.wrap(array: ByteArray): Memory { @Suppress("CAST_NEVER_SUCCEEDS") val int8Array = array as Int8Array diff --git a/kmath-memory/src/jvmMain/kotlin/space/kscience/kmath/memory/ByteBufferMemory.kt b/kmath-memory/src/jvmMain/kotlin/space/kscience/kmath/memory/ByteBufferMemory.kt index 944e8455b..6e60514f8 100644 --- a/kmath-memory/src/jvmMain/kotlin/space/kscience/kmath/memory/ByteBufferMemory.kt +++ b/kmath-memory/src/jvmMain/kotlin/space/kscience/kmath/memory/ByteBufferMemory.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.memory @@ -103,7 +103,7 @@ public actual fun Memory.Companion.allocate(length: Int): Memory = /** * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied - * and could be mutated independently from the resulting [Memory]. + * and could be mutated independently of the resulting [Memory]. */ public actual fun Memory.Companion.wrap(array: ByteArray): Memory = ByteBufferMemory(checkNotNull(ByteBuffer.wrap(array))) diff --git a/kmath-memory/src/nativeMain/kotlin/space/kscience/kmath/memory/NativeMemory.kt b/kmath-memory/src/nativeMain/kotlin/space/kscience/kmath/memory/NativeMemory.kt index d31c9e8f4..d13da1191 100644 --- a/kmath-memory/src/nativeMain/kotlin/space/kscience/kmath/memory/NativeMemory.kt +++ b/kmath-memory/src/nativeMain/kotlin/space/kscience/kmath/memory/NativeMemory.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.memory @@ -60,7 +60,7 @@ internal class NativeMemory( } override fun writeByte(offset: Int, value: Byte) { - array.set(position(offset), value) + array[position(offset)] = value } override fun writeShort(offset: Int, value: Short) { @@ -85,7 +85,7 @@ internal class NativeMemory( /** * Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied - * and could be mutated independently from the resulting [Memory]. + * and could be mutated independently of the resulting [Memory]. */ public actual fun Memory.Companion.wrap(array: ByteArray): Memory = NativeMemory(array) diff --git a/kmath-multik/build.gradle.kts b/kmath-multik/build.gradle.kts new file mode 100644 index 000000000..df2292f2e --- /dev/null +++ b/kmath-multik/build.gradle.kts @@ -0,0 +1,14 @@ +plugins { + id("ru.mipt.npm.gradle.jvm") +} + +description = "JetBrains Multik connector" + +dependencies { + api(project(":kmath-tensors")) + api("org.jetbrains.kotlinx:multik-default:0.1.0") +} + +readme { + maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE +} \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt new file mode 100644 index 000000000..70cfdeabd --- /dev/null +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -0,0 +1,344 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +@file:Suppress("unused") + +package space.kscience.kmath.multik + +import org.jetbrains.kotlinx.multik.api.Multik +import org.jetbrains.kotlinx.multik.api.linalg.dot +import org.jetbrains.kotlinx.multik.api.mk +import org.jetbrains.kotlinx.multik.api.ndarrayOf +import org.jetbrains.kotlinx.multik.api.zeros +import org.jetbrains.kotlinx.multik.ndarray.data.* +import org.jetbrains.kotlinx.multik.ndarray.operations.* +import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.DefaultStrides +import space.kscience.kmath.nd.Shape +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.nd.mapInPlace +import space.kscience.kmath.operations.* +import space.kscience.kmath.tensors.api.Tensor +import space.kscience.kmath.tensors.api.TensorAlgebra +import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra + +@JvmInline +public value class MultikTensor(public val array: MutableMultiArray) : Tensor { + override val shape: Shape get() = array.shape + + override fun get(index: IntArray): T = array[index] + + @PerformancePitfall + override fun elements(): Sequence> = + array.multiIndices.iterator().asSequence().map { it to get(it) } + + override fun set(index: IntArray, value: T) { + array[index] = value + } +} + +private fun MultiArray.asD1Array(): D1Array { + if (this is NDArray) + return this.asD1Array() + else throw ClassCastException("Cannot cast MultiArray to NDArray.") +} + + +private fun MultiArray.asD2Array(): D2Array { + if (this is NDArray) + return this.asD2Array() + else throw ClassCastException("Cannot cast MultiArray to NDArray.") +} + +public abstract class MultikTensorAlgebra> : TensorAlgebra where T : Number, T : Comparable { + + public abstract val type: DataType + + override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor { + val strides = DefaultStrides(shape) + val memoryView = initMemoryView(strides.linearSize, type) + strides.asSequence().forEachIndexed { linearIndex, tensorIndex -> + memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex) + } + return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size))) + } + + override fun StructureND.map(transform: A.(T) -> T): MultikTensor = if (this is MultikTensor) { + val data = initMemoryView(array.size, type) + var count = 0 + for (el in array) data[count++] = elementAlgebra.transform(el) + NDArray(data, shape = shape, dim = array.dim).wrap() + } else { + structureND(shape) { index -> + transform(get(index)) + } + } + + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor = + if (this is MultikTensor) { + val array = asMultik().array + val data = initMemoryView(array.size, type) + val indexIter = array.multiIndices.iterator() + var index = 0 + for (item in array) { + if (indexIter.hasNext()) { + data[index++] = elementAlgebra.transform(indexIter.next(), item) + } else { + throw ArithmeticException("Index overflow has happened.") + } + } + NDArray(data, shape = array.shape, dim = array.dim).wrap() + } else { + structureND(shape) { index -> + transform(index, get(index)) + } + } + + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): MultikTensor { + require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException + val leftArray = left.asMultik().array + val rightArray = right.asMultik().array + val data = initMemoryView(leftArray.size, type) + var counter = 0 + val leftIterator = leftArray.iterator() + val rightIterator = rightArray.iterator() + //iterating them together + while (leftIterator.hasNext()) { + data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next()) + } + return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap() + } + + /** + * Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor + * are not reflected back onto the source + */ + public fun StructureND.asMultik(): MultikTensor = if (this is MultikTensor) { + this + } else { + val res = mk.zeros(shape, type).asDNArray() + for (index in res.multiIndices) { + res[index] = this[index] + } + res.wrap() + } + + public fun MutableMultiArray.wrap(): MultikTensor = MultikTensor(this.asDNArray()) + + override fun StructureND.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) { + get(intArrayOf(0)) + } else null + + override fun T.plus(arg: StructureND): MultikTensor = + arg.plus(this) + + override fun StructureND.plus(arg: T): MultikTensor = + asMultik().array.deepCopy().apply { plusAssign(arg) }.wrap() + + override fun StructureND.plus(arg: StructureND): MultikTensor = + asMultik().array.plus(arg.asMultik().array).wrap() + + override fun Tensor.plusAssign(value: T) { + if (this is MultikTensor) { + array.plusAssign(value) + } else { + mapInPlace { _, t -> elementAlgebra.add(t, value) } + } + } + + override fun Tensor.plusAssign(arg: StructureND) { + if (this is MultikTensor) { + array.plusAssign(arg.asMultik().array) + } else { + mapInPlace { index, t -> elementAlgebra.add(t, arg[index]) } + } + } + + override fun T.minus(arg: StructureND): MultikTensor = (-(arg.asMultik().array - this)).wrap() + + override fun StructureND.minus(arg: T): MultikTensor = + asMultik().array.deepCopy().apply { minusAssign(arg) }.wrap() + + override fun StructureND.minus(arg: StructureND): MultikTensor = + asMultik().array.minus(arg.asMultik().array).wrap() + + override fun Tensor.minusAssign(value: T) { + if (this is MultikTensor) { + array.minusAssign(value) + } else { + mapInPlace { _, t -> elementAlgebra.run { t - value } } + } + } + + override fun Tensor.minusAssign(arg: StructureND) { + if (this is MultikTensor) { + array.minusAssign(arg.asMultik().array) + } else { + mapInPlace { index, t -> elementAlgebra.run { t - arg[index] } } + } + } + + override fun T.times(arg: StructureND): MultikTensor = + arg.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap() + + override fun StructureND.times(arg: T): Tensor = + asMultik().array.deepCopy().apply { timesAssign(arg) }.wrap() + + override fun StructureND.times(arg: StructureND): MultikTensor = + asMultik().array.times(arg.asMultik().array).wrap() + + override fun Tensor.timesAssign(value: T) { + if (this is MultikTensor) { + array.timesAssign(value) + } else { + mapInPlace { _, t -> elementAlgebra.multiply(t, value) } + } + } + + override fun Tensor.timesAssign(arg: StructureND) { + if (this is MultikTensor) { + array.timesAssign(arg.asMultik().array) + } else { + mapInPlace { index, t -> elementAlgebra.multiply(t, arg[index]) } + } + } + + override fun StructureND.unaryMinus(): MultikTensor = + asMultik().array.unaryMinus().wrap() + + override fun Tensor.get(i: Int): MultikTensor = asMultik().array.mutableView(i).wrap() + + override fun Tensor.transpose(i: Int, j: Int): MultikTensor = asMultik().array.transpose(i, j).wrap() + + override fun Tensor.view(shape: IntArray): MultikTensor { + require(shape.all { it > 0 }) + require(shape.fold(1, Int::times) == this.shape.size) { + "Cannot reshape array of size ${this.shape.size} into a new shape ${ + shape.joinToString( + prefix = "(", + postfix = ")" + ) + }" + } + + val mt = asMultik().array + return if (mt.shape.contentEquals(shape)) { + mt + } else { + NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt) + }.wrap() + } + + override fun Tensor.viewAs(other: StructureND): MultikTensor = view(other.shape) + + override fun StructureND.dot(other: StructureND): MultikTensor = + if (this.shape.size == 1 && other.shape.size == 1) { + Multik.ndarrayOf( + asMultik().array.asD1Array() dot other.asMultik().array.asD1Array() + ).asDNArray().wrap() + } else if (this.shape.size == 2 && other.shape.size == 2) { + (asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap() + } else if (this.shape.size == 2 && other.shape.size == 1) { + (asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap() + } else { + TODO("Not implemented for broadcasting") + } + + override fun diagonalEmbedding(diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int): MultikTensor { + TODO("Diagonal embedding not implemented") + } + + override fun StructureND.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T -> + elementAlgebra.add(acc, t) + } + + override fun StructureND.sum(dim: Int, keepDim: Boolean): MultikTensor { + TODO("Not yet implemented") + } + + override fun StructureND.min(): T? = asMultik().array.min() + + override fun StructureND.min(dim: Int, keepDim: Boolean): Tensor { + TODO("Not yet implemented") + } + + override fun StructureND.max(): T? = asMultik().array.max() + + override fun StructureND.max(dim: Int, keepDim: Boolean): Tensor { + TODO("Not yet implemented") + } + + override fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor { + TODO("Not yet implemented") + } +} + +public abstract class MultikDivisionTensorAlgebra> + : MultikTensorAlgebra(), TensorPartialDivisionAlgebra where T : Number, T : Comparable { + + override fun T.div(arg: StructureND): MultikTensor = arg.map { elementAlgebra.divide(this@div, it) } + + override fun StructureND.div(arg: T): MultikTensor = + asMultik().array.deepCopy().apply { divAssign(arg) }.wrap() + + override fun StructureND.div(arg: StructureND): MultikTensor = + asMultik().array.div(arg.asMultik().array).wrap() + + override fun Tensor.divAssign(value: T) { + if (this is MultikTensor) { + array.divAssign(value) + } else { + mapInPlace { _, t -> elementAlgebra.divide(t, value) } + } + } + + override fun Tensor.divAssign(arg: StructureND) { + if (this is MultikTensor) { + array.divAssign(arg.asMultik().array) + } else { + mapInPlace { index, t -> elementAlgebra.divide(t, arg[index]) } + } + } +} + +public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra() { + override val elementAlgebra: DoubleField get() = DoubleField + override val type: DataType get() = DataType.DoubleDataType +} + +public val Double.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra +public val DoubleField.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra + +public object MultikFloatAlgebra : MultikDivisionTensorAlgebra() { + override val elementAlgebra: FloatField get() = FloatField + override val type: DataType get() = DataType.FloatDataType +} + +public val Float.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra +public val FloatField.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra + +public object MultikShortAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: ShortRing get() = ShortRing + override val type: DataType get() = DataType.ShortDataType +} + +public val Short.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra +public val ShortRing.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra + +public object MultikIntAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: IntRing get() = IntRing + override val type: DataType get() = DataType.IntDataType +} + +public val Int.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra +public val IntRing.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra + +public object MultikLongAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: LongRing get() = LongRing + override val type: DataType get() = DataType.LongDataType +} + +public val Long.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra +public val LongRing.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra \ No newline at end of file diff --git a/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt b/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt new file mode 100644 index 000000000..66ba7db2d --- /dev/null +++ b/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt @@ -0,0 +1,13 @@ +package space.kscience.kmath.multik + +import org.junit.jupiter.api.Test +import space.kscience.kmath.nd.one +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke + +internal class MultikNDTest { + @Test + fun basicAlgebra(): Unit = DoubleField.multikAlgebra{ + one(2,2) + 1.0 + } +} \ No newline at end of file diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md index 1c945f06a..5cbb31d5a 100644 --- a/kmath-nd4j/README.md +++ b/kmath-nd4j/README.md @@ -9,7 +9,7 @@ ND4J based implementations of KMath abstractions. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-13`. +The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-14`. **Gradle:** ```gradle @@ -19,7 +19,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-nd4j:0.3.0-dev-13' + implementation 'space.kscience:kmath-nd4j:0.3.0-dev-14' } ``` **Gradle Kotlin DSL:** @@ -30,7 +30,7 @@ repositories { } dependencies { - implementation("space.kscience:kmath-nd4j:0.3.0-dev-13") + implementation("space.kscience:kmath-nd4j:0.3.0-dev-14") } ``` diff --git a/kmath-nd4j/build.gradle.kts b/kmath-nd4j/build.gradle.kts index abcc02962..09264501f 100644 --- a/kmath-nd4j/build.gradle.kts +++ b/kmath-nd4j/build.gradle.kts @@ -7,10 +7,9 @@ description = "ND4J NDStructure implementation and according NDAlgebra classes" dependencies { api(project(":kmath-tensors")) - api("org.nd4j:nd4j-api:1.0.0-beta7") - testImplementation("org.nd4j:nd4j-native:1.0.0-beta7") - testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") - testImplementation("org.slf4j:slf4j-simple:1.7.30") + api("org.nd4j:nd4j-api:1.0.0-M1") + testImplementation("org.nd4j:nd4j-native-platform:1.0.0-M1") + testImplementation("org.slf4j:slf4j-simple:1.7.32") } readme { diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt index e94bda12a..b1cc1f834 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd4j @@ -15,13 +15,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.* import space.kscience.kmath.operations.* -internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray { - val arrayShape = array.shape().toIntArray() - if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape) - return array -} - - /** * Represents [AlgebraND] over [Nd4jArrayAlgebra]. * @@ -35,38 +28,39 @@ public sealed interface Nd4jArrayAlgebra> : AlgebraND /** - * Unwraps to or acquires [INDArray] from [StructureND]. + * Unwraps to or get [INDArray] from [StructureND]. */ public val StructureND.ndArray: INDArray - public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure { + override fun structureND(shape: Shape, initializer: C.(IntArray) -> T): Nd4jArrayStructure { val struct = Nd4j.create(*shape)!!.wrap() - struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } + struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) } return struct } - @PerformancePitfall - public override fun StructureND.map(transform: C.(T) -> T): Nd4jArrayStructure { + @OptIn(PerformancePitfall::class) + override fun StructureND.map(transform: C.(T) -> T): Nd4jArrayStructure { val newStruct = ndArray.dup().wrap() - newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } + newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementAlgebra.transform(value) } return newStruct } - public override fun StructureND.mapIndexed( + override fun StructureND.mapIndexed( transform: C.(index: IntArray, T) -> T, ): Nd4jArrayStructure { - val new = Nd4j.create(*this@Nd4jArrayAlgebra.shape).wrap() - new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, this[idx]) } + val new = Nd4j.create(*shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(idx, this[idx]) } return new } - public override fun combine( - a: StructureND, - b: StructureND, + override fun zip( + left: StructureND, + right: StructureND, transform: C.(T, T) -> T, ): Nd4jArrayStructure { - val new = Nd4j.create(*shape).wrap() - new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } + require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" } + val new = Nd4j.create(*left.shape).wrap() + new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) } return new } } @@ -77,18 +71,15 @@ public sealed interface Nd4jArrayAlgebra> : AlgebraND> : GroupND, Nd4jArrayAlgebra { +public sealed interface Nd4jArrayGroupOps> : GroupOpsND, Nd4jArrayAlgebra { - public override val zero: Nd4jArrayStructure - get() = Nd4j.zeros(*shape).wrap() + override fun add(left: StructureND, right: StructureND): Nd4jArrayStructure = + left.ndArray.add(right.ndArray).wrap() - public override fun add(a: StructureND, b: StructureND): Nd4jArrayStructure = - a.ndArray.add(b.ndArray).wrap() + override operator fun StructureND.minus(arg: StructureND): Nd4jArrayStructure = + ndArray.sub(arg.ndArray).wrap() - public override operator fun StructureND.minus(b: StructureND): Nd4jArrayStructure = - ndArray.sub(b.ndArray).wrap() - - public override operator fun StructureND.unaryMinus(): Nd4jArrayStructure = + override operator fun StructureND.unaryMinus(): Nd4jArrayStructure = ndArray.neg().wrap() public fun multiply(a: StructureND, k: Number): Nd4jArrayStructure = @@ -102,45 +93,33 @@ public sealed interface Nd4jArrayGroup> : GroupND, Nd4j * @param R the type of ring of structure elements. */ @OptIn(UnstableKMathAPI::class) -public sealed interface Nd4jArrayRing> : RingND, Nd4jArrayGroup { +public sealed interface Nd4jArrayRingOps> : RingOpsND, Nd4jArrayGroupOps { - public override val one: Nd4jArrayStructure - get() = Nd4j.ones(*shape).wrap() - - public override fun multiply(a: StructureND, b: StructureND): Nd4jArrayStructure = - a.ndArray.mul(b.ndArray).wrap() + override fun multiply(left: StructureND, right: StructureND): Nd4jArrayStructure = + left.ndArray.mul(right.ndArray).wrap() // -// public override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure { +// override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure { // check(this) // return ndArray.sub(b).wrap() // } // -// public override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure { +// override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure { // check(this) // return ndArray.add(b).wrap() // } // -// public override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { +// override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { // check(b) // return b.ndArray.rsub(this).wrap() // } public companion object { - private val intNd4jArrayRingCache: ThreadLocal> = - ThreadLocal.withInitial(::HashMap) - - /** - * Creates an [RingND] for [Int] values or pull it from cache if it was created previously. - */ - public fun int(vararg shape: Int): Nd4jArrayRing = - intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) } - /** * Creates a most suitable implementation of [RingND] using reified class. */ @Suppress("UNCHECKED_CAST") - public inline fun auto(vararg shape: Int): Nd4jArrayRing> = when { - T::class == Int::class -> int(*shape) as Nd4jArrayRing> + public inline fun auto(): Nd4jArrayRingOps> = when { + T::class == Int::class -> IntRing.nd4j as Nd4jArrayRingOps> else -> throw UnsupportedOperationException("This factory method only supports Long type.") } } @@ -152,38 +131,21 @@ public sealed interface Nd4jArrayRing> : RingND, Nd4jAr * @param T the type of the element contained in ND structure. * @param F the type field of structure elements. */ -public sealed interface Nd4jArrayField> : FieldND, Nd4jArrayRing { - public override fun divide(a: StructureND, b: StructureND): Nd4jArrayStructure = - a.ndArray.div(b.ndArray).wrap() +public sealed interface Nd4jArrayField> : FieldOpsND, Nd4jArrayRingOps { + + override fun divide(left: StructureND, right: StructureND): Nd4jArrayStructure = + left.ndArray.div(right.ndArray).wrap() public operator fun Number.div(b: StructureND): Nd4jArrayStructure = b.ndArray.rdiv(this).wrap() public companion object { - private val floatNd4jArrayFieldCache: ThreadLocal> = - ThreadLocal.withInitial(::HashMap) - - private val doubleNd4JArrayFieldCache: ThreadLocal> = - ThreadLocal.withInitial(::HashMap) - - /** - * Creates an [FieldND] for [Float] values or pull it from cache if it was created previously. - */ - public fun float(vararg shape: Int): Nd4jArrayRing = - floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) } - - /** - * Creates an [FieldND] for [Double] values or pull it from cache if it was created previously. - */ - public fun real(vararg shape: Int): Nd4jArrayRing = - doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) } - /** * Creates a most suitable implementation of [FieldND] using reified class. */ @Suppress("UNCHECKED_CAST") - public inline fun auto(vararg shape: Int): Nd4jArrayField> = when { - T::class == Float::class -> float(*shape) as Nd4jArrayField> - T::class == Double::class -> real(*shape) as Nd4jArrayField> + public inline fun auto(): Nd4jArrayField> = when { + T::class == Float::class -> FloatField.nd4j as Nd4jArrayField> + T::class == Double::class -> DoubleField.nd4j as Nd4jArrayField> else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.") } } @@ -192,91 +154,90 @@ public sealed interface Nd4jArrayField> : FieldND, Nd4 /** * Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure]. */ -public sealed interface Nd4jArrayExtendedField> : ExtendedField>, - Nd4jArrayField { - public override fun sin(arg: StructureND): StructureND = Transforms.sin(arg.ndArray).wrap() - public override fun cos(arg: StructureND): StructureND = Transforms.cos(arg.ndArray).wrap() - public override fun asin(arg: StructureND): StructureND = Transforms.asin(arg.ndArray).wrap() - public override fun acos(arg: StructureND): StructureND = Transforms.acos(arg.ndArray).wrap() - public override fun atan(arg: StructureND): StructureND = Transforms.atan(arg.ndArray).wrap() +public sealed interface Nd4jArrayExtendedFieldOps> : + ExtendedFieldOps>, Nd4jArrayField { - public override fun power(arg: StructureND, pow: Number): StructureND = + override fun sin(arg: StructureND): StructureND = Transforms.sin(arg.ndArray).wrap() + override fun cos(arg: StructureND): StructureND = Transforms.cos(arg.ndArray).wrap() + override fun asin(arg: StructureND): StructureND = Transforms.asin(arg.ndArray).wrap() + override fun acos(arg: StructureND): StructureND = Transforms.acos(arg.ndArray).wrap() + override fun atan(arg: StructureND): StructureND = Transforms.atan(arg.ndArray).wrap() + + override fun power(arg: StructureND, pow: Number): StructureND = Transforms.pow(arg.ndArray, pow).wrap() - public override fun exp(arg: StructureND): StructureND = Transforms.exp(arg.ndArray).wrap() - public override fun ln(arg: StructureND): StructureND = Transforms.log(arg.ndArray).wrap() - public override fun sqrt(arg: StructureND): StructureND = Transforms.sqrt(arg.ndArray).wrap() - public override fun sinh(arg: StructureND): StructureND = Transforms.sinh(arg.ndArray).wrap() - public override fun cosh(arg: StructureND): StructureND = Transforms.cosh(arg.ndArray).wrap() - public override fun tanh(arg: StructureND): StructureND = Transforms.tanh(arg.ndArray).wrap() + override fun exp(arg: StructureND): StructureND = Transforms.exp(arg.ndArray).wrap() + override fun ln(arg: StructureND): StructureND = Transforms.log(arg.ndArray).wrap() + override fun sqrt(arg: StructureND): StructureND = Transforms.sqrt(arg.ndArray).wrap() + override fun sinh(arg: StructureND): StructureND = Transforms.sinh(arg.ndArray).wrap() + override fun cosh(arg: StructureND): StructureND = Transforms.cosh(arg.ndArray).wrap() + override fun tanh(arg: StructureND): StructureND = Transforms.tanh(arg.ndArray).wrap() - public override fun asinh(arg: StructureND): StructureND = + override fun asinh(arg: StructureND): StructureND = Nd4j.getExecutioner().exec(ASinh(arg.ndArray, arg.ndArray.ulike())).wrap() - public override fun acosh(arg: StructureND): StructureND = + override fun acosh(arg: StructureND): StructureND = Nd4j.getExecutioner().exec(ACosh(arg.ndArray, arg.ndArray.ulike())).wrap() - public override fun atanh(arg: StructureND): StructureND = Transforms.atanh(arg.ndArray).wrap() + override fun atanh(arg: StructureND): StructureND = Transforms.atanh(arg.ndArray).wrap() } /** * Represents [FieldND] over [Nd4jArrayDoubleStructure]. */ -public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayExtendedField { - public override val elementContext: DoubleField get() = DoubleField +public open class DoubleNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps { + override val elementAlgebra: DoubleField get() = DoubleField - public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asDoubleStructure() + override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() @OptIn(PerformancePitfall::class) override val StructureND.ndArray: INDArray get() = when (this) { - is Nd4jArrayStructure -> checkShape(ndArray) + is Nd4jArrayStructure -> ndArray else -> Nd4j.zeros(*shape).also { elements().forEach { (idx, value) -> it.putScalar(idx, value) } } } - override fun scale(a: StructureND, value: Double): Nd4jArrayStructure { - return a.ndArray.mul(value).wrap() - } + override fun scale(a: StructureND, value: Double): Nd4jArrayStructure = a.ndArray.mul(value).wrap() - public override operator fun StructureND.div(arg: Double): Nd4jArrayStructure { - return ndArray.div(arg).wrap() - } + override operator fun StructureND.div(arg: Double): Nd4jArrayStructure = ndArray.div(arg).wrap() - public override operator fun StructureND.plus(arg: Double): Nd4jArrayStructure { - return ndArray.add(arg).wrap() - } + override operator fun StructureND.plus(arg: Double): Nd4jArrayStructure = ndArray.add(arg).wrap() - public override operator fun StructureND.minus(arg: Double): Nd4jArrayStructure { - return ndArray.sub(arg).wrap() - } + override operator fun StructureND.minus(arg: Double): Nd4jArrayStructure = ndArray.sub(arg).wrap() - public override operator fun StructureND.times(arg: Double): Nd4jArrayStructure { - return ndArray.mul(arg).wrap() - } + override operator fun StructureND.times(arg: Double): Nd4jArrayStructure = ndArray.mul(arg).wrap() - public override operator fun Double.div(arg: StructureND): Nd4jArrayStructure { - return arg.ndArray.rdiv(this).wrap() - } + override operator fun Double.div(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rdiv(this).wrap() - public override operator fun Double.minus(arg: StructureND): Nd4jArrayStructure { - return arg.ndArray.rsub(this).wrap() - } + override operator fun Double.minus(arg: StructureND): Nd4jArrayStructure = + arg.ndArray.rsub(this).wrap() + + public companion object : DoubleNd4jArrayFieldOps() } +public val DoubleField.nd4j: DoubleNd4jArrayFieldOps get() = DoubleNd4jArrayFieldOps + +public class DoubleNd4jArrayField(override val shape: Shape) : DoubleNd4jArrayFieldOps(), FieldND + +public fun DoubleField.nd4j(shapeFirst: Int, vararg shapeRest: Int): DoubleNd4jArrayField = + DoubleNd4jArrayField(intArrayOf(shapeFirst, * shapeRest)) + + /** * Represents [FieldND] over [Nd4jArrayStructure] of [Float]. */ -public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayExtendedField { - public override val elementContext: FloatField get() = FloatField +public open class FloatNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps { + override val elementAlgebra: FloatField get() = FloatField - public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asFloatStructure() + override fun INDArray.wrap(): Nd4jArrayStructure = asFloatStructure() @OptIn(PerformancePitfall::class) - public override val StructureND.ndArray: INDArray + override val StructureND.ndArray: INDArray get() = when (this) { - is Nd4jArrayStructure -> checkShape(ndArray) + is Nd4jArrayStructure -> ndArray else -> Nd4j.zeros(*shape).also { elements().forEach { (idx, value) -> it.putScalar(idx, value) } } @@ -285,52 +246,69 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra override fun scale(a: StructureND, value: Double): StructureND = a.ndArray.mul(value).wrap() - public override operator fun StructureND.div(arg: Float): Nd4jArrayStructure = + override operator fun StructureND.div(arg: Float): Nd4jArrayStructure = ndArray.div(arg).wrap() - public override operator fun StructureND.plus(arg: Float): Nd4jArrayStructure = + override operator fun StructureND.plus(arg: Float): Nd4jArrayStructure = ndArray.add(arg).wrap() - public override operator fun StructureND.minus(arg: Float): Nd4jArrayStructure = + override operator fun StructureND.minus(arg: Float): Nd4jArrayStructure = ndArray.sub(arg).wrap() - public override operator fun StructureND.times(arg: Float): Nd4jArrayStructure = + override operator fun StructureND.times(arg: Float): Nd4jArrayStructure = ndArray.mul(arg).wrap() - public override operator fun Float.div(arg: StructureND): Nd4jArrayStructure = + override operator fun Float.div(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rdiv(this).wrap() - public override operator fun Float.minus(arg: StructureND): Nd4jArrayStructure = + override operator fun Float.minus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rsub(this).wrap() + + public companion object : FloatNd4jArrayFieldOps() } +public class FloatNd4jArrayField(override val shape: Shape) : FloatNd4jArrayFieldOps(), RingND + +public val FloatField.nd4j: FloatNd4jArrayFieldOps get() = FloatNd4jArrayFieldOps + +public fun FloatField.nd4j(shapeFirst: Int, vararg shapeRest: Int): FloatNd4jArrayField = + FloatNd4jArrayField(intArrayOf(shapeFirst, * shapeRest)) + /** * Represents [RingND] over [Nd4jArrayIntStructure]. */ -public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing { - public override val elementContext: IntRing - get() = IntRing +public open class IntNd4jArrayRingOps : Nd4jArrayRingOps { + override val elementAlgebra: IntRing get() = IntRing - public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asIntStructure() + override fun INDArray.wrap(): Nd4jArrayStructure = asIntStructure() @OptIn(PerformancePitfall::class) - public override val StructureND.ndArray: INDArray + override val StructureND.ndArray: INDArray get() = when (this) { - is Nd4jArrayStructure -> checkShape(ndArray) + is Nd4jArrayStructure -> ndArray else -> Nd4j.zeros(*shape).also { elements().forEach { (idx, value) -> it.putScalar(idx, value) } } } - public override operator fun StructureND.plus(arg: Int): Nd4jArrayStructure = + override operator fun StructureND.plus(arg: Int): Nd4jArrayStructure = ndArray.add(arg).wrap() - public override operator fun StructureND.minus(arg: Int): Nd4jArrayStructure = + override operator fun StructureND.minus(arg: Int): Nd4jArrayStructure = ndArray.sub(arg).wrap() - public override operator fun StructureND.times(arg: Int): Nd4jArrayStructure = + override operator fun StructureND.times(arg: Int): Nd4jArrayStructure = ndArray.mul(arg).wrap() - public override operator fun Int.minus(arg: StructureND): Nd4jArrayStructure = + override operator fun Int.minus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rsub(this).wrap() + + public companion object : IntNd4jArrayRingOps() } + +public val IntRing.nd4j: IntNd4jArrayRingOps get() = IntNd4jArrayRingOps + +public class IntNd4jArrayRing(override val shape: Shape) : IntNd4jArrayRingOps(), RingND + +public fun IntRing.nd4j(shapeFirst: Int, vararg shapeRest: Int): IntNd4jArrayRing = + IntNd4jArrayRing(intArrayOf(shapeFirst, * shapeRest)) \ No newline at end of file diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayIterator.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayIterator.kt index 140a212f8..5ae6f6b01 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayIterator.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayIterator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd4j @@ -25,7 +25,7 @@ private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iter internal fun INDArray.indicesIterator(): Iterator = Nd4jArrayIndicesIterator(this) -private sealed class Nd4jArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { +private sealed class Nd4jArrayIteratorBase(protected val iterateOver: INDArray) : Iterator> { private var i: Int = 0 final override fun hasNext(): Boolean = i < iterateOver.length() diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructure.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructure.kt index ffddcef90..82f560fdb 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructure.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructure.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd4j @@ -17,18 +17,18 @@ import space.kscience.kmath.nd.StructureND */ public sealed class Nd4jArrayStructure : MutableStructureND { /** - * The wrapped [INDArray]. Since KMath uses [Int] indexes, assuming that the size of [INDArray] is less or equal to + * The wrapped [INDArray]. Since KMath uses [Int] indexes, assuming the size of [INDArray] is less or equal to * [Int.MAX_VALUE]. */ public abstract val ndArray: INDArray - public override val shape: IntArray get() = ndArray.shape().toIntArray() + override val shape: IntArray get() = ndArray.shape().toIntArray() internal abstract fun elementsIterator(): Iterator> internal fun indicesIterator(): Iterator = ndArray.indicesIterator() @PerformancePitfall - public override fun elements(): Sequence> = Sequence(::elementsIterator) + override fun elements(): Sequence> = Sequence(::elementsIterator) } private data class Nd4jArrayIntStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index 6e84371ae..d7dd6e71b 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd4j @@ -13,7 +13,11 @@ import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.ops.NDBase import org.nd4j.linalg.ops.transforms.Transforms import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.DefaultStrides +import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.Field import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra @@ -22,116 +26,133 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra /** * ND4J based [TensorAlgebra] implementation. */ -public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra { +public sealed interface Nd4jTensorAlgebra> : AnalyticTensorAlgebra { + /** * Wraps [INDArray] to [Nd4jArrayStructure]. */ public fun INDArray.wrap(): Nd4jArrayStructure /** - * Unwraps to or acquires [INDArray] from [StructureND]. + * Unwraps to or gets [INDArray] from [StructureND]. */ public val StructureND.ndArray: INDArray - public override fun T.plus(other: Tensor): Tensor = other.ndArray.add(this).wrap() - public override fun Tensor.plus(value: T): Tensor = ndArray.add(value).wrap() + override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure - public override fun Tensor.plus(other: Tensor): Tensor = ndArray.add(other.ndArray).wrap() + override fun StructureND.map(transform: A.(T) -> T): Nd4jArrayStructure = + structureND(shape) { index -> elementAlgebra.transform(get(index)) } - public override fun Tensor.plusAssign(value: T) { + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure = + structureND(shape) { index -> elementAlgebra.transform(index, get(index)) } + + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): Nd4jArrayStructure { + require(left.shape.contentEquals(right.shape)) + return structureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) } + } + + override fun T.plus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.add(this).wrap() + override fun StructureND.plus(arg: T): Nd4jArrayStructure = ndArray.add(arg).wrap() + + override fun StructureND.plus(arg: StructureND): Nd4jArrayStructure = ndArray.add(arg.ndArray).wrap() + + override fun Tensor.plusAssign(value: T) { ndArray.addi(value) } - public override fun Tensor.plusAssign(other: Tensor) { - ndArray.addi(other.ndArray) + override fun Tensor.plusAssign(arg: StructureND) { + ndArray.addi(arg.ndArray) } - public override fun T.minus(other: Tensor): Tensor = other.ndArray.rsub(this).wrap() - public override fun Tensor.minus(value: T): Tensor = ndArray.sub(value).wrap() - public override fun Tensor.minus(other: Tensor): Tensor = ndArray.sub(other.ndArray).wrap() + override fun T.minus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rsub(this).wrap() + override fun StructureND.minus(arg: T): Nd4jArrayStructure = ndArray.sub(arg).wrap() + override fun StructureND.minus(arg: StructureND): Nd4jArrayStructure = ndArray.sub(arg.ndArray).wrap() - public override fun Tensor.minusAssign(value: T) { + override fun Tensor.minusAssign(value: T) { ndArray.rsubi(value) } - public override fun Tensor.minusAssign(other: Tensor) { - ndArray.subi(other.ndArray) + override fun Tensor.minusAssign(arg: StructureND) { + ndArray.subi(arg.ndArray) } - public override fun T.times(other: Tensor): Tensor = other.ndArray.mul(this).wrap() + override fun T.times(arg: StructureND): Nd4jArrayStructure = arg.ndArray.mul(this).wrap() - public override fun Tensor.times(value: T): Tensor = - ndArray.mul(value).wrap() + override fun StructureND.times(arg: T): Nd4jArrayStructure = + ndArray.mul(arg).wrap() - public override fun Tensor.times(other: Tensor): Tensor = ndArray.mul(other.ndArray).wrap() + override fun StructureND.times(arg: StructureND): Nd4jArrayStructure = ndArray.mul(arg.ndArray).wrap() - public override fun Tensor.timesAssign(value: T) { + override fun Tensor.timesAssign(value: T) { ndArray.muli(value) } - public override fun Tensor.timesAssign(other: Tensor) { - ndArray.mmuli(other.ndArray) + override fun Tensor.timesAssign(arg: StructureND) { + ndArray.mmuli(arg.ndArray) } - public override fun Tensor.unaryMinus(): Tensor = ndArray.neg().wrap() - public override fun Tensor.get(i: Int): Tensor = ndArray.slice(i.toLong()).wrap() - public override fun Tensor.transpose(i: Int, j: Int): Tensor = ndArray.swapAxes(i, j).wrap() - public override fun Tensor.dot(other: Tensor): Tensor = ndArray.mmul(other.ndArray).wrap() + override fun StructureND.unaryMinus(): Nd4jArrayStructure = ndArray.neg().wrap() + override fun Tensor.get(i: Int): Nd4jArrayStructure = ndArray.slice(i.toLong()).wrap() + override fun Tensor.transpose(i: Int, j: Int): Nd4jArrayStructure = ndArray.swapAxes(i, j).wrap() + override fun StructureND.dot(other: StructureND): Nd4jArrayStructure = ndArray.mmul(other.ndArray).wrap() - public override fun Tensor.min(dim: Int, keepDim: Boolean): Tensor = + override fun StructureND.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.min(keepDim, dim).wrap() - public override fun Tensor.sum(dim: Int, keepDim: Boolean): Tensor = + override fun StructureND.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.sum(keepDim, dim).wrap() - public override fun Tensor.max(dim: Int, keepDim: Boolean): Tensor = + override fun StructureND.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.max(keepDim, dim).wrap() - public override fun Tensor.view(shape: IntArray): Tensor = ndArray.reshape(shape).wrap() - public override fun Tensor.viewAs(other: Tensor): Tensor = view(other.shape) + override fun Tensor.view(shape: IntArray): Nd4jArrayStructure = ndArray.reshape(shape).wrap() + override fun Tensor.viewAs(other: StructureND): Nd4jArrayStructure = view(other.shape) - override fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor = + override fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor = ndBase.get().argmax(ndArray, keepDim, dim).asIntStructure() - public override fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor = ndArray.mean(keepDim, dim).wrap() + override fun StructureND.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + ndArray.mean(keepDim, dim).wrap() - public override fun Tensor.exp(): Tensor = Transforms.exp(ndArray).wrap() - public override fun Tensor.ln(): Tensor = Transforms.log(ndArray).wrap() - public override fun Tensor.sqrt(): Tensor = Transforms.sqrt(ndArray).wrap() - public override fun Tensor.cos(): Tensor = Transforms.cos(ndArray).wrap() - public override fun Tensor.acos(): Tensor = Transforms.acos(ndArray).wrap() - public override fun Tensor.cosh(): Tensor = Transforms.cosh(ndArray).wrap() + override fun StructureND.exp(): Nd4jArrayStructure = Transforms.exp(ndArray).wrap() + override fun StructureND.ln(): Nd4jArrayStructure = Transforms.log(ndArray).wrap() + override fun StructureND.sqrt(): Nd4jArrayStructure = Transforms.sqrt(ndArray).wrap() + override fun StructureND.cos(): Nd4jArrayStructure = Transforms.cos(ndArray).wrap() + override fun StructureND.acos(): Nd4jArrayStructure = Transforms.acos(ndArray).wrap() + override fun StructureND.cosh(): Nd4jArrayStructure = Transforms.cosh(ndArray).wrap() - public override fun Tensor.acosh(): Tensor = + override fun StructureND.acosh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() - public override fun Tensor.sin(): Tensor = Transforms.sin(ndArray).wrap() - public override fun Tensor.asin(): Tensor = Transforms.asin(ndArray).wrap() - public override fun Tensor.sinh(): Tensor = Transforms.sinh(ndArray).wrap() + override fun StructureND.sin(): Nd4jArrayStructure = Transforms.sin(ndArray).wrap() + override fun StructureND.asin(): Nd4jArrayStructure = Transforms.asin(ndArray).wrap() + override fun StructureND.sinh(): Tensor = Transforms.sinh(ndArray).wrap() - public override fun Tensor.asinh(): Tensor = + override fun StructureND.asinh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() - public override fun Tensor.tan(): Tensor = Transforms.tan(ndArray).wrap() - public override fun Tensor.atan(): Tensor = Transforms.atan(ndArray).wrap() - public override fun Tensor.tanh(): Tensor = Transforms.tanh(ndArray).wrap() - public override fun Tensor.atanh(): Tensor = Transforms.atanh(ndArray).wrap() - public override fun Tensor.ceil(): Tensor = Transforms.ceil(ndArray).wrap() - public override fun Tensor.floor(): Tensor = Transforms.floor(ndArray).wrap() - public override fun Tensor.std(dim: Int, keepDim: Boolean): Tensor = ndArray.std(true, keepDim, dim).wrap() - public override fun T.div(other: Tensor): Tensor = other.ndArray.rdiv(this).wrap() - public override fun Tensor.div(value: T): Tensor = ndArray.div(value).wrap() - public override fun Tensor.div(other: Tensor): Tensor = ndArray.div(other.ndArray).wrap() + override fun StructureND.tan(): Nd4jArrayStructure = Transforms.tan(ndArray).wrap() + override fun StructureND.atan(): Nd4jArrayStructure = Transforms.atan(ndArray).wrap() + override fun StructureND.tanh(): Nd4jArrayStructure = Transforms.tanh(ndArray).wrap() + override fun StructureND.atanh(): Nd4jArrayStructure = Transforms.atanh(ndArray).wrap() + override fun StructureND.ceil(): Nd4jArrayStructure = Transforms.ceil(ndArray).wrap() + override fun StructureND.floor(): Nd4jArrayStructure = Transforms.floor(ndArray).wrap() + override fun StructureND.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + ndArray.std(true, keepDim, dim).wrap() - public override fun Tensor.divAssign(value: T) { + override fun T.div(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rdiv(this).wrap() + override fun StructureND.div(arg: T): Nd4jArrayStructure = ndArray.div(arg).wrap() + override fun StructureND.div(arg: StructureND): Nd4jArrayStructure = ndArray.div(arg.ndArray).wrap() + + override fun Tensor.divAssign(value: T) { ndArray.divi(value) } - public override fun Tensor.divAssign(other: Tensor) { - ndArray.divi(other.ndArray) + override fun Tensor.divAssign(arg: StructureND) { + ndArray.divi(arg.ndArray) } - public override fun Tensor.variance(dim: Int, keepDim: Boolean): Tensor = + override fun StructureND.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure = Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap() private companion object { @@ -142,11 +163,24 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra /** * [Double] specialization of [Nd4jTensorAlgebra]. */ -public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { - public override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() +public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { + + override val elementAlgebra: DoubleField get() = DoubleField + + override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() + + override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): Nd4jArrayStructure { + val array: INDArray = Nd4j.zeros(*shape) + val indices = DefaultStrides(shape) + indices.asSequence().forEach { index -> + array.putScalar(index, elementAlgebra.initializer(index)) + } + return array.wrap() + } + @OptIn(PerformancePitfall::class) - public override val StructureND.ndArray: INDArray + override val StructureND.ndArray: INDArray get() = when (this) { is Nd4jArrayStructure -> ndArray else -> Nd4j.zeros(*shape).also { @@ -154,22 +188,21 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { } } - public override fun Tensor.valueOrNull(): Double? = + override fun StructureND.valueOrNull(): Double? = if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null // TODO rewrite - @PerformancePitfall - public override fun diagonalEmbedding( + override fun diagonalEmbedding( diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int, ): Tensor = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2) - public override fun Tensor.sum(): Double = ndArray.sumNumber().toDouble() - public override fun Tensor.min(): Double = ndArray.minNumber().toDouble() - public override fun Tensor.max(): Double = ndArray.maxNumber().toDouble() - public override fun Tensor.mean(): Double = ndArray.meanNumber().toDouble() - public override fun Tensor.std(): Double = ndArray.stdNumber().toDouble() - public override fun Tensor.variance(): Double = ndArray.varNumber().toDouble() + override fun StructureND.sum(): Double = ndArray.sumNumber().toDouble() + override fun StructureND.min(): Double = ndArray.minNumber().toDouble() + override fun StructureND.max(): Double = ndArray.maxNumber().toDouble() + override fun StructureND.mean(): Double = ndArray.meanNumber().toDouble() + override fun StructureND.std(): Double = ndArray.stdNumber().toDouble() + override fun StructureND.variance(): Double = ndArray.varNumber().toDouble() } diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/arrays.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/arrays.kt index 75a334ca7..cc9211b20 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/arrays.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/arrays.kt @@ -1,8 +1,10 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd4j -internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() } +import space.kscience.kmath.misc.toIntExact + +internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toIntExact() } diff --git a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt index 40da22763..103416120 100644 --- a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt +++ b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd4j @@ -8,6 +8,10 @@ package space.kscience.kmath.nd4j import org.nd4j.linalg.factory.Nd4j import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.nd.one +import space.kscience.kmath.nd.structureND +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.invoke import kotlin.math.PI import kotlin.test.Test @@ -19,7 +23,7 @@ import kotlin.test.fail internal class Nd4jArrayAlgebraTest { @Test fun testProduce() { - val res = with(DoubleNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } + val res = DoubleField.nd4j.structureND(2, 2) { it.sum().toDouble() } val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure() expected[intArrayOf(0, 0)] = 0.0 expected[intArrayOf(0, 1)] = 1.0 @@ -30,7 +34,9 @@ internal class Nd4jArrayAlgebraTest { @Test fun testMap() { - val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one.map { it + it * 2 } } + val res = IntRing.nd4j { + one(2, 2).map { it + it * 2 } + } val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() expected[intArrayOf(0, 0)] = 3 expected[intArrayOf(0, 1)] = 3 @@ -41,7 +47,7 @@ internal class Nd4jArrayAlgebraTest { @Test fun testAdd() { - val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 } + val res = IntRing.nd4j { one(2, 2) + 25 } val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() expected[intArrayOf(0, 0)] = 26 expected[intArrayOf(0, 1)] = 26 @@ -51,10 +57,10 @@ internal class Nd4jArrayAlgebraTest { } @Test - fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke { - val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 } + fun testSin() = DoubleField.nd4j{ + val initial = structureND(2, 2) { (i, j) -> if (i == j) PI / 2 else 0.0 } val transformed = sin(initial) - val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 } + val expected = structureND(2, 2) { (i, j) -> if (i == j) 1.0 else 0.0 } println(transformed) assertTrue { StructureND.contentEquals(transformed, expected) } diff --git a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt index 30d01338f..ff55ad521 100644 --- a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.nd4j diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 93dbe8f62..edf380f24 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -6,20 +6,22 @@ package space.kscience.kmath.noa import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.noa.memory.NoaScope +import space.kscience.kmath.operations.* import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.core.TensorLinearStructure -typealias Slice = Pair +internal typealias Slice = Pair -public sealed class NoaAlgebra> +public sealed class NoaAlgebra, PrimitiveArray, TensorType : NoaTensor> protected constructor(protected val scope: NoaScope) : - TensorAlgebra { + TensorAlgebra { - protected abstract val Tensor.tensor: TensorType + protected abstract val StructureND.tensor: TensorType protected abstract fun wrap(tensorHandle: TensorHandle): TensorType @@ -29,14 +31,14 @@ protected constructor(protected val scope: NoaScope) : /** * A scalar tensor must have empty shape */ - override fun Tensor.valueOrNull(): T? = + override fun StructureND.valueOrNull(): T? = try { tensor.item() } catch (e: NoaException) { null } - override fun Tensor.value(): T = tensor.item() + override fun StructureND.value(): T = tensor.item() public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device = Device.CPU): TensorType @@ -46,43 +48,43 @@ protected constructor(protected val scope: NoaScope) : public abstract fun full(value: T, shape: IntArray, device: Device = Device.CPU): TensorType - override operator fun Tensor.times(other: Tensor): TensorType { - return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle)) + override operator fun StructureND.times(arg: StructureND): TensorType { + return wrap(JNoa.timesTensor(tensor.tensorHandle, arg.tensor.tensorHandle)) } - override operator fun Tensor.timesAssign(other: Tensor): Unit { - JNoa.timesTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) + override operator fun Tensor.timesAssign(arg: StructureND): Unit { + JNoa.timesTensorAssign(tensor.tensorHandle, arg.tensor.tensorHandle) } - override operator fun Tensor.plus(other: Tensor): TensorType { - return wrap(JNoa.plusTensor(tensor.tensorHandle, other.tensor.tensorHandle)) + override operator fun StructureND.plus(arg: StructureND): TensorType { + return wrap(JNoa.plusTensor(tensor.tensorHandle, arg.tensor.tensorHandle)) } - override operator fun Tensor.plusAssign(other: Tensor): Unit { - JNoa.plusTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) + override operator fun Tensor.plusAssign(arg: StructureND): Unit { + JNoa.plusTensorAssign(tensor.tensorHandle, arg.tensor.tensorHandle) } - override operator fun Tensor.minus(other: Tensor): TensorType { - return wrap(JNoa.minusTensor(tensor.tensorHandle, other.tensor.tensorHandle)) + override operator fun StructureND.minus(arg: StructureND): TensorType { + return wrap(JNoa.minusTensor(tensor.tensorHandle, arg.tensor.tensorHandle)) } - override operator fun Tensor.minusAssign(other: Tensor): Unit { - JNoa.minusTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) + override operator fun Tensor.minusAssign(arg: StructureND): Unit { + JNoa.minusTensorAssign(tensor.tensorHandle, arg.tensor.tensorHandle) } - override operator fun Tensor.unaryMinus(): TensorType = + override operator fun StructureND.unaryMinus(): TensorType = wrap(JNoa.unaryMinus(tensor.tensorHandle)) - override infix fun Tensor.dot(other: Tensor): TensorType { + override infix fun StructureND.dot(other: StructureND): TensorType { return wrap(JNoa.matmul(tensor.tensorHandle, other.tensor.tensorHandle)) } - public infix fun Tensor.dotAssign(other: Tensor): Unit { - JNoa.matmulAssign(tensor.tensorHandle, other.tensor.tensorHandle) + public infix fun Tensor.dotAssign(arg: StructureND): Unit { + JNoa.matmulAssign(tensor.tensorHandle, arg.tensor.tensorHandle) } - public infix fun Tensor.dotRightAssign(other: Tensor): Unit { - JNoa.matmulRightAssign(tensor.tensorHandle, other.tensor.tensorHandle) + public infix fun StructureND.dotRightAssign(arg: Tensor): Unit { + JNoa.matmulRightAssign(tensor.tensorHandle, arg.tensor.tensorHandle) } override operator fun Tensor.get(i: Int): TensorType = @@ -114,28 +116,28 @@ protected constructor(protected val scope: NoaScope) : return wrap(JNoa.viewTensor(tensor.tensorHandle, shape)) } - override fun Tensor.viewAs(other: Tensor): TensorType { + override fun Tensor.viewAs(other: StructureND): TensorType { return wrap(JNoa.viewAsTensor(tensor.tensorHandle, other.tensor.tensorHandle)) } - public fun Tensor.abs(): TensorType = wrap(JNoa.absTensor(tensor.tensorHandle)) + public fun StructureND.abs(): TensorType = wrap(JNoa.absTensor(tensor.tensorHandle)) - public fun Tensor.sumAll(): TensorType = wrap(JNoa.sumTensor(tensor.tensorHandle)) - override fun Tensor.sum(): T = sumAll().item() - override fun Tensor.sum(dim: Int, keepDim: Boolean): TensorType = + public fun StructureND.sumAll(): TensorType = wrap(JNoa.sumTensor(tensor.tensorHandle)) + override fun StructureND.sum(): T = sumAll().item() + override fun StructureND.sum(dim: Int, keepDim: Boolean): TensorType = wrap(JNoa.sumDimTensor(tensor.tensorHandle, dim, keepDim)) - public fun Tensor.minAll(): TensorType = wrap(JNoa.minTensor(tensor.tensorHandle)) - override fun Tensor.min(): T = minAll().item() - override fun Tensor.min(dim: Int, keepDim: Boolean): TensorType = + public fun StructureND.minAll(): TensorType = wrap(JNoa.minTensor(tensor.tensorHandle)) + override fun StructureND.min(): T = minAll().item() + override fun StructureND.min(dim: Int, keepDim: Boolean): TensorType = wrap(JNoa.minDimTensor(tensor.tensorHandle, dim, keepDim)) - public fun Tensor.maxAll(): TensorType = wrap(JNoa.maxTensor(tensor.tensorHandle)) - override fun Tensor.max(): T = maxAll().item() - override fun Tensor.max(dim: Int, keepDim: Boolean): TensorType = + public fun StructureND.maxAll(): TensorType = wrap(JNoa.maxTensor(tensor.tensorHandle)) + override fun StructureND.max(): T = maxAll().item() + override fun StructureND.max(dim: Int, keepDim: Boolean): TensorType = wrap(JNoa.maxDimTensor(tensor.tensorHandle, dim, keepDim)) - override fun Tensor.argMax(dim: Int, keepDim: Boolean): NoaIntTensor = + override fun StructureND.argMax(dim: Int, keepDim: Boolean): NoaIntTensor = NoaIntTensor(scope, JNoa.argMaxTensor(tensor.tensorHandle, dim, keepDim)) public fun Tensor.flatten(startDim: Int, endDim: Int): TensorType = @@ -175,119 +177,119 @@ protected constructor(protected val scope: NoaScope) : public fun NoaJitModule.setBuffer(name: String, buffer: Tensor): Unit = JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle) - public infix fun TensorType.swap(other: TensorType): Unit = - JNoa.swapTensors(tensorHandle, other.tensorHandle) + public infix fun TensorType.swap(arg: TensorType): Unit = + JNoa.swapTensors(tensorHandle, arg.tensorHandle) public abstract fun TensorType.assignFromArray(array: PrimitiveArray): Unit } -public sealed class NoaPartialDivisionAlgebra> +public sealed class NoaPartialDivisionAlgebra, PrimitiveArray, TensorType : NoaTensor> protected constructor(scope: NoaScope) : - NoaAlgebra(scope), - LinearOpsTensorAlgebra, - AnalyticTensorAlgebra { + NoaAlgebra(scope), + LinearOpsTensorAlgebra, + AnalyticTensorAlgebra { - override operator fun Tensor.div(other: Tensor): TensorType { - return wrap(JNoa.divTensor(tensor.tensorHandle, other.tensor.tensorHandle)) + override operator fun StructureND.div(arg: StructureND): TensorType { + return wrap(JNoa.divTensor(tensor.tensorHandle, arg.tensor.tensorHandle)) } - override operator fun Tensor.divAssign(other: Tensor): Unit { - JNoa.divTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) + override operator fun Tensor.divAssign(arg: StructureND): Unit { + JNoa.divTensorAssign(tensor.tensorHandle, arg.tensor.tensorHandle) } - public fun Tensor.meanAll(): TensorType = wrap(JNoa.meanTensor(tensor.tensorHandle)) - override fun Tensor.mean(): T = meanAll().item() - override fun Tensor.mean(dim: Int, keepDim: Boolean): TensorType = + public fun StructureND.meanAll(): TensorType = wrap(JNoa.meanTensor(tensor.tensorHandle)) + override fun StructureND.mean(): T = meanAll().item() + override fun StructureND.mean(dim: Int, keepDim: Boolean): TensorType = wrap(JNoa.meanDimTensor(tensor.tensorHandle, dim, keepDim)) - public fun Tensor.stdAll(): TensorType = wrap(JNoa.stdTensor(tensor.tensorHandle)) - override fun Tensor.std(): T = stdAll().item() - override fun Tensor.std(dim: Int, keepDim: Boolean): TensorType = + public fun StructureND.stdAll(): TensorType = wrap(JNoa.stdTensor(tensor.tensorHandle)) + override fun StructureND.std(): T = stdAll().item() + override fun StructureND.std(dim: Int, keepDim: Boolean): TensorType = wrap(JNoa.stdDimTensor(tensor.tensorHandle, dim, keepDim)) - public fun Tensor.varAll(): TensorType = wrap(JNoa.varTensor(tensor.tensorHandle)) - override fun Tensor.variance(): T = varAll().item() - override fun Tensor.variance(dim: Int, keepDim: Boolean): TensorType = + public fun StructureND.varAll(): TensorType = wrap(JNoa.varTensor(tensor.tensorHandle)) + override fun StructureND.variance(): T = varAll().item() + override fun StructureND.variance(dim: Int, keepDim: Boolean): TensorType = wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim)) public abstract fun randNormal(shape: IntArray, device: Device = Device.CPU): TensorType public abstract fun randUniform(shape: IntArray, device: Device = Device.CPU): TensorType - public fun Tensor.randUniform(): TensorType = + public fun StructureND.randUniform(): TensorType = wrap(JNoa.randLike(tensor.tensorHandle)) - public fun Tensor.randUniformAssign(): Unit = + public fun StructureND.randUniformAssign(): Unit = JNoa.randLikeAssign(tensor.tensorHandle) - public fun Tensor.randNormal(): TensorType = + public fun StructureND.randNormal(): TensorType = wrap(JNoa.randnLike(tensor.tensorHandle)) - public fun Tensor.randNormalAssign(): Unit = + public fun StructureND.randNormalAssign(): Unit = JNoa.randnLikeAssign(tensor.tensorHandle) - override fun Tensor.exp(): TensorType = + override fun StructureND.exp(): TensorType = wrap(JNoa.expTensor(tensor.tensorHandle)) - override fun Tensor.ln(): TensorType = + override fun StructureND.ln(): TensorType = wrap(JNoa.lnTensor(tensor.tensorHandle)) - override fun Tensor.sqrt(): TensorType = + override fun StructureND.sqrt(): TensorType = wrap(JNoa.sqrtTensor(tensor.tensorHandle)) - override fun Tensor.cos(): TensorType = + override fun StructureND.cos(): TensorType = wrap(JNoa.cosTensor(tensor.tensorHandle)) - override fun Tensor.acos(): TensorType = + override fun StructureND.acos(): TensorType = wrap(JNoa.acosTensor(tensor.tensorHandle)) - override fun Tensor.cosh(): TensorType = + override fun StructureND.cosh(): TensorType = wrap(JNoa.coshTensor(tensor.tensorHandle)) - override fun Tensor.acosh(): TensorType = + override fun StructureND.acosh(): TensorType = wrap(JNoa.acoshTensor(tensor.tensorHandle)) - override fun Tensor.sin(): TensorType = + override fun StructureND.sin(): TensorType = wrap(JNoa.sinTensor(tensor.tensorHandle)) - override fun Tensor.asin(): TensorType = + override fun StructureND.asin(): TensorType = wrap(JNoa.asinTensor(tensor.tensorHandle)) - override fun Tensor.sinh(): TensorType = + override fun StructureND.sinh(): TensorType = wrap(JNoa.sinhTensor(tensor.tensorHandle)) - override fun Tensor.asinh(): TensorType = + override fun StructureND.asinh(): TensorType = wrap(JNoa.asinhTensor(tensor.tensorHandle)) - override fun Tensor.tan(): TensorType = + override fun StructureND.tan(): TensorType = wrap(JNoa.tanTensor(tensor.tensorHandle)) - override fun Tensor.atan(): TensorType = + override fun StructureND.atan(): TensorType = wrap(JNoa.atanTensor(tensor.tensorHandle)) - override fun Tensor.tanh(): TensorType = + override fun StructureND.tanh(): TensorType = wrap(JNoa.tanhTensor(tensor.tensorHandle)) - override fun Tensor.atanh(): TensorType = + override fun StructureND.atanh(): TensorType = wrap(JNoa.atanhTensor(tensor.tensorHandle)) - override fun Tensor.ceil(): TensorType = + override fun StructureND.ceil(): TensorType = wrap(JNoa.ceilTensor(tensor.tensorHandle)) - override fun Tensor.floor(): TensorType = + override fun StructureND.floor(): TensorType = wrap(JNoa.floorTensor(tensor.tensorHandle)) - override fun Tensor.det(): Tensor = + override fun StructureND.det(): Tensor = wrap(JNoa.detTensor(tensor.tensorHandle)) - override fun Tensor.inv(): Tensor = + override fun StructureND.inv(): Tensor = wrap(JNoa.invTensor(tensor.tensorHandle)) - override fun Tensor.cholesky(): Tensor = + override fun StructureND.cholesky(): Tensor = wrap(JNoa.choleskyTensor(tensor.tensorHandle)) - override fun Tensor.qr(): Pair { + override fun StructureND.qr(): Pair { val Q = JNoa.emptyTensor() val R = JNoa.emptyTensor() JNoa.qrTensor(tensor.tensorHandle, Q, R) @@ -297,7 +299,7 @@ protected constructor(scope: NoaScope) : /** * this implementation satisfies `tensor = P dot L dot U` */ - override fun Tensor.lu(): Triple { + override fun StructureND.lu(): Triple { val P = JNoa.emptyTensor() val L = JNoa.emptyTensor() val U = JNoa.emptyTensor() @@ -305,7 +307,7 @@ protected constructor(scope: NoaScope) : return Triple(wrap(P), wrap(L), wrap(U)) } - override fun Tensor.svd(): Triple { + override fun StructureND.svd(): Triple { val U = JNoa.emptyTensor() val V = JNoa.emptyTensor() val S = JNoa.emptyTensor() @@ -313,7 +315,7 @@ protected constructor(scope: NoaScope) : return Triple(wrap(U), wrap(S), wrap(V)) } - override fun Tensor.symEig(): Pair { + override fun StructureND.symEig(): Pair { val V = JNoa.emptyTensor() val S = JNoa.emptyTensor() JNoa.symEigTensor(tensor.tensorHandle, S, V) @@ -344,15 +346,25 @@ protected constructor(scope: NoaScope) : public sealed class NoaDoubleAlgebra protected constructor(scope: NoaScope) : - NoaPartialDivisionAlgebra(scope) { + NoaPartialDivisionAlgebra(scope) { - private fun Tensor.castHelper(): NoaDoubleTensor = + override val elementAlgebra: DoubleField + get() = DoubleField + + override fun structureND(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): NoaDoubleTensor = copyFromArray( - TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toDoubleArray(), + TensorLinearStructure(shape).asSequence().map { DoubleField.initializer(it) }.toMutableList() + .toDoubleArray(), + shape, Device.CPU + ) + + private fun StructureND.castHelper(): NoaDoubleTensor = + copyFromArray( + TensorLinearStructure(this.shape).asSequence().map(this::get).toMutableList().toDoubleArray(), this.shape, Device.CPU ) - override val Tensor.tensor: NoaDoubleTensor + override val StructureND.tensor: NoaDoubleTensor get() = when (this) { is NoaDoubleTensor -> this else -> castHelper() @@ -379,37 +391,37 @@ protected constructor(scope: NoaScope) : override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor = wrap(JNoa.randintDouble(low, high, shape, device.toInt())) - override operator fun Double.plus(other: Tensor): NoaDoubleTensor = - wrap(JNoa.plusDouble(this, other.tensor.tensorHandle)) + override operator fun Double.plus(arg: StructureND): NoaDoubleTensor = + wrap(JNoa.plusDouble(this, arg.tensor.tensorHandle)) - override fun Tensor.plus(value: Double): NoaDoubleTensor = + override fun StructureND.plus(value: Double): NoaDoubleTensor = wrap(JNoa.plusDouble(value, tensor.tensorHandle)) override fun Tensor.plusAssign(value: Double): Unit = JNoa.plusDoubleAssign(value, tensor.tensorHandle) - override operator fun Double.minus(other: Tensor): NoaDoubleTensor = - wrap(JNoa.plusDouble(-this, other.tensor.tensorHandle)) + override operator fun Double.minus(arg: StructureND): NoaDoubleTensor = + wrap(JNoa.plusDouble(-this, arg.tensor.tensorHandle)) - override fun Tensor.minus(value: Double): NoaDoubleTensor = + override fun StructureND.minus(value: Double): NoaDoubleTensor = wrap(JNoa.plusDouble(-value, tensor.tensorHandle)) override fun Tensor.minusAssign(value: Double): Unit = JNoa.plusDoubleAssign(-value, tensor.tensorHandle) - override operator fun Double.times(other: Tensor): NoaDoubleTensor = - wrap(JNoa.timesDouble(this, other.tensor.tensorHandle)) + override operator fun Double.times(arg: StructureND): NoaDoubleTensor = + wrap(JNoa.timesDouble(this, arg.tensor.tensorHandle)) - override fun Tensor.times(value: Double): NoaDoubleTensor = + override fun StructureND.times(value: Double): NoaDoubleTensor = wrap(JNoa.timesDouble(value, tensor.tensorHandle)) override fun Tensor.timesAssign(value: Double): Unit = JNoa.timesDoubleAssign(value, tensor.tensorHandle) - override fun Double.div(other: Tensor): NoaDoubleTensor = - other.tensor * (1 / this) + override fun Double.div(arg: StructureND): NoaDoubleTensor = + arg.tensor * (1 / this) - override fun Tensor.div(value: Double): NoaDoubleTensor = + override fun StructureND.div(value: Double): NoaDoubleTensor = tensor * (1 / value) override fun Tensor.divAssign(value: Double): Unit = @@ -436,15 +448,25 @@ protected constructor(scope: NoaScope) : public sealed class NoaFloatAlgebra protected constructor(scope: NoaScope) : - NoaPartialDivisionAlgebra(scope) { + NoaPartialDivisionAlgebra(scope) { - private fun Tensor.castHelper(): NoaFloatTensor = + override val elementAlgebra: FloatField + get() = FloatField + + override fun structureND(shape: IntArray, initializer: FloatField.(IntArray) -> Float): NoaFloatTensor = copyFromArray( - TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toFloatArray(), + TensorLinearStructure(shape).asSequence().map { FloatField.initializer(it) }.toMutableList() + .toFloatArray(), + shape, Device.CPU + ) + + private fun StructureND.castHelper(): NoaFloatTensor = + copyFromArray( + TensorLinearStructure(this.shape).asSequence().map(this::get).toMutableList().toFloatArray(), this.shape, Device.CPU ) - override val Tensor.tensor: NoaFloatTensor + override val StructureND.tensor: NoaFloatTensor get() = when (this) { is NoaFloatTensor -> this else -> castHelper() @@ -471,37 +493,37 @@ protected constructor(scope: NoaScope) : override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor = wrap(JNoa.randintFloat(low, high, shape, device.toInt())) - override operator fun Float.plus(other: Tensor): NoaFloatTensor = - wrap(JNoa.plusFloat(this, other.tensor.tensorHandle)) + override operator fun Float.plus(arg: StructureND): NoaFloatTensor = + wrap(JNoa.plusFloat(this, arg.tensor.tensorHandle)) - override fun Tensor.plus(value: Float): NoaFloatTensor = + override fun StructureND.plus(value: Float): NoaFloatTensor = wrap(JNoa.plusFloat(value, tensor.tensorHandle)) override fun Tensor.plusAssign(value: Float): Unit = JNoa.plusFloatAssign(value, tensor.tensorHandle) - override operator fun Float.minus(other: Tensor): NoaFloatTensor = - wrap(JNoa.plusFloat(-this, other.tensor.tensorHandle)) + override operator fun Float.minus(arg: StructureND): NoaFloatTensor = + wrap(JNoa.plusFloat(-this, arg.tensor.tensorHandle)) - override fun Tensor.minus(value: Float): NoaFloatTensor = + override fun StructureND.minus(value: Float): NoaFloatTensor = wrap(JNoa.plusFloat(-value, tensor.tensorHandle)) override fun Tensor.minusAssign(value: Float): Unit = JNoa.plusFloatAssign(-value, tensor.tensorHandle) - override operator fun Float.times(other: Tensor): NoaFloatTensor = - wrap(JNoa.timesFloat(this, other.tensor.tensorHandle)) + override operator fun Float.times(arg: StructureND): NoaFloatTensor = + wrap(JNoa.timesFloat(this, arg.tensor.tensorHandle)) - override fun Tensor.times(value: Float): NoaFloatTensor = + override fun StructureND.times(value: Float): NoaFloatTensor = wrap(JNoa.timesFloat(value, tensor.tensorHandle)) override fun Tensor.timesAssign(value: Float): Unit = JNoa.timesFloatAssign(value, tensor.tensorHandle) - override fun Float.div(other: Tensor): NoaFloatTensor = - other.tensor * (1 / this) + override fun Float.div(arg: StructureND): NoaFloatTensor = + arg.tensor * (1 / this) - override fun Tensor.div(value: Float): NoaFloatTensor = + override fun StructureND.div(value: Float): NoaFloatTensor = tensor * (1 / value) override fun Tensor.divAssign(value: Float): Unit = @@ -529,15 +551,25 @@ protected constructor(scope: NoaScope) : public sealed class NoaLongAlgebra protected constructor(scope: NoaScope) : - NoaAlgebra(scope) { + NoaAlgebra(scope) { - private fun Tensor.castHelper(): NoaLongTensor = + override val elementAlgebra: LongRing + get() = LongRing + + override fun structureND(shape: IntArray, initializer: LongRing.(IntArray) -> Long): NoaLongTensor = copyFromArray( - TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toLongArray(), + TensorLinearStructure(shape).asSequence().map { LongRing.initializer(it) }.toMutableList() + .toLongArray(), + shape, Device.CPU + ) + + private fun StructureND.castHelper(): NoaLongTensor = + copyFromArray( + TensorLinearStructure(this.shape).asSequence().map(this::get).toMutableList().toLongArray(), this.shape, Device.CPU ) - override val Tensor.tensor: NoaLongTensor + override val StructureND.tensor: NoaLongTensor get() = when (this) { is NoaLongTensor -> this else -> castHelper() @@ -558,28 +590,28 @@ protected constructor(scope: NoaScope) : override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor = wrap(JNoa.randintLong(low, high, shape, device.toInt())) - override operator fun Long.plus(other: Tensor): NoaLongTensor = - wrap(JNoa.plusLong(this, other.tensor.tensorHandle)) + override operator fun Long.plus(arg: StructureND): NoaLongTensor = + wrap(JNoa.plusLong(this, arg.tensor.tensorHandle)) - override fun Tensor.plus(value: Long): NoaLongTensor = + override fun StructureND.plus(value: Long): NoaLongTensor = wrap(JNoa.plusLong(value, tensor.tensorHandle)) override fun Tensor.plusAssign(value: Long): Unit = JNoa.plusLongAssign(value, tensor.tensorHandle) - override operator fun Long.minus(other: Tensor): NoaLongTensor = - wrap(JNoa.plusLong(-this, other.tensor.tensorHandle)) + override operator fun Long.minus(arg: StructureND): NoaLongTensor = + wrap(JNoa.plusLong(-this, arg.tensor.tensorHandle)) - override fun Tensor.minus(value: Long): NoaLongTensor = + override fun StructureND.minus(value: Long): NoaLongTensor = wrap(JNoa.plusLong(-value, tensor.tensorHandle)) override fun Tensor.minusAssign(value: Long): Unit = JNoa.plusLongAssign(-value, tensor.tensorHandle) - override operator fun Long.times(other: Tensor): NoaLongTensor = - wrap(JNoa.timesLong(this, other.tensor.tensorHandle)) + override operator fun Long.times(arg: StructureND): NoaLongTensor = + wrap(JNoa.timesLong(this, arg.tensor.tensorHandle)) - override fun Tensor.times(value: Long): NoaLongTensor = + override fun StructureND.times(value: Long): NoaLongTensor = wrap(JNoa.timesLong(value, tensor.tensorHandle)) override fun Tensor.timesAssign(value: Long): Unit = @@ -606,15 +638,25 @@ protected constructor(scope: NoaScope) : public sealed class NoaIntAlgebra protected constructor(scope: NoaScope) : - NoaAlgebra(scope) { + NoaAlgebra(scope) { - private fun Tensor.castHelper(): NoaIntTensor = + override val elementAlgebra: IntRing + get() = IntRing + + override fun structureND(shape: IntArray, initializer: IntRing.(IntArray) -> Int): NoaIntTensor = copyFromArray( - TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toIntArray(), + TensorLinearStructure(shape).asSequence().map { IntRing.initializer(it) }.toMutableList() + .toIntArray(), + shape, Device.CPU + ) + + private fun StructureND.castHelper(): NoaIntTensor = + copyFromArray( + TensorLinearStructure(this.shape).asSequence().map(this::get).toMutableList().toIntArray(), this.shape, Device.CPU ) - override val Tensor.tensor: NoaIntTensor + override val StructureND.tensor: NoaIntTensor get() = when (this) { is NoaIntTensor -> this else -> castHelper() @@ -635,28 +677,28 @@ protected constructor(scope: NoaScope) : override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor = wrap(JNoa.randintInt(low, high, shape, device.toInt())) - override operator fun Int.plus(other: Tensor): NoaIntTensor = - wrap(JNoa.plusInt(this, other.tensor.tensorHandle)) + override operator fun Int.plus(arg: StructureND): NoaIntTensor = + wrap(JNoa.plusInt(this, arg.tensor.tensorHandle)) - override fun Tensor.plus(value: Int): NoaIntTensor = + override fun StructureND.plus(value: Int): NoaIntTensor = wrap(JNoa.plusInt(value, tensor.tensorHandle)) override fun Tensor.plusAssign(value: Int): Unit = JNoa.plusIntAssign(value, tensor.tensorHandle) - override operator fun Int.minus(other: Tensor): NoaIntTensor = - wrap(JNoa.plusInt(-this, other.tensor.tensorHandle)) + override operator fun Int.minus(arg: StructureND): NoaIntTensor = + wrap(JNoa.plusInt(-this, arg.tensor.tensorHandle)) - override fun Tensor.minus(value: Int): NoaIntTensor = + override fun StructureND.minus(value: Int): NoaIntTensor = wrap(JNoa.plusInt(-value, tensor.tensorHandle)) override fun Tensor.minusAssign(value: Int): Unit = JNoa.plusIntAssign(-value, tensor.tensorHandle) - override operator fun Int.times(other: Tensor): NoaIntTensor = - wrap(JNoa.timesInt(this, other.tensor.tensorHandle)) + override operator fun Int.times(arg: StructureND): NoaIntTensor = + wrap(JNoa.timesInt(this, arg.tensor.tensorHandle)) - override fun Tensor.times(value: Int): NoaIntTensor = + override fun StructureND.times(value: Int): NoaIntTensor = wrap(JNoa.timesInt(value, tensor.tensorHandle)) override fun Tensor.timesAssign(value: Int): Unit = diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/utils.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/utils.kt index 5ad18c2eb..f6bfeffc5 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/utils.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/utils.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.noa +import space.kscience.kmath.operations.Field + public fun cudaAvailable(): Boolean { return JNoa.cudaIsAvailable() } @@ -21,9 +23,9 @@ public fun setSeed(seed: Int): Unit { JNoa.setSeed(seed) } -public inline fun , ArrayT, GradTensorT : NoaTensorOverField, - GradAlgebraT : NoaPartialDivisionAlgebra> + GradAlgebraT : NoaPartialDivisionAlgebra> GradAlgebraT.withGradAt( tensor: GradTensorT, block: GradAlgebraT.(GradTensorT) -> GradTensorT diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt index 61e6f0362..0c3d9618e 100644 --- a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt @@ -6,6 +6,7 @@ package space.kscience.kmath.noa import space.kscience.kmath.noa.memory.NoaScope +import space.kscience.kmath.operations.Ring import kotlin.test.Test import kotlin.test.assertEquals @@ -13,7 +14,7 @@ import kotlin.test.assertEquals internal val SEED = 987654 internal val TOLERANCE = 1e-6 -internal fun , AlgebraT : NoaAlgebra> +internal fun , ArrayT, TensorT : NoaTensor, AlgebraT : NoaAlgebra> AlgebraT.withCuda(block: AlgebraT.(Device) -> Unit): Unit { this.block(Device.CPU) if (cudaAvailable()) this.block(Device.CUDA(0)) diff --git a/kmath-optimization/build.gradle.kts b/kmath-optimization/build.gradle.kts new file mode 100644 index 000000000..68b82ad65 --- /dev/null +++ b/kmath-optimization/build.gradle.kts @@ -0,0 +1,20 @@ +plugins { + id("ru.mipt.npm.gradle.mpp") + id("ru.mipt.npm.gradle.native") +} + +kscience { + useAtomic() +} + +kotlin.sourceSets { + commonMain { + dependencies { + api(project(":kmath-coroutines")) + } + } +} + +readme { + maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL +} diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt new file mode 100644 index 000000000..02602b068 --- /dev/null +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt @@ -0,0 +1,71 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.optimization + +import space.kscience.kmath.expressions.DifferentiableExpression +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.misc.FeatureSet + +public class OptimizationValue(public val value: T) : OptimizationFeature { + override fun toString(): String = "Value($value)" +} + +public enum class FunctionOptimizationTarget : OptimizationFeature { + MAXIMIZE, + MINIMIZE +} + +public class FunctionOptimization( + override val features: FeatureSet, + public val expression: DifferentiableExpression, +) : OptimizationProblem { + + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as FunctionOptimization<*> + + if (features != other.features) return false + if (expression != other.expression) return false + + return true + } + + override fun hashCode(): Int { + var result = features.hashCode() + result = 31 * result + expression.hashCode() + return result + } + + override fun toString(): String = "FunctionOptimization(features=$features)" +} + +public fun FunctionOptimization.withFeatures( + vararg newFeature: OptimizationFeature, +): FunctionOptimization = FunctionOptimization( + features.with(*newFeature), + expression, +) + +/** + * Optimizes differentiable expression using specific [optimizer] form given [startingPoint]. + */ +public suspend fun DifferentiableExpression.optimizeWith( + optimizer: Optimizer>, + startingPoint: Map, + vararg features: OptimizationFeature, +): FunctionOptimization { + val problem = FunctionOptimization(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this) + return optimizer.optimize(problem) +} + +public val FunctionOptimization.resultValueOrNull: T? + get() = getFeature>()?.point?.let { expression(it) } + +public val FunctionOptimization.resultValue: T + get() = resultValueOrNull ?: error("Result is not present in $this") \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationBuilder.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationBuilder.kt new file mode 100644 index 000000000..416d0195d --- /dev/null +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationBuilder.kt @@ -0,0 +1,93 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.optimization + +import space.kscience.kmath.data.XYColumnarData +import space.kscience.kmath.expressions.DifferentiableExpression +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.misc.FeatureSet + +public abstract class OptimizationBuilder> { + public val features: MutableList = ArrayList() + + public fun addFeature(feature: OptimizationFeature) { + features.add(feature) + } + + public inline fun updateFeature(update: (T?) -> T) { + val existing = features.find { it.key == T::class } as? T + val new = update(existing) + if (existing != null) { + features.remove(existing) + } + addFeature(new) + } + + public abstract fun build(): R +} + +public fun OptimizationBuilder.startAt(startingPoint: Map) { + addFeature(OptimizationStartPoint(startingPoint)) +} + +public class FunctionOptimizationBuilder( + private val expression: DifferentiableExpression, +) : OptimizationBuilder>() { + override fun build(): FunctionOptimization = FunctionOptimization(FeatureSet.of(features), expression) +} + +public fun FunctionOptimization( + expression: DifferentiableExpression, + builder: FunctionOptimizationBuilder.() -> Unit, +): FunctionOptimization = FunctionOptimizationBuilder(expression).apply(builder).build() + +public suspend fun DifferentiableExpression.optimizeWith( + optimizer: Optimizer>, + startingPoint: Map, + builder: FunctionOptimizationBuilder.() -> Unit = {}, +): FunctionOptimization { + val problem = FunctionOptimization(this) { + startAt(startingPoint) + builder() + } + return optimizer.optimize(problem) +} + +public suspend fun DifferentiableExpression.optimizeWith( + optimizer: Optimizer>, + vararg startingPoint: Pair, + builder: FunctionOptimizationBuilder.() -> Unit = {}, +): FunctionOptimization { + val problem = FunctionOptimization(this) { + startAt(mapOf(*startingPoint)) + builder() + } + return optimizer.optimize(problem) +} + + +public class XYOptimizationBuilder( + public val data: XYColumnarData, + public val model: DifferentiableExpression, +) : OptimizationBuilder() { + + public var pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY + public var pointWeight: PointWeight = PointWeight.byYSigma + + override fun build(): XYFit = XYFit( + data, + model, + FeatureSet.of(features), + pointToCurveDistance, + pointWeight + ) +} + +public fun XYOptimization( + data: XYColumnarData, + model: DifferentiableExpression, + builder: XYOptimizationBuilder.() -> Unit, +): XYFit = XYOptimizationBuilder(data, model).apply(builder).build() \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationProblem.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationProblem.kt new file mode 100644 index 000000000..b42be4035 --- /dev/null +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationProblem.kt @@ -0,0 +1,66 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.optimization + +import space.kscience.kmath.expressions.DifferentiableExpression +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.linear.Matrix +import space.kscience.kmath.misc.* +import kotlin.reflect.KClass + +public interface OptimizationFeature : Feature { + // enforce toString override + override fun toString(): String +} + +public interface OptimizationProblem : Featured { + public val features: FeatureSet + override fun getFeature(type: KClass): F? = features.getFeature(type) +} + +public inline fun OptimizationProblem<*>.getFeature(): F? = getFeature(F::class) + +public open class OptimizationStartPoint(public val point: Map) : OptimizationFeature { + override fun toString(): String = "StartPoint($point)" +} + + +public interface OptimizationPrior : OptimizationFeature, DifferentiableExpression { + override val key: FeatureKey get() = OptimizationPrior::class +} + +public class OptimizationCovariance(public val covariance: Matrix) : OptimizationFeature { + override fun toString(): String = "Covariance($covariance)" +} + +/** + * Get the starting point for optimization. Throws error if not defined. + */ +public val OptimizationProblem.startPoint: Map + get() = getFeature>()?.point + ?: error("Starting point not defined in $this") + +public open class OptimizationResult(public val point: Map) : OptimizationFeature { + override fun toString(): String = "Result($point)" +} + +public val OptimizationProblem.resultPointOrNull: Map? + get() = getFeature>()?.point + +public val OptimizationProblem.resultPoint: Map + get() = resultPointOrNull ?: error("Result is not present in $this") + +public class OptimizationLog(private val loggable: Loggable) : Loggable by loggable, OptimizationFeature { + override fun toString(): String = "Log($loggable)" +} + +public class OptimizationParameters(public val symbols: List) : OptimizationFeature { + public constructor(vararg symbols: Symbol) : this(listOf(*symbols)) + + override fun toString(): String = "Parameters($symbols)" +} + + diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/VectorSpaceTest.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/Optimizer.kt similarity index 52% rename from kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/VectorSpaceTest.kt rename to kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/Optimizer.kt index f2c7f1f90..78385a99b 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/VectorSpaceTest.kt +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/Optimizer.kt @@ -3,3 +3,8 @@ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. */ +package space.kscience.kmath.optimization + +public interface Optimizer> { + public suspend fun optimize(problem: P): P +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/QowOptimizer.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/QowOptimizer.kt new file mode 100644 index 000000000..babbaf6cd --- /dev/null +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/QowOptimizer.kt @@ -0,0 +1,266 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.optimization + +import space.kscience.kmath.expressions.DifferentiableExpression +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.expressions.SymbolIndexer +import space.kscience.kmath.expressions.derivative +import space.kscience.kmath.linear.* +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.misc.log +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.DoubleL2Norm +import space.kscience.kmath.operations.algebra +import space.kscience.kmath.structures.DoubleBuffer + + +public class QowRuns(public val runs: Int) : OptimizationFeature { + init { + require(runs >= 1) { "Number of runs must be more than zero" } + } + + override fun toString(): String = "QowRuns(runs=$runs)" +} + + +/** + * An optimizer based onf Fyodor Tkachev's quasi-optimal weights method. + * See [the article](http://arxiv.org/abs/physics/0604127). + */ +@UnstableKMathAPI +public object QowOptimizer : Optimizer { + + private val linearSpace: LinearSpace = Double.algebra.linearSpace + private val solver: LinearSolver = linearSpace.lupSolver() + + @OptIn(UnstableKMathAPI::class) + private class QoWeight( + val problem: XYFit, + val parameters: Map, + ) : Map by parameters, SymbolIndexer { + override val symbols: List = parameters.keys.toList() + + val data get() = problem.data + + /** + * Derivatives of the spectrum over parameters. First index in the point number, second one - index of parameter + */ + val derivs: Matrix by lazy { + linearSpace.buildMatrix(problem.data.size, symbols.size) { d, s -> + problem.distance(d).derivative(symbols[s])(parameters) + } + } + + /** + * Array of dispersions in each point + */ + val dispersion: Point by lazy { + DoubleBuffer(problem.data.size) { d -> + 1.0/problem.weight(d).invoke(parameters) + } + } + + val prior: DifferentiableExpression? get() = problem.getFeature>() + + override fun toString(): String = parameters.toString() + } + + /** + * The signed distance from the model to the [d]-th point of data. + */ + private fun QoWeight.distance(d: Int, parameters: Map): Double = problem.distance(d)(parameters) + + + /** + * The derivative of [distance] + */ + private fun QoWeight.distanceDerivative(symbol: Symbol, d: Int, parameters: Map): Double = + problem.distance(d).derivative(symbol)(parameters) + + /** + * Теоретическая ковариация весовых функций. + * + * D(\phi)=E(\phi_k(\theta_0) \phi_l(\theta_0))= disDeriv_k * disDeriv_l /sigma^2 + */ + private fun QoWeight.covarF(): Matrix = + linearSpace.matrix(size, size).symmetric { s1, s2 -> + (0 until data.size).sumOf { d -> derivs[d, s1] * derivs[d, s2] / dispersion[d] } + } + + /** + * Экспериментальная ковариация весов. Формула (22) из + * http://arxiv.org/abs/physics/0604127 + */ + private fun QoWeight.covarFExp(theta: Map): Matrix = + with(linearSpace) { + /* + * Важно! Если не делать предварителього вычисления этих производных, то + * количество вызывов функции будет dim^2 вместо dim Первый индекс - + * номер точки, второй - номер переменной, по которой берется производная + */ + val eqvalues = linearSpace.buildMatrix(data.size, size) { d, s -> + distance(d, theta) * derivs[d, s] / dispersion[d] + } + + buildMatrix(size, size) { s1, s2 -> + (0 until data.size).sumOf { d -> eqvalues[d, s2] * eqvalues[d, s1] } + } + } + + /** + * Equation derivatives for Newton run + */ + private fun QoWeight.getEqDerivValues( + theta: Map = parameters, + ): Matrix = with(linearSpace) { + //Возвращает производную k-того Eq по l-тому параметру + //val res = Array(fitDim) { DoubleArray(fitDim) } + val sderiv = buildMatrix(data.size, size) { d, s -> + distanceDerivative(symbols[s], d, theta) + } + + buildMatrix(size, size) { s1, s2 -> + val base = (0 until data.size).sumOf { d -> + require(dispersion[d] > 0) + sderiv[d, s2] * derivs[d, s1] / dispersion[d] + } + prior?.let { prior -> + //Check if this one is correct + val pi = prior(theta) + val deriv1 = prior.derivative(symbols[s1])(theta) + val deriv2 = prior.derivative(symbols[s2])(theta) + base + deriv1 * deriv2 / pi / pi + } ?: base + } + } + + + /** + * Значения уравнений метода квазиоптимальных весов + */ + private fun QoWeight.getEqValues(theta: Map = this): Point { + val distances = DoubleBuffer(data.size) { d -> distance(d, theta) } + + return DoubleBuffer(size) { s -> + val base = (0 until data.size).sumOf { d -> distances[d] * derivs[d, s] / dispersion[d] } + //Поправка на априорную вероятность + prior?.let { prior -> + base - prior.derivative(symbols[s])(theta) / prior(theta) + } ?: base + } + } + + + private fun QoWeight.newtonianStep( + theta: Map, + eqvalues: Point, + ): QoWeight = linearSpace { + with(this@newtonianStep) { + val start = theta.toPoint() + val invJacob = solver.inverse(this@newtonianStep.getEqDerivValues(theta)) + + val step = invJacob.dot(eqvalues) + return QoWeight(problem, theta + (start - step).toMap()) + } + } + + private fun QoWeight.newtonianRun( + maxSteps: Int = 100, + tolerance: Double = 0.0, + fast: Boolean = false, + ): QoWeight { + + val logger = problem.getFeature() + + var dis: Double //discrepancy value + // Working with the full set of parameters + var par = problem.startPoint + + logger?.log { "Starting newtonian iteration from: \n\t$par" } + + var eqvalues = getEqValues(par) //Values of the weight functions + + dis = DoubleL2Norm.norm(eqvalues) // discrepancy + logger?.log { "Starting discrepancy is $dis" } + var i = 0 + var flag = false + while (!flag) { + i++ + logger?.log { "Starting step number $i" } + + val currentSolution = if (fast) { + //Берет значения матрицы в той точке, где считается вес + newtonianStep(this, eqvalues) + } else { + //Берет значения матрицы в точке par + newtonianStep(par, eqvalues) + } + // здесь должен стоять учет границ параметров + logger?.log { "Parameter values after step are: \n\t$currentSolution" } + + eqvalues = getEqValues(currentSolution) + val currentDis = DoubleL2Norm.norm(eqvalues)// невязка после шага + + logger?.log { "The discrepancy after step is: $currentDis." } + + if (currentDis >= dis && i > 1) { + //дополнительно проверяем, чтобы был сделан хотя бы один шаг + flag = true + logger?.log { "The discrepancy does not decrease. Stopping iteration." } + } else { + par = currentSolution + dis = currentDis + } + if (i >= maxSteps) { + flag = true + logger?.log { "Maximum number of iterations reached. Stopping iteration." } + } + if (dis <= tolerance) { + flag = true + logger?.log { "Tolerance threshold is reached. Stopping iteration." } + } + } + + return QoWeight(problem, par) + } + + private fun QoWeight.covariance(): Matrix { + val logger = problem.getFeature() + + logger?.log { + """ + Starting errors estimation using quasioptimal weights method. The starting weight is: + ${problem.startPoint} + """.trimIndent() + } + + val covar = solver.inverse(getEqDerivValues()) + //TODO fix eigenvalues check +// val decomposition = EigenDecomposition(covar.matrix) +// var valid = true +// for (lambda in decomposition.realEigenvalues) { +// if (lambda <= 0) { +// logger?.log { "The covariance matrix is not positive defined. Error estimation is not valid" } +// valid = false +// } +// } + return covar + } + + override suspend fun optimize(problem: XYFit): XYFit { + val qowRuns = problem.getFeature()?.runs ?: 2 + + + var qow = QoWeight(problem, problem.startPoint) + var res = qow.newtonianRun() + repeat(qowRuns - 1) { + qow = QoWeight(problem, res.parameters) + res = qow.newtonianRun() + } + return res.problem.withFeature(OptimizationResult(res.parameters)) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt new file mode 100644 index 000000000..07fea3126 --- /dev/null +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt @@ -0,0 +1,146 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ +@file:OptIn(UnstableKMathAPI::class) + +package space.kscience.kmath.optimization + +import space.kscience.kmath.data.XYColumnarData +import space.kscience.kmath.data.indices +import space.kscience.kmath.expressions.* +import space.kscience.kmath.misc.FeatureSet +import space.kscience.kmath.misc.Loggable +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.ExtendedField +import space.kscience.kmath.operations.bindSymbol +import kotlin.math.pow + +/** + * Specify the way to compute distance from point to the curve as DifferentiableExpression + */ +public interface PointToCurveDistance : OptimizationFeature { + public fun distance(problem: XYFit, index: Int): DifferentiableExpression + + public companion object { + public val byY: PointToCurveDistance = object : PointToCurveDistance { + override fun distance(problem: XYFit, index: Int): DifferentiableExpression { + val x = problem.data.x[index] + val y = problem.data.y[index] + + return object : DifferentiableExpression { + override fun derivativeOrNull( + symbols: List + ): Expression? = problem.model.derivativeOrNull(symbols)?.let { derivExpression -> + Expression { arguments -> + derivExpression.invoke(arguments + (Symbol.x to x)) + } + } + + override fun invoke(arguments: Map): Double = + problem.model(arguments + (Symbol.x to x)) - y + } + } + + override fun toString(): String = "PointToCurveDistanceByY" + } + } +} + +/** + * Compute a wight of the point. The more the weight, the more impact this point will have on the fit. + * By default, uses Dispersion^-1 + */ +public interface PointWeight : OptimizationFeature { + public fun weight(problem: XYFit, index: Int): DifferentiableExpression + + public companion object { + public fun bySigma(sigmaSymbol: Symbol): PointWeight = object : PointWeight { + override fun weight(problem: XYFit, index: Int): DifferentiableExpression = + object : DifferentiableExpression { + override fun invoke(arguments: Map): Double { + return problem.data[sigmaSymbol]?.get(index)?.pow(-2) ?: 1.0 + } + + override fun derivativeOrNull(symbols: List): Expression = Expression { 0.0 } + } + + override fun toString(): String = "PointWeightBySigma($sigmaSymbol)" + + } + + public val byYSigma: PointWeight = bySigma(Symbol.yError) + } +} + +/** + * A fit problem for X-Y-Yerr data. Also known as "least-squares" problem. + */ +public class XYFit( + public val data: XYColumnarData, + public val model: DifferentiableExpression, + override val features: FeatureSet, + internal val pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, + internal val pointWeight: PointWeight = PointWeight.byYSigma, + public val xSymbol: Symbol = Symbol.x, +) : OptimizationProblem { + public fun distance(index: Int): DifferentiableExpression = pointToCurveDistance.distance(this, index) + + public fun weight(index: Int): DifferentiableExpression = pointWeight.weight(this, index) +} + +public fun XYFit.withFeature(vararg features: OptimizationFeature): XYFit { + return XYFit(data, model, this.features.with(*features), pointToCurveDistance, pointWeight) +} + +/** + * Fit given dta with + */ +public suspend fun XYColumnarData.fitWith( + optimizer: Optimizer, + processor: AutoDiffProcessor, + startingPoint: Map, + vararg features: OptimizationFeature = emptyArray(), + xSymbol: Symbol = Symbol.x, + pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, + pointWeight: PointWeight = PointWeight.byYSigma, + model: A.(I) -> I +): XYFit where A : ExtendedField, A : ExpressionAlgebra { + val modelExpression = processor.differentiate { + val x = bindSymbol(xSymbol) + model(x) + } + + var actualFeatures = FeatureSet.of(*features, OptimizationStartPoint(startingPoint)) + + if (actualFeatures.getFeature() == null) { + actualFeatures = actualFeatures.with(OptimizationLog(Loggable.console)) + } + val problem = XYFit( + this, + modelExpression, + actualFeatures, + pointToCurveDistance, + pointWeight, + xSymbol + ) + return optimizer.optimize(problem) +} + +/** + * Compute chi squared value for completed fit. Return null for incomplete fit + */ +public val XYFit.chiSquaredOrNull: Double? get() { + val result = resultPointOrNull ?: return null + + return data.indices.sumOf { index-> + + val x = data.x[index] + val y = data.y[index] + val yErr = data[Symbol.yError]?.get(index) ?: 1.0 + + val mu = model.invoke(result + (xSymbol to x) ) + + ((y - mu)/yErr).pow(2) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/logLikelihood.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/logLikelihood.kt new file mode 100644 index 000000000..b4cb2f1cf --- /dev/null +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/logLikelihood.kt @@ -0,0 +1,66 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.optimization + +import space.kscience.kmath.data.XYColumnarData +import space.kscience.kmath.data.indices +import space.kscience.kmath.expressions.DifferentiableExpression +import space.kscience.kmath.expressions.Expression +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.expressions.derivative +import space.kscience.kmath.misc.UnstableKMathAPI +import kotlin.math.PI +import kotlin.math.ln +import kotlin.math.pow +import kotlin.math.sqrt + + +private val oneOver2Pi = 1.0 / sqrt(2 * PI) + +@UnstableKMathAPI +internal fun XYFit.logLikelihood(): DifferentiableExpression = object : DifferentiableExpression { + override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + data.indices.sumOf { index -> + val d = distance(index)(arguments) + val weight = weight(index)(arguments) + val weightDerivative = weight(index).derivative(symbols)(arguments) + + // -1 / (sqrt(2 PI) * sigma) + 2 (x-mu)/ 2 sigma^2 * d mu/ d theta - (x-mu)^2 / 2 * d w/ d theta + return@sumOf -oneOver2Pi * sqrt(weight) + //offset derivative + d * model.derivative(symbols)(arguments) * weight - //model derivative + d.pow(2) * weightDerivative / 2 //weight derivative + } + } + + override fun invoke(arguments: Map): Double { + return data.indices.sumOf { index -> + val d = distance(index)(arguments) + val weight = weight(index)(arguments) + //1/sqrt(2 PI sigma^2) - (x-mu)^2/ (2 * sigma^2) + oneOver2Pi * ln(weight) - d.pow(2) * weight + } / 2 + } + +} + +/** + * Optimize given XY (least squares) [problem] using this function [Optimizer]. + * The problem is treated as maximum likelihood problem and is done via maximizing logarithmic likelihood, respecting + * possible weight dependency on the model and parameters. + */ +@UnstableKMathAPI +public suspend fun Optimizer>.maximumLogLikelihood(problem: XYFit): XYFit { + val functionOptimization = FunctionOptimization(problem.features, problem.logLikelihood()) + val result = optimize(functionOptimization.withFeatures(FunctionOptimizationTarget.MAXIMIZE)) + return XYFit(problem.data, problem.model, result.features) +} + +@UnstableKMathAPI +public suspend fun Optimizer>.maximumLogLikelihood( + data: XYColumnarData, + model: DifferentiableExpression, + builder: XYOptimizationBuilder.() -> Unit, +): XYFit = maximumLogLikelihood(XYOptimization(data, model, builder)) diff --git a/kmath-optimization/src/commonMain/tmp/QowFit.kt b/kmath-optimization/src/commonMain/tmp/QowFit.kt new file mode 100644 index 000000000..c78aef401 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/QowFit.kt @@ -0,0 +1,372 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.optimization.qow + +import space.kscience.kmath.data.ColumnarData +import space.kscience.kmath.data.XYErrorColumnarData +import space.kscience.kmath.expressions.* +import space.kscience.kmath.linear.* +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.Field +import space.kscience.kmath.optimization.OptimizationFeature +import space.kscience.kmath.optimization.OptimizationProblemFactory +import space.kscience.kmath.optimization.OptimizationResult +import space.kscience.kmath.optimization.XYOptimization +import space.kscience.kmath.structures.DoubleBuffer +import space.kscience.kmath.structures.DoubleL2Norm +import kotlin.math.pow + + +private typealias ParamSet = Map + +@OptIn(UnstableKMathAPI::class) +public class QowFit( + override val symbols: List, + private val space: LinearSpace, + private val solver: LinearSolver, +) : XYOptimization, SymbolIndexer { + + private var logger: FitLogger? = null + + private var startingPoint: Map = TODO() + private var covariance: Matrix? = TODO() + private val prior: DifferentiableExpression>? = TODO() + private var data: XYErrorColumnarData = TODO() + private var model: DifferentiableExpression> = TODO() + + private val features = HashSet() + + override fun update(result: OptimizationResult) { + TODO("Not yet implemented") + } + + override val algebra: Field + get() = TODO("Not yet implemented") + + override fun data( + dataSet: ColumnarData, + xSymbol: Symbol, + ySymbol: Symbol, + xErrSymbol: Symbol?, + yErrSymbol: Symbol?, + ) { + TODO("Not yet implemented") + } + + override fun model(model: (Double) -> DifferentiableExpression) { + TODO("Not yet implemented") + } + + private var x: Symbol = Symbol.x + + /** + * The signed distance from the model to the [i]-th point of data. + */ + private fun distance(i: Int, parameters: Map): Double = + model(parameters + (x to data.x[i])) - data.y[i] + + + /** + * The derivative of [distance] + * TODO use expressions instead + */ + private fun distanceDerivative(symbol: Symbol, i: Int, parameters: Map): Double = + model.derivative(symbol)(parameters + (x to data.x[i])) + + /** + * The dispersion of [i]-th data point + */ + private fun getDispersion(i: Int, parameters: Map): Double = data.yErr[i].pow(2) + + private fun getCovariance(weight: QoWeight): Matrix = solver.inverse(getEqDerivValues(weight)) + + /** + * Теоретическая ковариация весовых функций. + * + * D(\phi)=E(\phi_k(\theta_0) \phi_l(\theta_0))= disDeriv_k * disDeriv_l /sigma^2 + */ + private fun covarF(weight: QoWeight): Matrix = space.buildSymmetricMatrix(symbols.size) { k, l -> + (0 until data.size).sumOf { i -> weight.derivs[k, i] * weight.derivs[l, i] / weight.dispersion[i] } + } + + /** + * Экспериментальная ковариация весов. Формула (22) из + * http://arxiv.org/abs/physics/0604127 + * + * @param source + * @param set + * @param fitPars + * @param weight + * @return + */ + private fun covarFExp(weight: QoWeight, theta: Map): Matrix = space.run { + /* + * Важно! Если не делать предварителього вычисления этих производных, то + * количество вызывов функции будет dim^2 вместо dim Первый индекс - + * номер точки, второй - номер переменной, по которой берется производная + */ + val eqvalues = buildMatrix(data.size, symbols.size) { i, l -> + distance(i, theta) * weight.derivs[l, i] / weight.dispersion[i] + } + + buildMatrix(symbols.size, symbols.size) { k, l -> + (0 until data.size).sumOf { i -> eqvalues[i, l] * eqvalues[i, k] } + } + } + + /** + * производные уравнений для метода Ньютона + * + * @param source + * @param set + * @param fitPars + * @param weight + * @return + */ + private fun getEqDerivValues( + weight: QoWeight, theta: Map = weight.theta, + ): Matrix = space.run { + val fitDim = symbols.size + //Возвращает производную k-того Eq по l-тому параметру + val res = Array(fitDim) { DoubleArray(fitDim) } + val sderiv = buildMatrix(data.size, symbols.size) { i, l -> + distanceDerivative(symbols[l], i, theta) + } + + buildMatrix(symbols.size, symbols.size) { k, l -> + val base = (0 until data.size).sumOf { i -> + require(weight.dispersion[i] > 0) + sderiv[i, l] * weight.derivs[k, i] / weight.dispersion[i] + } + prior?.let { prior -> + //Check if this one is correct + val pi = prior(theta) + val deriv1 = prior.derivative(symbols[k])(theta) + val deriv2 = prior.derivative(symbols[l])(theta) + base + deriv1 * deriv2 / pi / pi + } ?: base + } + } + + + /** + * Значения уравнений метода квазиоптимальных весов + * + * @param source + * @param set + * @param fitPars + * @param weight + * @return + */ + private fun getEqValues(weight: QoWeight, theta: Map = weight.theta): Point { + val distances = DoubleBuffer(data.size) { i -> distance(i, theta) } + + return DoubleBuffer(symbols.size) { k -> + val base = (0 until data.size).sumOf { i -> distances[i] * weight.derivs[k, i] / weight.dispersion[i] } + //Поправка на априорную вероятность + prior?.let { prior -> + base - prior.derivative(symbols[k])(theta) / prior(theta) + } ?: base + } + } + + + /** + * The state of QOW fitter + * Created by Alexander Nozik on 17-Oct-16. + */ + private inner class QoWeight( + val theta: Map, + ) { + + init { + require(data.size > 0) { "The state does not contain data" } + } + + /** + * Derivatives of the spectrum over parameters. First index in the point number, second one - index of parameter + */ + val derivs: Matrix by lazy { + space.buildMatrix(data.size, symbols.size) { i, k -> + distanceDerivative(symbols[k], i, theta) + } + } + + /** + * Array of dispersions in each point + */ + val dispersion: Point by lazy { + DoubleBuffer(data.size) { i -> getDispersion(i, theta) } + } + + } + + private fun newtonianStep( + weight: QoWeight, + par: Map, + eqvalues: Point, + ): Map = space.run { + val start = par.toPoint() + val invJacob = solver.inverse(getEqDerivValues(weight, par)) + + val step = invJacob.dot(eqvalues) + return par + (start - step).toMap() + } + + private fun newtonianRun( + weight: QoWeight, + maxSteps: Int = 100, + tolerance: Double = 0.0, + fast: Boolean = false, + ): ParamSet { + + var dis: Double//норма невязки + // Для удобства работаем всегда с полным набором параметров + var par = startingPoint + + logger?.log { "Starting newtonian iteration from: \n\t$par" } + + var eqvalues = getEqValues(weight, par)//значения функций + + dis = DoubleL2Norm.norm(eqvalues)// невязка + logger?.log { "Starting discrepancy is $dis" } + var i = 0 + var flag = false + while (!flag) { + i++ + logger?.log { "Starting step number $i" } + + val currentSolution = if (fast) { + //Берет значения матрицы в той точке, где считается вес + newtonianStep(weight, weight.theta, eqvalues) + } else { + //Берет значения матрицы в точке par + newtonianStep(weight, par, eqvalues) + } + // здесь должен стоять учет границ параметров + logger?.log { "Parameter values after step are: \n\t$currentSolution" } + + eqvalues = getEqValues(weight, currentSolution) + val currentDis = DoubleL2Norm.norm(eqvalues)// невязка после шага + + logger?.log { "The discrepancy after step is: $currentDis." } + + if (currentDis >= dis && i > 1) { + //дополнительно проверяем, чтобы был сделан хотя бы один шаг + flag = true + logger?.log { "The discrepancy does not decrease. Stopping iteration." } + } else { + par = currentSolution + dis = currentDis + } + if (i >= maxSteps) { + flag = true + logger?.log { "Maximum number of iterations reached. Stopping iteration." } + } + if (dis <= tolerance) { + flag = true + logger?.log { "Tolerance threshold is reached. Stopping iteration." } + } + } + + return par + } + + +// +// override fun run(state: FitState, parentLog: History?, meta: Meta): FitResult { +// val log = Chronicle("QOW", parentLog) +// val action = meta.getString(FIT_STAGE_TYPE, TASK_RUN) +// log.report("QOW fit engine started task '{}'", action) +// return when (action) { +// TASK_SINGLE -> makeRun(state, log, meta) +// TASK_COVARIANCE -> generateErrors(state, log, meta) +// TASK_RUN -> { +// var res = makeRun(state, log, meta) +// res = makeRun(res.optState().get(), log, meta) +// generateErrors(res.optState().get(), log, meta) +// } +// else -> throw IllegalArgumentException("Unknown task") +// } +// } + +// private fun makeRun(state: FitState, log: History, meta: Meta): FitResult { +// /*Инициализация объектов, задание исходных значений*/ +// log.report("Starting fit using quasioptimal weights method.") +// +// val fitPars = getFitPars(state, meta) +// +// val curWeight = QoWeight(state, fitPars, state.parameters) +// +// // вычисляем вес в allPar. Потом можно будет попробовать ручное задание веса +// log.report("The starting weight is: \n\t{}", +// MathUtils.toString(curWeight.theta)) +// +// //Стартовая точка такая же как и параметр веса +// /*Фитирование*/ +// val res = newtonianRun(state, curWeight, log, meta) +// +// /*Генерация результата*/ +// +// return FitResult.build(state.edit().setPars(res).build(), *fitPars) +// } + + /** + * generateErrors. + */ + private fun generateErrors(): Matrix { + logger?.log { """ + Starting errors estimation using quasioptimal weights method. The starting weight is: + ${curWeight.theta} + """.trimIndent()} + val curWeight = QoWeight(startingPoint) + + val covar = getCovariance(curWeight) + + val decomposition = EigenDecomposition(covar.matrix) + var valid = true + for (lambda in decomposition.realEigenvalues) { + if (lambda <= 0) { + log.report("The covariance matrix is not positive defined. Error estimation is not valid") + valid = false + } + } + } + + + override suspend fun optimize(): OptimizationResult { + val curWeight = QoWeight(startingPoint) + logger?.log { + """ + Starting fit using quasioptimal weights method. The starting weight is: + ${curWeight.theta} + """.trimIndent() + } + val res = newtonianRun(curWeight) + } + + + companion object : OptimizationProblemFactory { + override fun build(symbols: List): QowFit { + TODO("Not yet implemented") + } + + + /** + * Constant `QOW_ENGINE_NAME="QOW"` + */ + const val QOW_ENGINE_NAME = "QOW" + + /** + * Constant `QOW_METHOD_FAST="fast"` + */ + const val QOW_METHOD_FAST = "fast" + + + } +} + diff --git a/kmath-optimization/src/commonMain/tmp/minuit/AnalyticalGradientCalculator.kt b/kmath-optimization/src/commonMain/tmp/minuit/AnalyticalGradientCalculator.kt new file mode 100644 index 000000000..912fa22eb --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/AnalyticalGradientCalculator.kt @@ -0,0 +1,61 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction + +/** + * + * @version $Id$ + */ +internal class AnalyticalGradientCalculator(fcn: MultiFunction?, state: MnUserTransformation, checkGradient: Boolean) : + GradientCalculator { + private val function: MultiFunction? + private val theCheckGradient: Boolean + private val theTransformation: MnUserTransformation + fun checkGradient(): Boolean { + return theCheckGradient + } + + /** {@inheritDoc} */ + fun gradient(par: MinimumParameters): FunctionGradient { +// double[] grad = theGradCalc.gradientValue(theTransformation.andThen(par.vec()).data()); + val point: DoubleArray = theTransformation.transform(par.vec()).toArray() + require(!(function.getDimension() !== theTransformation.parameters().size())) { "Invalid parameter size" } + val v: RealVector = ArrayRealVector(par.vec().getDimension()) + for (i in 0 until par.vec().getDimension()) { + val ext: Int = theTransformation.extOfInt(i) + if (theTransformation.parameter(ext).hasLimits()) { + val dd: Double = theTransformation.dInt2Ext(i, par.vec().getEntry(i)) + v.setEntry(i, dd * function.derivValue(ext, point)) + } else { + v.setEntry(i, function.derivValue(ext, point)) + } + } + return FunctionGradient(v) + } + + /** {@inheritDoc} */ + fun gradient(par: MinimumParameters, grad: FunctionGradient?): FunctionGradient { + return gradient(par) + } + + init { + function = fcn + theTransformation = state + theCheckGradient = checkGradient + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/CombinedMinimizer.kt b/kmath-optimization/src/commonMain/tmp/minuit/CombinedMinimizer.kt new file mode 100644 index 000000000..9363492ad --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/CombinedMinimizer.kt @@ -0,0 +1,32 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal class CombinedMinimizer : ModularFunctionMinimizer() { + private val theMinBuilder: CombinedMinimumBuilder = CombinedMinimumBuilder() + private val theMinSeedGen: MnSeedGenerator = MnSeedGenerator() + override fun builder(): MinimumBuilder { + return theMinBuilder + } + + override fun seedGenerator(): MinimumSeedGenerator { + return theMinSeedGen + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/CombinedMinimumBuilder.kt b/kmath-optimization/src/commonMain/tmp/minuit/CombinedMinimumBuilder.kt new file mode 100644 index 000000000..8c5452575 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/CombinedMinimumBuilder.kt @@ -0,0 +1,58 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import space.kscience.kmath.optimization.minuit.MINUITPlugin +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * + * @version $Id$ + */ +internal class CombinedMinimumBuilder : MinimumBuilder { + private val theSimplexMinimizer: SimplexMinimizer = SimplexMinimizer() + private val theVMMinimizer: VariableMetricMinimizer = VariableMetricMinimizer() + + /** {@inheritDoc} */ + override fun minimum( + fcn: MnFcn?, + gc: GradientCalculator?, + seed: MinimumSeed?, + strategy: MnStrategy?, + maxfcn: Int, + toler: Double + ): FunctionMinimum { + val min: FunctionMinimum = theVMMinimizer.minimize(fcn!!, gc, seed, strategy, maxfcn, toler) + if (!min.isValid()) { + MINUITPlugin.logStatic("CombinedMinimumBuilder: migrad method fails, will try with simplex method first.") + val str = MnStrategy(2) + val min1: FunctionMinimum = theSimplexMinimizer.minimize(fcn, gc, seed, str, maxfcn, toler) + if (!min1.isValid()) { + MINUITPlugin.logStatic("CombinedMinimumBuilder: both migrad and simplex method fail.") + return min1 + } + val seed1: MinimumSeed = theVMMinimizer.seedGenerator().generate(fcn, gc, min1.userState(), str) + val min2: FunctionMinimum = theVMMinimizer.minimize(fcn, gc, seed1, str, maxfcn, toler) + if (!min2.isValid()) { + MINUITPlugin.logStatic("CombinedMinimumBuilder: both migrad and method fails also at 2nd attempt.") + MINUITPlugin.logStatic("CombinedMinimumBuilder: return simplex minimum.") + return min1 + } + return min2 + } + return min + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/ContoursError.kt b/kmath-optimization/src/commonMain/tmp/minuit/ContoursError.kt new file mode 100644 index 000000000..214d94c80 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/ContoursError.kt @@ -0,0 +1,150 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * ContoursError class. + * + * @author Darksnake + * @version $Id$ + */ +class ContoursError internal constructor( + private val theParX: Int, + private val theParY: Int, + points: List, + xmnos: MinosError, + ymnos: MinosError, + nfcn: Int +) { + private val theNFcn: Int + private val thePoints: List = points + private val theXMinos: MinosError + private val theYMinos: MinosError + + /** + * + * nfcn. + * + * @return a int. + */ + fun nfcn(): Int { + return theNFcn + } + + /** + * + * points. + * + * @return a [List] object. + */ + fun points(): List { + return thePoints + } + + /** + * {@inheritDoc} + */ + override fun toString(): String { + return MnPrint.toString(this) + } + + /** + * + * xMinosError. + * + * @return a [hep.dataforge.MINUIT.MinosError] object. + */ + fun xMinosError(): MinosError { + return theXMinos + } + + /** + * + * xRange. + * + * @return + */ + fun xRange(): Range { + return theXMinos.range() + } + + /** + * + * xmin. + * + * @return a double. + */ + fun xmin(): Double { + return theXMinos.min() + } + + /** + * + * xpar. + * + * @return a int. + */ + fun xpar(): Int { + return theParX + } + + /** + * + * yMinosError. + * + * @return a [hep.dataforge.MINUIT.MinosError] object. + */ + fun yMinosError(): MinosError { + return theYMinos + } + + /** + * + * yRange. + * + * @return + */ + fun yRange(): Range { + return theYMinos.range() + } + + /** + * + * ymin. + * + * @return a double. + */ + fun ymin(): Double { + return theYMinos.min() + } + + /** + * + * ypar. + * + * @return a int. + */ + fun ypar(): Int { + return theParY + } + + init { + theXMinos = xmnos + theYMinos = ymnos + theNFcn = nfcn + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/DavidonErrorUpdator.kt b/kmath-optimization/src/commonMain/tmp/minuit/DavidonErrorUpdator.kt new file mode 100644 index 000000000..9eb2443e4 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/DavidonErrorUpdator.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.RealVector +import ru.inr.mass.minuit.* + +/** + * + * @version $Id$ + */ +internal class DavidonErrorUpdator : MinimumErrorUpdator { + /** {@inheritDoc} */ + fun update(s0: MinimumState, p1: MinimumParameters, g1: FunctionGradient): MinimumError { + val V0: MnAlgebraicSymMatrix = s0.error().invHessian() + val dx: RealVector = MnUtils.sub(p1.vec(), s0.vec()) + val dg: RealVector = MnUtils.sub(g1.getGradient(), s0.gradient().getGradient()) + val delgam: Double = MnUtils.innerProduct(dx, dg) + val gvg: Double = MnUtils.similarity(dg, V0) + val vg: RealVector = MnUtils.mul(V0, dg) + var Vupd: MnAlgebraicSymMatrix = + MnUtils.sub(MnUtils.div(MnUtils.outerProduct(dx), delgam), MnUtils.div(MnUtils.outerProduct(vg), gvg)) + if (delgam > gvg) { + Vupd = MnUtils.add(Vupd, + MnUtils.mul(MnUtils.outerProduct(MnUtils.sub(MnUtils.div(dx, delgam), MnUtils.div(vg, gvg))), gvg)) + } + val sum_upd: Double = MnUtils.absoluteSumOfElements(Vupd) + Vupd = MnUtils.add(Vupd, V0) + val dcov: Double = 0.5 * (s0.error().dcovar() + sum_upd / MnUtils.absoluteSumOfElements(Vupd)) + return MinimumError(Vupd, dcov) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/FunctionGradient.kt b/kmath-optimization/src/commonMain/tmp/minuit/FunctionGradient.kt new file mode 100644 index 000000000..a0866d916 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/FunctionGradient.kt @@ -0,0 +1,72 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector + +/** + * + * @version $Id$ + */ +class FunctionGradient { + private var theAnalytical = false + private var theG2ndDerivative: RealVector + private var theGStepSize: RealVector + private var theGradient: RealVector + private var theValid = false + + constructor(n: Int) { + theGradient = ArrayRealVector(n) + theG2ndDerivative = ArrayRealVector(n) + theGStepSize = ArrayRealVector(n) + } + + constructor(grd: RealVector) { + theGradient = grd + theG2ndDerivative = ArrayRealVector(grd.getDimension()) + theGStepSize = ArrayRealVector(grd.getDimension()) + theValid = true + theAnalytical = true + } + + constructor(grd: RealVector, g2: RealVector, gstep: RealVector) { + theGradient = grd + theG2ndDerivative = g2 + theGStepSize = gstep + theValid = true + theAnalytical = false + } + + fun getGradient(): RealVector { + return theGradient + } + + fun getGradientDerivative(): RealVector { + return theG2ndDerivative + } + + fun getStep(): RealVector { + return theGStepSize + } + + fun isAnalytical(): Boolean { + return theAnalytical + } + + fun isValid(): Boolean { + return theValid + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/FunctionMinimum.kt b/kmath-optimization/src/commonMain/tmp/minuit/FunctionMinimum.kt new file mode 100644 index 000000000..e43523291 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/FunctionMinimum.kt @@ -0,0 +1,260 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.minuit.* +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * Result of the minimization. + * + * + * The FunctionMinimum is the output of the minimizers and contains the + * minimization result. The methods + * + * * userState(), + * * userParameters() and + * * userCovariance() + * + * are provided. These can be used as new input to a new minimization after some + * manipulation. The parameters and/or the FunctionMinimum can be printed using + * the toString() method or the MnPrint class. + * + * @author Darksnake + */ +class FunctionMinimum { + private var theAboveMaxEdm = false + private var theErrorDef: Double + private var theReachedCallLimit = false + private var theSeed: MinimumSeed + private var theStates: MutableList + private var theUserState: MnUserParameterState + + internal constructor(seed: MinimumSeed, up: Double) { + theSeed = seed + theStates = ArrayList() + theStates.add(MinimumState(seed.parameters(), + seed.error(), + seed.gradient(), + seed.parameters().fval(), + seed.nfcn())) + theErrorDef = up + theUserState = MnUserParameterState() + } + + internal constructor(seed: MinimumSeed, states: MutableList, up: Double) { + theSeed = seed + theStates = states + theErrorDef = up + theUserState = MnUserParameterState() + } + + internal constructor(seed: MinimumSeed, states: MutableList, up: Double, x: MnReachedCallLimit?) { + theSeed = seed + theStates = states + theErrorDef = up + theReachedCallLimit = true + theUserState = MnUserParameterState() + } + + internal constructor(seed: MinimumSeed, states: MutableList, up: Double, x: MnAboveMaxEdm?) { + theSeed = seed + theStates = states + theErrorDef = up + theAboveMaxEdm = true + theReachedCallLimit = false + theUserState = MnUserParameterState() + } + + // why not + fun add(state: MinimumState) { + theStates.add(state) + } + + /** + * returns the expected vertical distance to the minimum (EDM) + * + * @return a double. + */ + fun edm(): Double { + return lastState().edm() + } + + fun error(): MinimumError { + return lastState().error() + } + + /** + * + * + * errorDef. + * + * @return a double. + */ + fun errorDef(): Double { + return theErrorDef + } + + /** + * Returns the function value at the minimum. + * + * @return a double. + */ + fun fval(): Double { + return lastState().fval() + } + + fun grad(): FunctionGradient { + return lastState().gradient() + } + + fun hasAccurateCovar(): Boolean { + return state().error().isAccurate() + } + + fun hasCovariance(): Boolean { + return state().error().isAvailable() + } + + fun hasMadePosDefCovar(): Boolean { + return state().error().isMadePosDef() + } + + fun hasPosDefCovar(): Boolean { + return state().error().isPosDef() + } + + fun hasReachedCallLimit(): Boolean { + return theReachedCallLimit + } + + fun hasValidCovariance(): Boolean { + return state().error().isValid() + } + + fun hasValidParameters(): Boolean { + return state().parameters().isValid() + } + + fun hesseFailed(): Boolean { + return state().error().hesseFailed() + } + + fun isAboveMaxEdm(): Boolean { + return theAboveMaxEdm + } + + /** + * In general, if this returns true, the minimizer did find a + * minimum without running into troubles. However, in some cases a minimum + * cannot be found, then the return value will be false. + * Reasons for the minimization to fail are + * + * * the number of allowed function calls has been exhausted + * * the minimizer could not improve the values of the parameters (and + * knowing that it has not converged yet) + * * a problem with the calculation of the covariance matrix + * + * Additional methods for the analysis of the state at the minimum are + * provided. + * + * @return a boolean. + */ + fun isValid(): Boolean { + return state().isValid() && !isAboveMaxEdm() && !hasReachedCallLimit() + } + + private fun lastState(): MinimumState { + return theStates[theStates.size - 1] + } + // forward interface of last state + /** + * returns the total number of function calls during the minimization. + * + * @return a int. + */ + fun nfcn(): Int { + return lastState().nfcn() + } + + fun parameters(): MinimumParameters { + return lastState().parameters() + } + + fun seed(): MinimumSeed { + return theSeed + } + + fun state(): MinimumState { + return lastState() + } + + fun states(): List { + return theStates + } + + /** + * {@inheritDoc} + * + * @return + */ + override fun toString(): String { + return MnPrint.toString(this) + } + + /** + * + * + * userCovariance. + * + * @return a [hep.dataforge.MINUIT.MnUserCovariance] object. + */ + fun userCovariance(): MnUserCovariance { + if (!theUserState.isValid()) { + theUserState = MnUserParameterState(state(), errorDef(), seed().trafo()) + } + return theUserState.covariance() + } + + /** + * + * + * userParameters. + * + * @return a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + fun userParameters(): MnUserParameters { + if (!theUserState.isValid()) { + theUserState = MnUserParameterState(state(), errorDef(), seed().trafo()) + } + return theUserState.parameters() + } + + /** + * user representation of state at minimum + * + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun userState(): MnUserParameterState { + if (!theUserState.isValid()) { + theUserState = MnUserParameterState(state(), errorDef(), seed().trafo()) + } + return theUserState + } + + internal class MnAboveMaxEdm + internal class MnReachedCallLimit +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/GradientCalculator.kt b/kmath-optimization/src/commonMain/tmp/minuit/GradientCalculator.kt new file mode 100644 index 000000000..379de1b6d --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/GradientCalculator.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +interface GradientCalculator { + /** + * + * gradient. + * + * @param par a [hep.dataforge.MINUIT.MinimumParameters] object. + * @return a [hep.dataforge.MINUIT.FunctionGradient] object. + */ + fun gradient(par: MinimumParameters?): FunctionGradient + + /** + * + * gradient. + * + * @param par a [hep.dataforge.MINUIT.MinimumParameters] object. + * @param grad a [hep.dataforge.MINUIT.FunctionGradient] object. + * @return a [hep.dataforge.MINUIT.FunctionGradient] object. + */ + fun gradient(par: MinimumParameters?, grad: FunctionGradient?): FunctionGradient +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/HessianGradientCalculator.kt b/kmath-optimization/src/commonMain/tmp/minuit/HessianGradientCalculator.kt new file mode 100644 index 000000000..150d192f9 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/HessianGradientCalculator.kt @@ -0,0 +1,137 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector +import ru.inr.mass.minuit.* + +/** + * + * @version $Id$ + */ +internal class HessianGradientCalculator(fcn: MnFcn, par: MnUserTransformation, stra: MnStrategy) : GradientCalculator { + private val theFcn: MnFcn = fcn + private val theStrategy: MnStrategy + private val theTransformation: MnUserTransformation + fun deltaGradient(par: MinimumParameters, gradient: FunctionGradient): Pair { + require(par.isValid()) { "parameters are invalid" } + val x: RealVector = par.vec().copy() + val grd: RealVector = gradient.getGradient().copy() + val g2: RealVector = gradient.getGradientDerivative() + val gstep: RealVector = gradient.getStep() + val fcnmin: Double = par.fval() + // std::cout<<"fval: "< optstp) { + d = optstp + } + if (d < dmin) { + d = dmin + } + var chgold = 10000.0 + var dgmin = 0.0 + var grdold = 0.0 + var grdnew = 0.0 + for (j in 0 until ncycle()) { + x.setEntry(i, xtf + d) + val fs1: Double = theFcn.value(x) + x.setEntry(i, xtf - d) + val fs2: Double = theFcn.value(x) + x.setEntry(i, xtf) + // double sag = 0.5*(fs1+fs2-2.*fcnmin); + grdold = grd.getEntry(i) + grdnew = (fs1 - fs2) / (2.0 * d) + dgmin = precision().eps() * (abs(fs1) + abs(fs2)) / d + if (abs(grdnew) < precision().eps()) { + break + } + val change: Double = abs((grdold - grdnew) / grdnew) + if (change > chgold && j > 1) { + break + } + chgold = change + grd.setEntry(i, grdnew) + if (change < 0.05) { + break + } + if (abs(grdold - grdnew) < dgmin) { + break + } + if (d < dmin) { + break + } + d *= 0.2 + } + dgrd.setEntry(i, max(dgmin, abs(grdold - grdnew))) + } + return Pair(FunctionGradient(grd, g2, gstep), dgrd) + } + + fun fcn(): MnFcn { + return theFcn + } + + fun gradTolerance(): Double { + return strategy().gradientTolerance() + } + + /** {@inheritDoc} */ + fun gradient(par: MinimumParameters): FunctionGradient { + val gc = InitialGradientCalculator(theFcn, theTransformation, theStrategy) + val gra: FunctionGradient = gc.gradient(par) + return gradient(par, gra) + } + + /** {@inheritDoc} */ + fun gradient(par: MinimumParameters, gradient: FunctionGradient): FunctionGradient { + return deltaGradient(par, gradient).getFirst() + } + + fun ncycle(): Int { + return strategy().hessianGradientNCycles() + } + + fun precision(): MnMachinePrecision { + return theTransformation.precision() + } + + fun stepTolerance(): Double { + return strategy().gradientStepTolerance() + } + + fun strategy(): MnStrategy { + return theStrategy + } + + fun trafo(): MnUserTransformation { + return theTransformation + } + + init { + theTransformation = par + theStrategy = stra + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/InitialGradientCalculator.kt b/kmath-optimization/src/commonMain/tmp/minuit/InitialGradientCalculator.kt new file mode 100644 index 000000000..794556414 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/InitialGradientCalculator.kt @@ -0,0 +1,116 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector +import ru.inr.mass.minuit.* + +/** + * Calculating derivatives via finite differences + * @version $Id$ + */ +internal class InitialGradientCalculator(fcn: MnFcn, par: MnUserTransformation, stra: MnStrategy) { + private val theFcn: MnFcn = fcn + private val theStrategy: MnStrategy + private val theTransformation: MnUserTransformation + fun fcn(): MnFcn { + return theFcn + } + + fun gradTolerance(): Double { + return strategy().gradientTolerance() + } + + fun gradient(par: MinimumParameters): FunctionGradient { + require(par.isValid()) { "Parameters are invalid" } + val n: Int = trafo().variableParameters() + require(n == par.vec().getDimension()) { "Parameters have invalid size" } + val gr: RealVector = ArrayRealVector(n) + val gr2: RealVector = ArrayRealVector(n) + val gst: RealVector = ArrayRealVector(n) + + // initial starting values + for (i in 0 until n) { + val exOfIn: Int = trafo().extOfInt(i) + val `var`: Double = par.vec().getEntry(i) //parameter value + val werr: Double = trafo().parameter(exOfIn).error() //parameter error + val sav: Double = trafo().int2ext(i, `var`) //value after transformation + var sav2 = sav + werr //value after transfomation + error + if (trafo().parameter(exOfIn).hasLimits()) { + if (trafo().parameter(exOfIn).hasUpperLimit() + && sav2 > trafo().parameter(exOfIn).upperLimit() + ) { + sav2 = trafo().parameter(exOfIn).upperLimit() + } + } + var var2: Double = trafo().ext2int(exOfIn, sav2) + val vplu = var2 - `var` + sav2 = sav - werr + if (trafo().parameter(exOfIn).hasLimits()) { + if (trafo().parameter(exOfIn).hasLowerLimit() + && sav2 < trafo().parameter(exOfIn).lowerLimit() + ) { + sav2 = trafo().parameter(exOfIn).lowerLimit() + } + } + var2 = trafo().ext2int(exOfIn, sav2) + val vmin = var2 - `var` + val dirin: Double = 0.5 * (abs(vplu) + abs(vmin)) + val g2: Double = 2.0 * theFcn.errorDef() / (dirin * dirin) + val gsmin: Double = 8.0 * precision().eps2() * (abs(`var`) + precision().eps2()) + var gstep: Double = max(gsmin, 0.1 * dirin) + val grd = g2 * dirin + if (trafo().parameter(exOfIn).hasLimits()) { + if (gstep > 0.5) { + gstep = 0.5 + } + } + gr.setEntry(i, grd) + gr2.setEntry(i, g2) + gst.setEntry(i, gstep) + } + return FunctionGradient(gr, gr2, gst) + } + + fun gradient(par: MinimumParameters, gra: FunctionGradient?): FunctionGradient { + return gradient(par) + } + + fun ncycle(): Int { + return strategy().gradientNCycles() + } + + fun precision(): MnMachinePrecision { + return theTransformation.precision() + } + + fun stepTolerance(): Double { + return strategy().gradientStepTolerance() + } + + fun strategy(): MnStrategy { + return theStrategy + } + + fun trafo(): MnUserTransformation { + return theTransformation + } + + init { + theTransformation = par + theStrategy = stra + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MINOSResult.kt b/kmath-optimization/src/commonMain/tmp/minuit/MINOSResult.kt new file mode 100644 index 000000000..c33994648 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MINOSResult.kt @@ -0,0 +1,70 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package space.kscience.kmath.optimization.minuit + + +/** + * Контейнер для несимметричных оценок и доверительных интервалов + * + * @author Darksnake + * @version $Id: $Id + */ +class MINOSResult +/** + * + * Constructor for MINOSResult. + * + * @param list an array of [String] objects. + */(private val names: Array, private val errl: DoubleArray?, private val errp: DoubleArray?) : + IntervalEstimate { + fun getNames(): NameList { + return NameList(names) + } + + fun getInterval(parName: String?): Pair { + val index: Int = getNames().getNumberByName(parName) + return Pair(ValueFactory.of(errl!![index]), ValueFactory.of(errp!![index])) + } + + val cL: Double + get() = 0.68 + + /** {@inheritDoc} */ + fun print(out: PrintWriter) { + if (errl != null || errp != null) { + out.println() + out.println("Assymetrical errors:") + out.println() + out.println("Name\tLower\tUpper") + for (i in 0 until getNames().size()) { + out.print(getNames().get(i)) + out.print("\t") + if (errl != null) { + out.print(errl[i]) + } else { + out.print("---") + } + out.print("\t") + if (errp != null) { + out.print(errp[i]) + } else { + out.print("---") + } + out.println() + } + } + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MINUITFitter.kt b/kmath-optimization/src/commonMain/tmp/minuit/MINUITFitter.kt new file mode 100644 index 000000000..a26321249 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MINUITFitter.kt @@ -0,0 +1,205 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ +package space.kscience.kmath.optimization.minuit + +import ru.inr.mass.minuit.* + +/** + * + * + * MINUITFitter class. + * + * @author Darksnake + * @version $Id: $Id + */ +class MINUITFitter : Fitter { + fun run(state: FitState, parentLog: History?, meta: Meta): FitResult { + val log = Chronicle("MINUIT", parentLog) + val action: String = meta.getString("action", TASK_RUN) + log.report("MINUIT fit engine started action '{}'", action) + return when (action) { + TASK_COVARIANCE -> runHesse(state, log, meta) + TASK_SINGLE, TASK_RUN -> runFit(state, log, meta) + else -> throw IllegalArgumentException("Unknown task") + } + } + + @NotNull + fun getName(): String { + return MINUIT_ENGINE_NAME + } + + /** + * + * + * runHesse. + * + * @param state a [hep.dataforge.stat.fit.FitState] object. + * @param log + * @return a [FitResult] object. + */ + fun runHesse(state: FitState, log: History, meta: Meta?): FitResult { + val strategy: Int + strategy = Global.INSTANCE.getInt("MINUIT_STRATEGY", 2) + log.report("Generating errors using MnHesse 2-nd order gradient calculator.") + val fcn: MultiFunction + val fitPars: Array = Fitter.Companion.getFitPars(state, meta) + val pars: ParamSet = state.getParameters() + fcn = MINUITUtils.getFcn(state, pars, fitPars) + val hesse = MnHesse(strategy) + val mnState: MnUserParameterState = hesse.calculate(fcn, MINUITUtils.getFitParameters(pars, fitPars)) + val allPars: ParamSet = pars.copy() + for (fitPar in fitPars) { + allPars.setParValue(fitPar, mnState.value(fitPar)) + allPars.setParError(fitPar, mnState.error(fitPar)) + } + val newState: FitState.Builder = state.edit() + newState.setPars(allPars) + if (mnState.hasCovariance()) { + val mnCov: MnUserCovariance = mnState.covariance() + var j: Int + val cov = Array(mnState.variableParameters()) { DoubleArray(mnState.variableParameters()) } + for (i in 0 until mnState.variableParameters()) { + j = 0 + while (j < mnState.variableParameters()) { + cov[i][j] = mnCov.get(i, j) + j++ + } + } + newState.setCovariance(NamedMatrix(fitPars, cov), true) + } + return FitResult.build(newState.build(), fitPars) + } + + fun runFit(state: FitState, log: History, meta: Meta): FitResult { + val minuit: MnApplication + log.report("Starting fit using Minuit.") + val strategy: Int + strategy = Global.INSTANCE.getInt("MINUIT_STRATEGY", 2) + var force: Boolean + force = Global.INSTANCE.getBoolean("FORCE_DERIVS", false) + val fitPars: Array = Fitter.Companion.getFitPars(state, meta) + for (fitPar in fitPars) { + if (!state.modelProvidesDerivs(fitPar)) { + force = true + log.reportError("Model does not provide derivatives for parameter '{}'", fitPar) + } + } + if (force) { + log.report("Using MINUIT gradient calculator.") + } + val fcn: MultiFunction + val pars: ParamSet = state.getParameters().copy() + fcn = MINUITUtils.getFcn(state, pars, fitPars) + val method: String = meta.getString("method", MINUIT_MIGRAD) + when (method) { + MINUIT_MINOS, MINUIT_MINIMIZE -> minuit = + MnMinimize(fcn, MINUITUtils.getFitParameters(pars, fitPars), strategy) + MINUIT_SIMPLEX -> minuit = MnSimplex(fcn, MINUITUtils.getFitParameters(pars, fitPars), strategy) + else -> minuit = MnMigrad(fcn, MINUITUtils.getFitParameters(pars, fitPars), strategy) + } + if (force) { + minuit.setUseAnalyticalDerivatives(false) + log.report("Forced to use MINUIT internal derivative calculator!") + } + +// minuit.setUseAnalyticalDerivatives(true); + val minimum: FunctionMinimum + val maxSteps: Int = meta.getInt("iterations", -1) + val tolerance: Double = meta.getDouble("tolerance", -1) + minimum = if (maxSteps > 0) { + if (tolerance > 0) { + minuit.minimize(maxSteps, tolerance) + } else { + minuit.minimize(maxSteps) + } + } else { + minuit.minimize() + } + if (!minimum.isValid()) { + log.report("Minimization failed!") + } + log.report("MINUIT run completed in {} function calls.", minimum.nfcn()) + + /* + * Генерация результата + */ + val allPars: ParamSet = pars.copy() + for (fitPar in fitPars) { + allPars.setParValue(fitPar, minimum.userParameters().value(fitPar)) + allPars.setParError(fitPar, minimum.userParameters().error(fitPar)) + } + val newState: FitState.Builder = state.edit() + newState.setPars(allPars) + var valid: Boolean = minimum.isValid() + if (minimum.userCovariance().nrow() > 0) { + var j: Int + val cov = Array(minuit.variableParameters()) { DoubleArray(minuit.variableParameters()) } + if (cov[0].length == 1) { + cov[0][0] = minimum.userParameters().error(0) * minimum.userParameters().error(0) + } else { + for (i in 0 until minuit.variableParameters()) { + j = 0 + while (j < minuit.variableParameters()) { + cov[i][j] = minimum.userCovariance().get(i, j) + j++ + } + } + } + newState.setCovariance(NamedMatrix(fitPars, cov), true) + } + if (method == MINUIT_MINOS) { + log.report("Starting MINOS procedure for precise error estimation.") + val minos = MnMinos(fcn, minimum, strategy) + var mnError: MinosError + val errl = DoubleArray(fitPars.size) + val errp = DoubleArray(fitPars.size) + for (i in fitPars.indices) { + mnError = minos.minos(i) + if (mnError.isValid()) { + errl[i] = mnError.lower() + errp[i] = mnError.upper() + } else { + valid = false + } + } + val minosErrors = MINOSResult(fitPars, errl, errp) + newState.setInterval(minosErrors) + } + return FitResult.build(newState.build(), valid, fitPars) + } + + companion object { + /** + * Constant `MINUIT_MIGRAD="MIGRAD"` + */ + const val MINUIT_MIGRAD = "MIGRAD" + + /** + * Constant `MINUIT_MINIMIZE="MINIMIZE"` + */ + const val MINUIT_MINIMIZE = "MINIMIZE" + + /** + * Constant `MINUIT_SIMPLEX="SIMPLEX"` + */ + const val MINUIT_SIMPLEX = "SIMPLEX" + + /** + * Constant `MINUIT_MINOS="MINOS"` + */ + const val MINUIT_MINOS = "MINOS" //MINOS errors + + /** + * Constant `MINUIT_HESSE="HESSE"` + */ + const val MINUIT_HESSE = "HESSE" //HESSE errors + + /** + * Constant `MINUIT_ENGINE_NAME="MINUIT"` + */ + const val MINUIT_ENGINE_NAME = "MINUIT" + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MINUITPlugin.kt b/kmath-optimization/src/commonMain/tmp/minuit/MINUITPlugin.kt new file mode 100644 index 000000000..7eaefd9d2 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MINUITPlugin.kt @@ -0,0 +1,86 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ +package space.kscience.kmath.optimization.minuit + +import hep.dataforge.context.* + +/** + * Мэнеджер для MINUITа. Пока не играет никакой активной роли кроме ведения + * внутреннего лога. + * + * @author Darksnake + * @version $Id: $Id + */ +@PluginDef(group = "hep.dataforge", + name = "MINUIT", + dependsOn = ["hep.dataforge:fitting"], + info = "The MINUIT fitter engine for DataForge fitting") +class MINUITPlugin : BasicPlugin() { + fun attach(@NotNull context: Context?) { + super.attach(context) + clearStaticLog() + } + + @Provides(Fitter.FITTER_TARGET) + fun getFitter(fitterName: String): Fitter? { + return if (fitterName == "MINUIT") { + MINUITFitter() + } else { + null + } + } + + @ProvidesNames(Fitter.FITTER_TARGET) + fun listFitters(): List { + return listOf("MINUIT") + } + + fun detach() { + clearStaticLog() + super.detach() + } + + class Factory : PluginFactory() { + fun build(meta: Meta?): Plugin { + return MINUITPlugin() + } + + fun getType(): java.lang.Class { + return MINUITPlugin::class.java + } + } + + companion object { + /** + * Constant `staticLog` + */ + private val staticLog: Chronicle? = Chronicle("MINUIT-STATIC", Global.INSTANCE.getHistory()) + + /** + * + * + * clearStaticLog. + */ + fun clearStaticLog() { + staticLog.clear() + } + + /** + * + * + * logStatic. + * + * @param str a [String] object. + * @param pars a [Object] object. + */ + fun logStatic(str: String?, vararg pars: Any?) { + checkNotNull(staticLog) { "MINUIT log is not initialized." } + staticLog.report(str, pars) + LoggerFactory.getLogger("MINUIT").info(String.format(str, *pars)) + // Out.out.printf(str,pars); +// Out.out.println(); + } + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MINUITUtils.kt b/kmath-optimization/src/commonMain/tmp/minuit/MINUITUtils.kt new file mode 100644 index 000000000..44c70cb42 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MINUITUtils.kt @@ -0,0 +1,121 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ +package space.kscience.kmath.optimization.minuit + +import hep.dataforge.MINUIT.FunctionMinimum + +internal object MINUITUtils { + fun getFcn(source: FitState, allPar: ParamSet, fitPars: Array): MultiFunction { + return MnFunc(source, allPar, fitPars) + } + + fun getFitParameters(set: ParamSet, fitPars: Array): MnUserParameters { + val pars = MnUserParameters() + var i: Int + var par: Param + i = 0 + while (i < fitPars.size) { + par = set.getByName(fitPars[i]) + pars.add(fitPars[i], par.getValue(), par.getErr()) + if (par.getLowerBound() > Double.NEGATIVE_INFINITY && par.getUpperBound() < Double.POSITIVE_INFINITY) { + pars.setLimits(i, par.getLowerBound(), par.getUpperBound()) + } else if (par.getLowerBound() > Double.NEGATIVE_INFINITY) { + pars.setLowerLimit(i, par.getLowerBound()) + } else if (par.getUpperBound() < Double.POSITIVE_INFINITY) { + pars.setUpperLimit(i, par.getUpperBound()) + } + i++ + } + return pars + } + + fun getValueSet(allPar: ParamSet, names: Array, values: DoubleArray): ParamSet { + assert(values.size == names.size) + assert(allPar.getNames().contains(names)) + val vector: ParamSet = allPar.copy() + for (i in values.indices) { + vector.setParValue(names[i], values[i]) + } + return vector + } + + fun isValidArray(ar: DoubleArray): Boolean { + for (i in ar.indices) { + if (java.lang.Double.isNaN(ar[i])) { + return false + } + } + return true + } + + /** + * + * + * printMINUITResult. + * + * @param out a [PrintWriter] object. + * @param minimum a [hep.dataforge.MINUIT.FunctionMinimum] object. + */ + fun printMINUITResult(out: PrintWriter, minimum: FunctionMinimum?) { + out.println() + out.println("***MINUIT INTERNAL FIT INFORMATION***") + out.println() + MnPrint.print(out, minimum) + out.println() + out.println("***END OF MINUIT INTERNAL FIT INFORMATION***") + out.println() + } + + internal class MnFunc(source: FitState, allPar: ParamSet, fitPars: Array) : MultiFunction { + var source: FitState + var allPar: ParamSet + var fitPars: Array + fun value(doubles: DoubleArray): Double { + assert(isValidArray(doubles)) + assert(doubles.size == fitPars.size) + return -2 * source.getLogProb(getValueSet(allPar, fitPars, doubles)) + // source.getChi2(getValueSet(allPar, fitPars, doubles)); + } + + @Throws(NotDefinedException::class) + fun derivValue(n: Int, doubles: DoubleArray): Double { + assert(isValidArray(doubles)) + assert(doubles.size == getDimension()) + val set: ParamSet = getValueSet(allPar, fitPars, doubles) + +// double res; +// double d, s, deriv; +// +// res = 0; +// for (int i = 0; i < source.getDataNum(); i++) { +// d = source.getDis(i, set); +// s = source.getDispersion(i, set); +// if (source.modelProvidesDerivs(fitPars[n])) { +// deriv = source.getDisDeriv(fitPars[n], i, set); +// } else { +// throw new NotDefinedException(); +// // Такого не должно быть, поскольку мы где-то наверху должы были проверить, что производные все есть. +// } +// res += 2 * d * deriv / s; +// } + return -2 * source.getLogProbDeriv(fitPars[n], set) + } + + fun getDimension(): Int { + return fitPars.size + } + + fun providesDeriv(n: Int): Boolean { + return source.modelProvidesDerivs(fitPars[n]) + } + + init { + this.source = source + this.allPar = allPar + this.fitPars = fitPars + assert(source.getModel().getNames().contains(fitPars)) + } + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MinimumBuilder.kt b/kmath-optimization/src/commonMain/tmp/minuit/MinimumBuilder.kt new file mode 100644 index 000000000..7d918c339 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MinimumBuilder.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * + * @version $Id$ + */ +interface MinimumBuilder { + /** + * + * minimum. + * + * @param fcn a [hep.dataforge.MINUIT.MnFcn] object. + * @param gc a [hep.dataforge.MINUIT.GradientCalculator] object. + * @param seed a [hep.dataforge.MINUIT.MinimumSeed] object. + * @param strategy a [hep.dataforge.MINUIT.MnStrategy] object. + * @param maxfcn a int. + * @param toler a double. + * @return a [hep.dataforge.MINUIT.FunctionMinimum] object. + */ + fun minimum( + fcn: MnFcn?, + gc: GradientCalculator?, + seed: MinimumSeed?, + strategy: MnStrategy?, + maxfcn: Int, + toler: Double + ): FunctionMinimum +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MinimumError.kt b/kmath-optimization/src/commonMain/tmp/minuit/MinimumError.kt new file mode 100644 index 000000000..6993b9e6d --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MinimumError.kt @@ -0,0 +1,155 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import space.kscience.kmath.optimization.minuit.MINUITPlugin + +/** + * MinimumError keeps the inverse 2nd derivative (inverse Hessian) used for + * calculating the parameter step size (-V*g) and for the covariance update + * (ErrorUpdator). The covariance matrix is equal to twice the inverse Hessian. + * + * @version $Id$ + */ +class MinimumError { + private var theAvailable = false + private var theDCovar: Double + private var theHesseFailed = false + private var theInvertFailed = false + private var theMadePosDef = false + private var theMatrix: MnAlgebraicSymMatrix + private var thePosDef = false + private var theValid = false + + constructor(n: Int) { + theMatrix = MnAlgebraicSymMatrix(n) + theDCovar = 1.0 + } + + constructor(mat: MnAlgebraicSymMatrix, dcov: Double) { + theMatrix = mat + theDCovar = dcov + theValid = true + thePosDef = true + theAvailable = true + } + + constructor(mat: MnAlgebraicSymMatrix, x: MnHesseFailed?) { + theMatrix = mat + theDCovar = 1.0 + theValid = false + thePosDef = false + theMadePosDef = false + theHesseFailed = true + theInvertFailed = false + theAvailable = true + } + + constructor(mat: MnAlgebraicSymMatrix, x: MnMadePosDef?) { + theMatrix = mat + theDCovar = 1.0 + theValid = false + thePosDef = false + theMadePosDef = true + theHesseFailed = false + theInvertFailed = false + theAvailable = true + } + + constructor(mat: MnAlgebraicSymMatrix, x: MnInvertFailed?) { + theMatrix = mat + theDCovar = 1.0 + theValid = false + thePosDef = true + theMadePosDef = false + theHesseFailed = false + theInvertFailed = true + theAvailable = true + } + + constructor(mat: MnAlgebraicSymMatrix, x: MnNotPosDef?) { + theMatrix = mat + theDCovar = 1.0 + theValid = false + thePosDef = false + theMadePosDef = false + theHesseFailed = false + theInvertFailed = false + theAvailable = true + } + + fun dcovar(): Double { + return theDCovar + } + + fun hesseFailed(): Boolean { + return theHesseFailed + } + + fun hessian(): MnAlgebraicSymMatrix { + return try { + val tmp: MnAlgebraicSymMatrix = theMatrix.copy() + tmp.invert() + tmp + } catch (x: SingularMatrixException) { + MINUITPlugin.logStatic("BasicMinimumError inversion fails; return diagonal matrix.") + val tmp = MnAlgebraicSymMatrix(theMatrix.nrow()) + var i = 0 + while (i < theMatrix.nrow()) { + tmp[i, i] = 1.0 / theMatrix[i, i] + i++ + } + tmp + } + } + + fun invHessian(): MnAlgebraicSymMatrix { + return theMatrix + } + + fun invertFailed(): Boolean { + return theInvertFailed + } + + fun isAccurate(): Boolean { + return theDCovar < 0.1 + } + + fun isAvailable(): Boolean { + return theAvailable + } + + fun isMadePosDef(): Boolean { + return theMadePosDef + } + + fun isPosDef(): Boolean { + return thePosDef + } + + fun isValid(): Boolean { + return theValid + } + + fun matrix(): MnAlgebraicSymMatrix { + return MnUtils.mul(theMatrix, 2) + } + + internal class MnHesseFailed + internal class MnInvertFailed + internal class MnMadePosDef + internal class MnNotPosDef +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MinimumErrorUpdator.kt b/kmath-optimization/src/commonMain/tmp/minuit/MinimumErrorUpdator.kt new file mode 100644 index 000000000..6022aa5b7 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MinimumErrorUpdator.kt @@ -0,0 +1,33 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal interface MinimumErrorUpdator { + /** + * + * update. + * + * @param state a [hep.dataforge.MINUIT.MinimumState] object. + * @param par a [hep.dataforge.MINUIT.MinimumParameters] object. + * @param grad a [hep.dataforge.MINUIT.FunctionGradient] object. + * @return a [hep.dataforge.MINUIT.MinimumError] object. + */ + fun update(state: MinimumState?, par: MinimumParameters?, grad: FunctionGradient?): MinimumError? +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MinimumParameters.kt b/kmath-optimization/src/commonMain/tmp/minuit/MinimumParameters.kt new file mode 100644 index 000000000..bed13ea0b --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MinimumParameters.kt @@ -0,0 +1,70 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector + +/** + * + * @version $Id$ + */ +class MinimumParameters { + private var theFVal = 0.0 + private var theHasStep = false + private var theParameters: RealVector + private var theStepSize: RealVector + private var theValid = false + + constructor(n: Int) { + theParameters = ArrayRealVector(n) + theStepSize = ArrayRealVector(n) + } + + constructor(avec: RealVector, fval: Double) { + theParameters = avec + theStepSize = ArrayRealVector(avec.getDimension()) + theFVal = fval + theValid = true + } + + constructor(avec: RealVector, dirin: RealVector, fval: Double) { + theParameters = avec + theStepSize = dirin + theFVal = fval + theValid = true + theHasStep = true + } + + fun dirin(): RealVector { + return theStepSize + } + + fun fval(): Double { + return theFVal + } + + fun hasStepSize(): Boolean { + return theHasStep + } + + fun isValid(): Boolean { + return theValid + } + + fun vec(): RealVector { + return theParameters + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MinimumSeed.kt b/kmath-optimization/src/commonMain/tmp/minuit/MinimumSeed.kt new file mode 100644 index 000000000..53a78da75 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MinimumSeed.kt @@ -0,0 +1,66 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package space.kscience.kmath.optimization.minuit + +import ru.inr.mass.minuit.* + +/** + * + * @version $Id$ + */ +class MinimumSeed(state: MinimumState, trafo: MnUserTransformation) { + private val theState: MinimumState = state + private val theTrafo: MnUserTransformation = trafo + private val theValid: Boolean = true + val edm: Double get() = state().edm() + + fun error(): MinimumError { + return state().error() + } + + fun fval(): Double { + return state().fval() + } + + fun gradient(): FunctionGradient { + return state().gradient() + } + + fun isValid(): Boolean { + return theValid + } + + fun nfcn(): Int { + return state().nfcn() + } + + fun parameters(): MinimumParameters { + return state().parameters() + } + + fun precision(): MnMachinePrecision { + return theTrafo.precision() + } + + fun state(): MinimumState { + return theState + } + + fun trafo(): MnUserTransformation { + return theTrafo + } + +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MinimumSeedGenerator.kt b/kmath-optimization/src/commonMain/tmp/minuit/MinimumSeedGenerator.kt new file mode 100644 index 000000000..e152559b5 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MinimumSeedGenerator.kt @@ -0,0 +1,39 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * base class for seed generators (starting values); the seed generator prepares + * initial starting values from the input (MnUserParameterState) for the + * minimization; + * + * @version $Id$ + */ +interface MinimumSeedGenerator { + /** + * + * generate. + * + * @param fcn a [hep.dataforge.MINUIT.MnFcn] object. + * @param calc a [hep.dataforge.MINUIT.GradientCalculator] object. + * @param user a [hep.dataforge.MINUIT.MnUserParameterState] object. + * @param stra a [hep.dataforge.MINUIT.MnStrategy] object. + * @return a [hep.dataforge.MINUIT.MinimumSeed] object. + */ + fun generate(fcn: MnFcn?, calc: GradientCalculator?, user: MnUserParameterState?, stra: MnStrategy?): MinimumSeed +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MinimumState.kt b/kmath-optimization/src/commonMain/tmp/minuit/MinimumState.kt new file mode 100644 index 000000000..9f63e0e1f --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MinimumState.kt @@ -0,0 +1,104 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.RealVector + +/** + * MinimumState keeps the information (position, gradient, 2nd deriv, etc) after + * one minimization step (usually in MinimumBuilder). + * + * @version $Id$ + */ +class MinimumState { + private var theEDM = 0.0 + private var theError: MinimumError + private var theGradient: FunctionGradient + private var theNFcn = 0 + private var theParameters: MinimumParameters + + constructor(n: Int) { + theParameters = MinimumParameters(n) + theError = MinimumError(n) + theGradient = FunctionGradient(n) + } + + constructor(states: MinimumParameters, err: MinimumError, grad: FunctionGradient, edm: Double, nfcn: Int) { + theParameters = states + theError = err + theGradient = grad + theEDM = edm + theNFcn = nfcn + } + + constructor(states: MinimumParameters, edm: Double, nfcn: Int) { + theParameters = states + theError = MinimumError(states.vec().getDimension()) + theGradient = FunctionGradient(states.vec().getDimension()) + theEDM = edm + theNFcn = nfcn + } + + fun edm(): Double { + return theEDM + } + + fun error(): MinimumError { + return theError + } + + fun fval(): Double { + return theParameters.fval() + } + + fun gradient(): FunctionGradient { + return theGradient + } + + fun hasCovariance(): Boolean { + return theError.isAvailable() + } + + fun hasParameters(): Boolean { + return theParameters.isValid() + } + + fun isValid(): Boolean { + return if (hasParameters() && hasCovariance()) { + parameters().isValid() && error().isValid() + } else if (hasParameters()) { + parameters().isValid() + } else { + false + } + } + + fun nfcn(): Int { + return theNFcn + } + + fun parameters(): MinimumParameters { + return theParameters + } + + fun size(): Int { + return theParameters.vec().getDimension() + } + + fun vec(): RealVector { + return theParameters.vec() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MinosError.kt b/kmath-optimization/src/commonMain/tmp/minuit/MinosError.kt new file mode 100644 index 000000000..c7cf10523 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MinosError.kt @@ -0,0 +1,219 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * MinosError class. + * + * @author Darksnake + * @version $Id$ + */ +class MinosError { + private var theLower: MnCross + private var theMinValue = 0.0 + private var theParameter = 0 + private var theUpper: MnCross + + internal constructor() { + theUpper = MnCross() + theLower = MnCross() + } + + internal constructor(par: Int, min: Double, low: MnCross, up: MnCross) { + theParameter = par + theMinValue = min + theUpper = up + theLower = low + } + + /** + * + * atLowerLimit. + * + * @return a boolean. + */ + fun atLowerLimit(): Boolean { + return theLower.atLimit() + } + + /** + * + * atLowerMaxFcn. + * + * @return a boolean. + */ + fun atLowerMaxFcn(): Boolean { + return theLower.atMaxFcn() + } + + /** + * + * atUpperLimit. + * + * @return a boolean. + */ + fun atUpperLimit(): Boolean { + return theUpper.atLimit() + } + + /** + * + * atUpperMaxFcn. + * + * @return a boolean. + */ + fun atUpperMaxFcn(): Boolean { + return theUpper.atMaxFcn() + } + + /** + * + * isValid. + * + * @return a boolean. + */ + fun isValid(): Boolean { + return theLower.isValid() && theUpper.isValid() + } + + /** + * + * lower. + * + * @return a double. + */ + fun lower(): Double { + return -1.0 * lowerState().error(parameter()) * (1.0 + theLower.value()) + } + + /** + * + * lowerNewMin. + * + * @return a boolean. + */ + fun lowerNewMin(): Boolean { + return theLower.newMinimum() + } + + /** + * + * lowerState. + * + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun lowerState(): MnUserParameterState { + return theLower.state() + } + + /** + * + * lowerValid. + * + * @return a boolean. + */ + fun lowerValid(): Boolean { + return theLower.isValid() + } + + /** + * + * min. + * + * @return a double. + */ + fun min(): Double { + return theMinValue + } + + /** + * + * nfcn. + * + * @return a int. + */ + fun nfcn(): Int { + return theUpper.nfcn() + theLower.nfcn() + } + + /** + * + * parameter. + * + * @return a int. + */ + fun parameter(): Int { + return theParameter + } + + /** + * + * range. + * + * @return + */ + fun range(): Range { + return Range(lower(), upper()) + } + + /** + * {@inheritDoc} + */ + override fun toString(): String { + return MnPrint.toString(this) + } + + /** + * + * upper. + * + * @return a double. + */ + fun upper(): Double { + return upperState().error(parameter()) * (1.0 + theUpper.value()) + } + + /** + * + * upperNewMin. + * + * @return a boolean. + */ + fun upperNewMin(): Boolean { + return theUpper.newMinimum() + } + + /** + * + * upperState. + * + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun upperState(): MnUserParameterState { + return theUpper.state() + } + + /** + * + * upperValid. + * + * @return a boolean. + */ + fun upperValid(): Boolean { + return theUpper.isValid() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MinuitParameter.kt b/kmath-optimization/src/commonMain/tmp/minuit/MinuitParameter.kt new file mode 100644 index 000000000..ff6834df4 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MinuitParameter.kt @@ -0,0 +1,314 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +class MinuitParameter { + private var theConst = false + private var theError = 0.0 + private var theFix = false + private var theLoLimValid = false + private var theLoLimit = 0.0 + private var theName: String + private var theNum: Int + private var theUpLimValid = false + private var theUpLimit = 0.0 + private var theValue: Double + + /** + * constructor for constant parameter + * + * @param num a int. + * @param name a [String] object. + * @param val a double. + */ + constructor(num: Int, name: String, `val`: Double) { + theNum = num + theValue = `val` + theConst = true + theName = name + } + + /** + * constructor for standard parameter + * + * @param num a int. + * @param name a [String] object. + * @param val a double. + * @param err a double. + */ + constructor(num: Int, name: String, `val`: Double, err: Double) { + theNum = num + theValue = `val` + theError = err + theName = name + } + + /** + * constructor for limited parameter + * + * @param num a int. + * @param name a [String] object. + * @param val a double. + * @param err a double. + * @param min a double. + * @param max a double. + */ + constructor(num: Int, name: String, `val`: Double, err: Double, min: Double, max: Double) { + theNum = num + theValue = `val` + theError = err + theLoLimit = min + theUpLimit = max + theLoLimValid = true + theUpLimValid = true + require(min != max) { "min == max" } + if (min > max) { + theLoLimit = max + theUpLimit = min + } + theName = name + } + + private constructor(other: MinuitParameter) { + theNum = other.theNum + theName = other.theName + theValue = other.theValue + theError = other.theError + theConst = other.theConst + theFix = other.theFix + theLoLimit = other.theLoLimit + theUpLimit = other.theUpLimit + theLoLimValid = other.theLoLimValid + theUpLimValid = other.theUpLimValid + } + + /** + * + * copy. + * + * @return a [hep.dataforge.MINUIT.MinuitParameter] object. + */ + fun copy(): MinuitParameter { + return MinuitParameter(this) + } + + /** + * + * error. + * + * @return a double. + */ + fun error(): Double { + return theError + } + + /** + * + * fix. + */ + fun fix() { + theFix = true + } + + /** + * + * hasLimits. + * + * @return a boolean. + */ + fun hasLimits(): Boolean { + return theLoLimValid || theUpLimValid + } + + /** + * + * hasLowerLimit. + * + * @return a boolean. + */ + fun hasLowerLimit(): Boolean { + return theLoLimValid + } + + /** + * + * hasUpperLimit. + * + * @return a boolean. + */ + fun hasUpperLimit(): Boolean { + return theUpLimValid + } + //state of parameter (fixed/const/limited) + /** + * + * isConst. + * + * @return a boolean. + */ + fun isConst(): Boolean { + return theConst + } + + /** + * + * isFixed. + * + * @return a boolean. + */ + fun isFixed(): Boolean { + return theFix + } + + /** + * + * lowerLimit. + * + * @return a double. + */ + fun lowerLimit(): Double { + return theLoLimit + } + + /** + * + * name. + * + * @return a [String] object. + */ + fun name(): String { + return theName + } + //access methods + /** + * + * number. + * + * @return a int. + */ + fun number(): Int { + return theNum + } + + /** + * + * release. + */ + fun release() { + theFix = false + } + + /** + * + * removeLimits. + */ + fun removeLimits() { + theLoLimit = 0.0 + theUpLimit = 0.0 + theLoLimValid = false + theUpLimValid = false + } + + /** + * + * setError. + * + * @param err a double. + */ + fun setError(err: Double) { + theError = err + theConst = false + } + + /** + * + * setLimits. + * + * @param low a double. + * @param up a double. + */ + fun setLimits(low: Double, up: Double) { + require(low != up) { "min == max" } + theLoLimit = low + theUpLimit = up + theLoLimValid = true + theUpLimValid = true + if (low > up) { + theLoLimit = up + theUpLimit = low + } + } + + /** + * + * setLowerLimit. + * + * @param low a double. + */ + fun setLowerLimit(low: Double) { + theLoLimit = low + theUpLimit = 0.0 + theLoLimValid = true + theUpLimValid = false + } + + /** + * + * setUpperLimit. + * + * @param up a double. + */ + fun setUpperLimit(up: Double) { + theLoLimit = 0.0 + theUpLimit = up + theLoLimValid = false + theUpLimValid = true + } + //interaction + /** + * + * setValue. + * + * @param val a double. + */ + fun setValue(`val`: Double) { + theValue = `val` + } + + /** + * + * upperLimit. + * + * @return a double. + */ + fun upperLimit(): Double { + return theUpLimit + } + + /** + * + * value. + * + * @return a double. + */ + fun value(): Double { + return theValue + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnAlgebraicSymMatrix.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnAlgebraicSymMatrix.kt new file mode 100644 index 000000000..4b75858e1 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnAlgebraicSymMatrix.kt @@ -0,0 +1,458 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector + +/** + * + * @version $Id$ + */ +class MnAlgebraicSymMatrix(n: Int) { + private val theData: DoubleArray + private val theNRow: Int + private val theSize: Int + + /** + * + * copy. + * + * @return a [hep.dataforge.MINUIT.MnAlgebraicSymMatrix] object. + */ + fun copy(): MnAlgebraicSymMatrix { + val copy = MnAlgebraicSymMatrix(theNRow) + java.lang.System.arraycopy(theData, 0, copy.theData, 0, theSize) + return copy + } + + fun data(): DoubleArray { + return theData + } + + fun eigenvalues(): ArrayRealVector { + val nrow = theNRow + val tmp = DoubleArray((nrow + 1) * (nrow + 1)) + val work = DoubleArray(1 + 2 * nrow) + for (i in 0 until nrow) { + for (j in 0..i) { + tmp[1 + i + (1 + j) * nrow] = get(i, j) + tmp[(1 + i) * nrow + (1 + j)] = get(i, j) + } + } + val info = mneigen(tmp, nrow, nrow, work.size, work, 1e-6) + if (info != 0) { + throw EigenvaluesException() + } + val result = ArrayRealVector(nrow) + for (i in 0 until nrow) { + result.setEntry(i, work[1 + i]) + } + return result + } + + operator fun get(row: Int, col: Int): Double { + if (row >= theNRow || col >= theNRow) { + throw ArrayIndexOutOfBoundsException() + } + return theData[theIndex(row, col)] + } + + @Throws(SingularMatrixException::class) + fun invert() { + if (theSize == 1) { + val tmp = theData[0] + if (tmp <= 0.0) { + throw SingularMatrixException() + } + theData[0] = 1.0 / tmp + } else { + val nrow = theNRow + val s = DoubleArray(nrow) + val q = DoubleArray(nrow) + val pp = DoubleArray(nrow) + for (i in 0 until nrow) { + val si = theData[theIndex(i, i)] + if (si < 0.0) { + throw SingularMatrixException() + } + s[i] = 1.0 / sqrt(si) + } + for (i in 0 until nrow) { + for (j in i until nrow) { + theData[theIndex(i, j)] *= s[i] * s[j] + } + } + for (i in 0 until nrow) { + var k = i + if (theData[theIndex(k, k)] == 0.0) { + throw SingularMatrixException() + } + q[k] = 1.0 / theData[theIndex(k, k)] + pp[k] = 1.0 + theData[theIndex(k, k)] = 0.0 + val kp1 = k + 1 + if (k != 0) { + for (j in 0 until k) { + val index = theIndex(j, k) + pp[j] = theData[index] + q[j] = theData[index] * q[k] + theData[index] = 0.0 + } + } + if (k != nrow - 1) { + for (j in kp1 until nrow) { + val index = theIndex(k, j) + pp[j] = theData[index] + q[j] = -theData[index] * q[k] + theData[index] = 0.0 + } + } + for (j in 0 until nrow) { + k = j + while (k < nrow) { + theData[theIndex(j, k)] += pp[j] * q[k] + k++ + } + } + } + for (j in 0 until nrow) { + for (k in j until nrow) { + theData[theIndex(j, k)] *= s[j] * s[k] + } + } + } + } + + fun ncol(): Int { + return nrow() + } + + fun nrow(): Int { + return theNRow + } + + operator fun set(row: Int, col: Int, value: Double) { + if (row >= theNRow || col >= theNRow) { + throw ArrayIndexOutOfBoundsException() + } + theData[theIndex(row, col)] = value + } + + fun size(): Int { + return theSize + } + + private fun theIndex(row: Int, col: Int): Int { + return if (row > col) { + col + row * (row + 1) / 2 + } else { + row + col * (col + 1) / 2 + } + } + + /** {@inheritDoc} */ + override fun toString(): String { + return MnPrint.toString(this) + } /* mneig_ */ + + private inner class EigenvaluesException : RuntimeException() + companion object { + private fun mneigen(a: DoubleArray, ndima: Int, n: Int, mits: Int, work: DoubleArray, precis: Double): Int { + + /* System generated locals */ + var i__2: Int + var i__3: Int + + /* Local variables */ + var b: Double + var c__: Double + var f: Double + var h__: Double + var i__: Int + var j: Int + var k: Int + var l: Int + var m = 0 + var r__: Double + var s: Double + var i0: Int + var i1: Int + var j1: Int + var m1: Int + var hh: Double + var gl: Double + var pr: Double + var pt: Double + + /* PRECIS is the machine precision EPSMAC */ + /* Parameter adjustments */ + val a_dim1: Int = ndima + val a_offset: Int = 1 + a_dim1 * 1 + + /* Function Body */ + var ifault = 1 + i__ = n + var i__1: Int = n + i1 = 2 + while (i1 <= i__1) { + l = i__ - 2 + f = a[i__ + (i__ - 1) * a_dim1] + gl = 0.0 + if (l >= 1) { + i__2 = l + k = 1 + while (k <= i__2) { + + /* Computing 2nd power */ + val r__1 = a[i__ + k * a_dim1] + gl += r__1 * r__1 + ++k + } + } + /* Computing 2nd power */h__ = gl + f * f + if (gl <= 1e-35) { + work[i__] = 0.0 + work[n + i__] = f + } else { + ++l + gl = sqrt(h__) + if (f >= 0.0) { + gl = -gl + } + work[n + i__] = gl + h__ -= f * gl + a[i__ + (i__ - 1) * a_dim1] = f - gl + f = 0.0 + i__2 = l + j = 1 + while (j <= i__2) { + a[j + i__ * a_dim1] = a[i__ + j * a_dim1] / h__ + gl = 0.0 + i__3 = j + k = 1 + while (k <= i__3) { + gl += a[j + k * a_dim1] * a[i__ + k * a_dim1] + ++k + } + if (j < l) { + j1 = j + 1 + i__3 = l + k = j1 + while (k <= i__3) { + gl += a[k + j * a_dim1] * a[i__ + k * a_dim1] + ++k + } + } + work[n + j] = gl / h__ + f += gl * a[j + i__ * a_dim1] + ++j + } + hh = f / (h__ + h__) + i__2 = l + j = 1 + while (j <= i__2) { + f = a[i__ + j * a_dim1] + gl = work[n + j] - hh * f + work[n + j] = gl + i__3 = j + k = 1 + while (k <= i__3) { + a[j + k * a_dim1] = a[j + k * a_dim1] - f * work[n + k] - (gl + * a[i__ + k * a_dim1]) + ++k + } + ++j + } + work[i__] = h__ + } + --i__ + ++i1 + } + work[1] = 0.0 + work[n + 1] = 0.0 + i__1 = n + i__ = 1 + while (i__ <= i__1) { + l = i__ - 1 + if (work[i__] != 0.0 && l != 0) { + i__3 = l + j = 1 + while (j <= i__3) { + gl = 0.0 + i__2 = l + k = 1 + while (k <= i__2) { + gl += a[i__ + k * a_dim1] * a[k + j * a_dim1] + ++k + } + i__2 = l + k = 1 + while (k <= i__2) { + a[k + j * a_dim1] -= gl * a[k + i__ * a_dim1] + ++k + } + ++j + } + } + work[i__] = a[i__ + i__ * a_dim1] + a[i__ + i__ * a_dim1] = 1.0 + if (l != 0) { + i__2 = l + j = 1 + while (j <= i__2) { + a[i__ + j * a_dim1] = 0.0 + a[j + i__ * a_dim1] = 0.0 + ++j + } + } + ++i__ + } + val n1: Int = n - 1 + i__1 = n + i__ = 2 + while (i__ <= i__1) { + i0 = n + i__ - 1 + work[i0] = work[i0 + 1] + ++i__ + } + work[n + n] = 0.0 + b = 0.0 + f = 0.0 + i__1 = n + l = 1 + while (l <= i__1) { + j = 0 + h__ = precis * (abs(work[l]) + abs(work[n + l])) + if (b < h__) { + b = h__ + } + i__2 = n + m1 = l + while (m1 <= i__2) { + m = m1 + if (abs(work[n + m]) <= b) { + break + } + ++m1 + } + if (m != l) { + while (true) { + if (j == mits) { + return ifault + } + ++j + pt = (work[l + 1] - work[l]) / (work[n + l] * 2.0) + r__ = sqrt(pt * pt + 1.0) + pr = pt + r__ + if (pt < 0.0) { + pr = pt - r__ + } + h__ = work[l] - work[n + l] / pr + i__2 = n + i__ = l + while (i__ <= i__2) { + work[i__] -= h__ + ++i__ + } + f += h__ + pt = work[m] + c__ = 1.0 + s = 0.0 + m1 = m - 1 + i__ = m + i__2 = m1 + i1 = l + while (i1 <= i__2) { + j = i__ + --i__ + gl = c__ * work[n + i__] + h__ = c__ * pt + if (abs(pt) < abs(work[n + i__])) { + c__ = pt / work[n + i__] + r__ = sqrt(c__ * c__ + 1.0) + work[n + j] = s * work[n + i__] * r__ + s = 1.0 / r__ + c__ /= r__ + } else { + c__ = work[n + i__] / pt + r__ = sqrt(c__ * c__ + 1.0) + work[n + j] = s * pt * r__ + s = c__ / r__ + c__ = 1.0 / r__ + } + pt = c__ * work[i__] - s * gl + work[j] = h__ + s * (c__ * gl + s * work[i__]) + i__3 = n + k = 1 + while (k <= i__3) { + h__ = a[k + j * a_dim1] + a[k + j * a_dim1] = s * a[k + i__ * a_dim1] + c__ * h__ + a[k + i__ * a_dim1] = c__ * a[k + i__ * a_dim1] - s * h__ + ++k + } + ++i1 + } + work[n + l] = s * pt + work[l] = c__ * pt + if (abs(work[n + l]) <= b) { + break + } + } + } + work[l] += f + ++l + } + i__1 = n1 + i__ = 1 + while (i__ <= i__1) { + k = i__ + pt = work[i__] + i1 = i__ + 1 + i__3 = n + j = i1 + while (j <= i__3) { + if (work[j] < pt) { + k = j + pt = work[j] + } + ++j + } + if (k != i__) { + work[k] = work[i__] + work[i__] = pt + i__3 = n + j = 1 + while (j <= i__3) { + pt = a[j + i__ * a_dim1] + a[j + i__ * a_dim1] = a[j + k * a_dim1] + a[j + k * a_dim1] = pt + ++j + } + } + ++i__ + } + ifault = 0 + return ifault + } /* mneig_ */ + } + + init { + require(n >= 0) { "Invalid matrix size: $n" } + theSize = n * (n + 1) / 2 + theNRow = n + theData = DoubleArray(theSize) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnApplication.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnApplication.kt new file mode 100644 index 000000000..025eea4ae --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnApplication.kt @@ -0,0 +1,554 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction +import ru.inr.mass.minuit.* + +/** + * Base class for minimizers. + * + * @version $Id$ + * @author Darksnake + */ +abstract class MnApplication { + /* package protected */ + var checkAnalyticalDerivatives: Boolean + + /* package protected */ /* package protected */ + var theErrorDef = 1.0 /* package protected */ + var theFCN: MultiFunction? + + /* package protected */ /* package protected */ + var theNumCall /* package protected */ = 0 + var theState: MnUserParameterState + + /* package protected */ + var theStrategy: MnStrategy + + /* package protected */ + var useAnalyticalDerivatives: Boolean + + /* package protected */ + internal constructor(fcn: MultiFunction?, state: MnUserParameterState, stra: MnStrategy) { + theFCN = fcn + theState = state + theStrategy = stra + checkAnalyticalDerivatives = true + useAnalyticalDerivatives = true + } + + internal constructor(fcn: MultiFunction?, state: MnUserParameterState, stra: MnStrategy, nfcn: Int) { + theFCN = fcn + theState = state + theStrategy = stra + theNumCall = nfcn + checkAnalyticalDerivatives = true + useAnalyticalDerivatives = true + } + + /** + * + * MultiFunction. + * + * @return a [MultiFunction] object. + */ + fun MultiFunction(): MultiFunction? { + return theFCN + } + + /** + * add free parameter + * + * @param err a double. + * @param val a double. + * @param name a [String] object. + */ + fun add(name: String, `val`: Double, err: Double) { + theState.add(name, `val`, err) + } + + /** + * add limited parameter + * + * @param up a double. + * @param low a double. + * @param name a [String] object. + * @param val a double. + * @param err a double. + */ + fun add(name: String, `val`: Double, err: Double, low: Double, up: Double) { + theState.add(name, `val`, err, low, up) + } + + /** + * add const parameter + * + * @param name a [String] object. + * @param val a double. + */ + fun add(name: String, `val`: Double) { + theState.add(name, `val`) + } + + /** + * + * checkAnalyticalDerivatives. + * + * @return a boolean. + */ + fun checkAnalyticalDerivatives(): Boolean { + return checkAnalyticalDerivatives + } + + /** + * + * covariance. + * + * @return a [hep.dataforge.MINUIT.MnUserCovariance] object. + */ + fun covariance(): MnUserCovariance { + return theState.covariance() + } + + /** + * + * error. + * + * @param index a int. + * @return a double. + */ + fun error(index: Int): Double { + return theState.error(index) + } + + /** + * + * error. + * + * @param name a [String] object. + * @return a double. + */ + fun error(name: String?): Double { + return theState.error(name) + } + + /** + * + * errorDef. + * + * @return a double. + */ + fun errorDef(): Double { + return theErrorDef + } + + /** + * + * errors. + * + * @return an array of double. + */ + fun errors(): DoubleArray { + return theState.errors() + } + + fun ext2int(i: Int, value: Double): Double { + return theState.ext2int(i, value) + } + + fun extOfInt(i: Int): Int { + return theState.extOfInt(i) + } + //interaction via external number of parameter + /** + * + * fix. + * + * @param index a int. + */ + fun fix(index: Int) { + theState.fix(index) + } + //interaction via name of parameter + /** + * + * fix. + * + * @param name a [String] object. + */ + fun fix(name: String?) { + theState.fix(name) + } + + /** + * convert name into external number of parameter + * + * @param name a [String] object. + * @return a int. + */ + fun index(name: String?): Int { + return theState.index(name) + } + + // transformation internal <-> external + fun int2ext(i: Int, value: Double): Double { + return theState.int2ext(i, value) + } + + fun intOfExt(i: Int): Int { + return theState.intOfExt(i) + } + + /** + * + * minimize. + * + * @return a [hep.dataforge.MINUIT.FunctionMinimum] object. + */ + fun minimize(): FunctionMinimum { + return minimize(DEFAULT_MAXFCN) + } + + /** + * + * minimize. + * + * @param maxfcn a int. + * @return a [hep.dataforge.MINUIT.FunctionMinimum] object. + */ + fun minimize(maxfcn: Int): FunctionMinimum { + return minimize(maxfcn, DEFAULT_TOLER) + } + + /** + * Causes minimization of the FCN and returns the result in form of a + * FunctionMinimum. + * + * @param maxfcn specifies the (approximate) maximum number of function + * calls after which the calculation will be stopped even if it has not yet + * converged. + * @param toler specifies the required tolerance on the function value at + * the minimum. The default tolerance value is 0.1, and the minimization + * will stop when the estimated vertical distance to the minimum (EDM) is + * less than 0:001*tolerance*errorDef + * @return a [hep.dataforge.MINUIT.FunctionMinimum] object. + */ + fun minimize(maxfcn: Int, toler: Double): FunctionMinimum { + var maxfcn = maxfcn + check(theState.isValid()) { "Invalid state" } + val npar = variableParameters() + if (maxfcn == 0) { + maxfcn = 200 + 100 * npar + 5 * npar * npar + } + val min: FunctionMinimum = minimizer().minimize(theFCN, + theState, + theStrategy, + maxfcn, + toler, + theErrorDef, + useAnalyticalDerivatives, + checkAnalyticalDerivatives) + theNumCall += min.nfcn() + theState = min.userState() + return min + } + + abstract fun minimizer(): ModularFunctionMinimizer + + // facade: forward interface of MnUserParameters and MnUserTransformation + fun minuitParameters(): List { + return theState.minuitParameters() + } + + /** + * convert external number into name of parameter + * + * @param index a int. + * @return a [String] object. + */ + fun name(index: Int): String { + return theState.name(index) + } + + /** + * + * numOfCalls. + * + * @return a int. + */ + fun numOfCalls(): Int { + return theNumCall + } + + /** + * access to single parameter + * @param i + * @return + */ + fun parameter(i: Int): MinuitParameter { + return theState.parameter(i) + } + + /** + * + * parameters. + * + * @return a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + fun parameters(): MnUserParameters { + return theState.parameters() + } + + /** + * access to parameters and errors in column-wise representation + * + * @return an array of double. + */ + fun params(): DoubleArray { + return theState.params() + } + + /** + * + * precision. + * + * @return a [hep.dataforge.MINUIT.MnMachinePrecision] object. + */ + fun precision(): MnMachinePrecision { + return theState.precision() + } + + /** + * + * release. + * + * @param index a int. + */ + fun release(index: Int) { + theState.release(index) + } + + /** + * + * release. + * + * @param name a [String] object. + */ + fun release(name: String?) { + theState.release(name) + } + + /** + * + * removeLimits. + * + * @param index a int. + */ + fun removeLimits(index: Int) { + theState.removeLimits(index) + } + + /** + * + * removeLimits. + * + * @param name a [String] object. + */ + fun removeLimits(name: String?) { + theState.removeLimits(name) + } + + /** + * Minuit does a check of the user gradient at the beginning, if this is not + * wanted the set this to "false". + * + * @param check a boolean. + */ + fun setCheckAnalyticalDerivatives(check: Boolean) { + checkAnalyticalDerivatives = check + } + + /** + * + * setError. + * + * @param index a int. + * @param err a double. + */ + fun setError(index: Int, err: Double) { + theState.setError(index, err) + } + + /** + * + * setError. + * + * @param name a [String] object. + * @param err a double. + */ + fun setError(name: String?, err: Double) { + theState.setError(name, err) + } + + /** + * errorDef() is the error definition of the function. E.g. is 1 if function + * is Chi2 and 0.5 if function is -logLikelihood. If the user wants instead + * the 2-sigma errors, errorDef() = 4, as Chi2(x+n*sigma) = Chi2(x) + n*n. + * + * @param errorDef a double. + */ + fun setErrorDef(errorDef: Double) { + theErrorDef = errorDef + } + + /** + * + * setLimits. + * + * @param index a int. + * @param low a double. + * @param up a double. + */ + fun setLimits(index: Int, low: Double, up: Double) { + theState.setLimits(index, low, up) + } + + /** + * + * setLimits. + * + * @param name a [String] object. + * @param low a double. + * @param up a double. + */ + fun setLimits(name: String?, low: Double, up: Double) { + theState.setLimits(name, low, up) + } + + /** + * + * setPrecision. + * + * @param prec a double. + */ + fun setPrecision(prec: Double) { + theState.setPrecision(prec) + } + + /** + * By default if the function to be minimized implements MultiFunction then + * the analytical gradient provided by the function will be used. Set this + * to + * false to disable this behaviour and force numerical + * calculation of the gradient. + * + * @param use a boolean. + */ + fun setUseAnalyticalDerivatives(use: Boolean) { + useAnalyticalDerivatives = use + } + + /** + * + * setValue. + * + * @param index a int. + * @param val a double. + */ + fun setValue(index: Int, `val`: Double) { + theState.setValue(index, `val`) + } + + /** + * + * setValue. + * + * @param name a [String] object. + * @param val a double. + */ + fun setValue(name: String?, `val`: Double) { + theState.setValue(name, `val`) + } + + /** + * + * state. + * + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun state(): MnUserParameterState { + return theState + } + + /** + * + * strategy. + * + * @return a [hep.dataforge.MINUIT.MnStrategy] object. + */ + fun strategy(): MnStrategy { + return theStrategy + } + + /** + * + * useAnalyticalDerivaties. + * + * @return a boolean. + */ + fun useAnalyticalDerivaties(): Boolean { + return useAnalyticalDerivatives + } + + /** + * + * value. + * + * @param index a int. + * @return a double. + */ + fun value(index: Int): Double { + return theState.value(index) + } + + /** + * + * value. + * + * @param name a [String] object. + * @return a double. + */ + fun value(name: String?): Double { + return theState.value(name) + } + + /** + * + * variableParameters. + * + * @return a int. + */ + fun variableParameters(): Int { + return theState.variableParameters() + } + + companion object { + var DEFAULT_MAXFCN = 0 + var DEFAULT_STRATEGY = 1 + var DEFAULT_TOLER = 0.1 + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnContours.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnContours.kt new file mode 100644 index 000000000..1b700f4e2 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnContours.kt @@ -0,0 +1,283 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction +import ru.inr.mass.minuit.* + +/** + * API class for Contours error analysis (2-dim errors). Minimization has to be + * done before and minimum must be valid. Possibility to ask only for the points + * or the points and associated Minos errors. + * + * @version $Id$ + * @author Darksnake + */ +class MnContours(fcn: MultiFunction?, min: FunctionMinimum?, stra: MnStrategy?) { + private var theFCN: MultiFunction? = null + private var theMinimum: FunctionMinimum? = null + private var theStrategy: MnStrategy? = null + + /** + * construct from FCN + minimum + * + * @param fcn a [MultiFunction] object. + * @param min a [hep.dataforge.MINUIT.FunctionMinimum] object. + */ + constructor(fcn: MultiFunction?, min: FunctionMinimum?) : this(fcn, min, MnApplication.DEFAULT_STRATEGY) + + /** + * construct from FCN + minimum + strategy + * + * @param stra a int. + * @param min a [hep.dataforge.MINUIT.FunctionMinimum] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, min: FunctionMinimum?, stra: Int) : this(fcn, min, MnStrategy(stra)) + + /** + * + * contour. + * + * @param px a int. + * @param py a int. + * @return a [hep.dataforge.MINUIT.ContoursError] object. + */ + fun contour(px: Int, py: Int): ContoursError { + return contour(px, py, 1.0) + } + + /** + * + * contour. + * + * @param px a int. + * @param py a int. + * @param errDef a double. + * @return a [hep.dataforge.MINUIT.ContoursError] object. + */ + fun contour(px: Int, py: Int, errDef: Double): ContoursError { + return contour(px, py, errDef, 20) + } + + /** + * Causes a CONTOURS error analysis and returns the result in form of + * ContoursError. As a by-product ContoursError keeps the MinosError + * information of parameters parx and pary. The result ContoursError can be + * easily printed using MnPrint or toString(). + * + * @param npoints a int. + * @param px a int. + * @param py a int. + * @param errDef a double. + * @return a [hep.dataforge.MINUIT.ContoursError] object. + */ + fun contour(px: Int, py: Int, errDef: Double, npoints: Int): ContoursError { + var errDef = errDef + errDef *= theMinimum!!.errorDef() + assert(npoints > 3) + val maxcalls: Int = 100 * (npoints + 5) * (theMinimum!!.userState().variableParameters() + 1) + var nfcn = 0 + val result: MutableList = java.util.ArrayList(npoints) + val states: List = java.util.ArrayList() + val toler = 0.05 + + //get first four points + val minos = MnMinos(theFCN, theMinimum, theStrategy) + val valx: Double = theMinimum!!.userState().value(px) + val valy: Double = theMinimum!!.userState().value(py) + val mex: MinosError = minos.minos(px, errDef) + nfcn += mex.nfcn() + if (!mex.isValid()) { + MINUITPlugin.logStatic("MnContours is unable to find first two points.") + return ContoursError(px, py, result, mex, mex, nfcn) + } + val ex: Range = mex.range() + val mey: MinosError = minos.minos(py, errDef) + nfcn += mey.nfcn() + if (!mey.isValid()) { + MINUITPlugin.logStatic("MnContours is unable to find second two points.") + return ContoursError(px, py, result, mex, mey, nfcn) + } + val ey: Range = mey.range() + val migrad = MnMigrad(theFCN, + theMinimum!!.userState().copy(), + MnStrategy(max(0, theStrategy!!.strategy() - 1))) + migrad.fix(px) + migrad.setValue(px, valx + ex.getSecond()) + val exy_up: FunctionMinimum = migrad.minimize() + nfcn += exy_up.nfcn() + if (!exy_up.isValid()) { + MINUITPlugin.logStatic("MnContours is unable to find upper y value for x parameter $px.") + return ContoursError(px, py, result, mex, mey, nfcn) + } + migrad.setValue(px, valx + ex.getFirst()) + val exy_lo: FunctionMinimum = migrad.minimize() + nfcn += exy_lo.nfcn() + if (!exy_lo.isValid()) { + MINUITPlugin.logStatic("MnContours is unable to find lower y value for x parameter $px.") + return ContoursError(px, py, result, mex, mey, nfcn) + } + val migrad1 = MnMigrad(theFCN, + theMinimum!!.userState().copy(), + MnStrategy(max(0, theStrategy!!.strategy() - 1))) + migrad1.fix(py) + migrad1.setValue(py, valy + ey.getSecond()) + val eyx_up: FunctionMinimum = migrad1.minimize() + nfcn += eyx_up.nfcn() + if (!eyx_up.isValid()) { + MINUITPlugin.logStatic("MnContours is unable to find upper x value for y parameter $py.") + return ContoursError(px, py, result, mex, mey, nfcn) + } + migrad1.setValue(py, valy + ey.getFirst()) + val eyx_lo: FunctionMinimum = migrad1.minimize() + nfcn += eyx_lo.nfcn() + if (!eyx_lo.isValid()) { + MINUITPlugin.logStatic("MnContours is unable to find lower x value for y parameter $py.") + return ContoursError(px, py, result, mex, mey, nfcn) + } + val scalx: Double = 1.0 / (ex.getSecond() - ex.getFirst()) + val scaly: Double = 1.0 / (ey.getSecond() - ey.getFirst()) + result.add(Range(valx + ex.getFirst(), exy_lo.userState().value(py))) + result.add(Range(eyx_lo.userState().value(px), valy + ey.getFirst())) + result.add(Range(valx + ex.getSecond(), exy_up.userState().value(py))) + result.add(Range(eyx_up.userState().value(px), valy + ey.getSecond())) + val upar: MnUserParameterState = theMinimum!!.userState().copy() + upar.fix(px) + upar.fix(py) + val par = intArrayOf(px, py) + val cross = MnFunctionCross(theFCN, upar, theMinimum!!.fval(), theStrategy, errDef) + for (i in 4 until npoints) { + var idist1: Range = result[result.size - 1] + var idist2: Range = result[0] + var pos2 = 0 + val distx: Double = idist1.getFirst() - idist2.getFirst() + val disty: Double = idist1.getSecond() - idist2.getSecond() + var bigdis = scalx * scalx * distx * distx + scaly * scaly * disty * disty + for (j in 0 until result.size - 1) { + val ipair: Range = result[j] + val distx2: Double = ipair.getFirst() - result[j + 1].getFirst() + val disty2: Double = ipair.getSecond() - result[j + 1].getSecond() + val dist = scalx * scalx * distx2 * distx2 + scaly * scaly * disty2 * disty2 + if (dist > bigdis) { + bigdis = dist + idist1 = ipair + idist2 = result[j + 1] + pos2 = j + 1 + } + } + val a1 = 0.5 + val a2 = 0.5 + var sca = 1.0 + while (true) { + if (nfcn > maxcalls) { + MINUITPlugin.logStatic("MnContours: maximum number of function calls exhausted.") + return ContoursError(px, py, result, mex, mey, nfcn) + } + val xmidcr: Double = a1 * idist1.getFirst() + a2 * idist2.getFirst() + val ymidcr: Double = a1 * idist1.getSecond() + a2 * idist2.getSecond() + val xdir: Double = idist2.getSecond() - idist1.getSecond() + val ydir: Double = idist1.getFirst() - idist2.getFirst() + val scalfac: Double = + sca * max(abs(xdir * scalx), abs(ydir * scaly)) + val xdircr = xdir / scalfac + val ydircr = ydir / scalfac + val pmid = doubleArrayOf(xmidcr, ymidcr) + val pdir = doubleArrayOf(xdircr, ydircr) + val opt: MnCross = cross.cross(par, pmid, pdir, toler, maxcalls) + nfcn += opt.nfcn() + if (opt.isValid()) { + val aopt: Double = opt.value() + if (pos2 == 0) { + result.add(Range(xmidcr + aopt * xdircr, ymidcr + aopt * ydircr)) + } else { + result.add(pos2, Range(xmidcr + aopt * xdircr, ymidcr + aopt * ydircr)) + } + break + } + if (sca < 0.0) { + MINUITPlugin.logStatic("MnContours is unable to find point " + (i + 1) + " on contour.") + MINUITPlugin.logStatic("MnContours finds only $i points.") + return ContoursError(px, py, result, mex, mey, nfcn) + } + sca = -1.0 + } + } + return ContoursError(px, py, result, mex, mey, nfcn) + } + + /** + * + * points. + * + * @param px a int. + * @param py a int. + * @return a [List] object. + */ + fun points(px: Int, py: Int): List { + return points(px, py, 1.0) + } + + /** + * + * points. + * + * @param px a int. + * @param py a int. + * @param errDef a double. + * @return a [List] object. + */ + fun points(px: Int, py: Int, errDef: Double): List { + return points(px, py, errDef, 20) + } + + /** + * Calculates one function contour of FCN with respect to parameters parx + * and pary. The return value is a list of (x,y) points. FCN minimized + * always with respect to all other n - 2 variable parameters (if any). + * MINUITPlugin will try to find n points on the contour (default 20). To + * calculate more than one contour, the user needs to set the error + * definition in its FCN to the appropriate value for the desired confidence + * level and call this method for each contour. + * + * @param npoints a int. + * @param px a int. + * @param py a int. + * @param errDef a double. + * @return a [List] object. + */ + fun points(px: Int, py: Int, errDef: Double, npoints: Int): List { + val cont: ContoursError = contour(px, py, errDef, npoints) + return cont.points() + } + + fun strategy(): MnStrategy? { + return theStrategy + } + + /** + * construct from FCN + minimum + strategy + * + * @param stra a [hep.dataforge.MINUIT.MnStrategy] object. + * @param min a [hep.dataforge.MINUIT.FunctionMinimum] object. + * @param fcn a [MultiFunction] object. + */ + init { + theFCN = fcn + theMinimum = min + theStrategy = stra + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnCovarianceSqueeze.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnCovarianceSqueeze.kt new file mode 100644 index 000000000..7614a93b0 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnCovarianceSqueeze.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import space.kscience.kmath.optimization.minuit.MINUITPlugin + +/** + * + * @version $Id$ + */ +internal object MnCovarianceSqueeze { + fun squeeze(cov: MnUserCovariance, n: Int): MnUserCovariance { + assert(cov.nrow() > 0) + assert(n < cov.nrow()) + val hess = MnAlgebraicSymMatrix(cov.nrow()) + for (i in 0 until cov.nrow()) { + for (j in i until cov.nrow()) { + hess[i, j] = cov[i, j] + } + } + try { + hess.invert() + } catch (x: SingularMatrixException) { + MINUITPlugin.logStatic("MnUserCovariance inversion failed; return diagonal matrix;") + val result = MnUserCovariance(cov.nrow() - 1) + var i = 0 + var j = 0 + while (i < cov.nrow()) { + if (i == n) { + i++ + continue + } + result[j, j] = cov[i, i] + j++ + i++ + } + return result + } + val squeezed: MnAlgebraicSymMatrix = squeeze(hess, n) + try { + squeezed.invert() + } catch (x: SingularMatrixException) { + MINUITPlugin.logStatic("MnUserCovariance back-inversion failed; return diagonal matrix;") + val result = MnUserCovariance(squeezed.nrow()) + var i = 0 + while (i < squeezed.nrow()) { + result[i, i] = 1.0 / squeezed[i, i] + i++ + } + return result + } + return MnUserCovariance(squeezed.data(), squeezed.nrow()) + } + + fun squeeze(err: MinimumError, n: Int): MinimumError { + val hess: MnAlgebraicSymMatrix = err.hessian() + val squeezed: MnAlgebraicSymMatrix = squeeze(hess, n) + try { + squeezed.invert() + } catch (x: SingularMatrixException) { + MINUITPlugin.logStatic("MnCovarianceSqueeze: MinimumError inversion fails; return diagonal matrix.") + val tmp = MnAlgebraicSymMatrix(squeezed.nrow()) + var i = 0 + while (i < squeezed.nrow()) { + tmp[i, i] = 1.0 / squeezed[i, i] + i++ + } + return MinimumError(tmp, MnInvertFailed()) + } + return MinimumError(squeezed, err.dcovar()) + } + + fun squeeze(hess: MnAlgebraicSymMatrix, n: Int): MnAlgebraicSymMatrix { + assert(hess.nrow() > 0) + assert(n < hess.nrow()) + val hs = MnAlgebraicSymMatrix(hess.nrow() - 1) + var i = 0 + var j = 0 + while (i < hess.nrow()) { + if (i == n) { + i++ + continue + } + var k = i + var l = j + while (k < hess.nrow()) { + if (k == n) { + k++ + continue + } + hs[j, l] = hess[i, k] + l++ + k++ + } + j++ + i++ + } + return hs + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnCross.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnCross.kt new file mode 100644 index 000000000..f1487b106 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnCross.kt @@ -0,0 +1,99 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * MnCross class. + * + * @version $Id$ + * @author Darksnake + */ +class MnCross { + private var theLimset = false + private var theMaxFcn = false + private var theNFcn = 0 + private var theNewMin = false + private var theState: MnUserParameterState + private var theValid = false + private var theValue = 0.0 + + internal constructor() { + theState = MnUserParameterState() + } + + internal constructor(nfcn: Int) { + theState = MnUserParameterState() + theNFcn = nfcn + } + + internal constructor(value: Double, state: MnUserParameterState, nfcn: Int) { + theValue = value + theState = state + theNFcn = nfcn + theValid = true + } + + internal constructor(state: MnUserParameterState, nfcn: Int, x: CrossParLimit?) { + theState = state + theNFcn = nfcn + theLimset = true + } + + internal constructor(state: MnUserParameterState, nfcn: Int, x: CrossFcnLimit?) { + theState = state + theNFcn = nfcn + theMaxFcn = true + } + + internal constructor(state: MnUserParameterState, nfcn: Int, x: CrossNewMin?) { + theState = state + theNFcn = nfcn + theNewMin = true + } + + fun atLimit(): Boolean { + return theLimset + } + + fun atMaxFcn(): Boolean { + return theMaxFcn + } + + fun isValid(): Boolean { + return theValid + } + + fun newMinimum(): Boolean { + return theNewMin + } + + fun nfcn(): Int { + return theNFcn + } + + fun state(): MnUserParameterState { + return theState + } + + fun value(): Double { + return theValue + } + + internal class CrossFcnLimit + internal class CrossNewMin + internal class CrossParLimit +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnEigen.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnEigen.kt new file mode 100644 index 000000000..d7aade0c9 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnEigen.kt @@ -0,0 +1,50 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.RealVector + +/** + * Calculates and the eigenvalues of the user covariance matrix + * MnUserCovariance. + * + * @version $Id$ + * @author Darksnake + */ +object MnEigen { + /* Calculate eigenvalues of the covariance matrix. + * Will perform the calculation of the eigenvalues of the covariance matrix + * and return the result in the form of a double array. + * The eigenvalues are ordered from the smallest to the largest eigenvalue. + */ + /** + * + * eigenvalues. + * + * @param covar a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @return an array of double. + */ + fun eigenvalues(covar: MnUserCovariance): DoubleArray { + val cov = MnAlgebraicSymMatrix(covar.nrow()) + for (i in 0 until covar.nrow()) { + for (j in i until covar.nrow()) { + cov[i, j] = covar[i, j] + } + } + val eigen: RealVector = cov.eigenvalues() + return eigen.toArray() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnFcn.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnFcn.kt new file mode 100644 index 000000000..b11f71035 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnFcn.kt @@ -0,0 +1,50 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction + +/** + * Функция, которая помнит количество вызовов себя и ErrorDef + * @version $Id$ + */ +class MnFcn(fcn: MultiFunction?, errorDef: Double) { + private val theErrorDef: Double + private val theFCN: MultiFunction? + protected var theNumCall: Int + fun errorDef(): Double { + return theErrorDef + } + + fun fcn(): MultiFunction? { + return theFCN + } + + fun numOfCalls(): Int { + return theNumCall + } + + fun value(v: RealVector): Double { + theNumCall++ + return theFCN.value(v.toArray()) + } + + init { + theFCN = fcn + theNumCall = 0 + theErrorDef = errorDef + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnFunctionCross.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnFunctionCross.kt new file mode 100644 index 000000000..a05590e53 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnFunctionCross.kt @@ -0,0 +1,369 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction +import ru.inr.mass.minuit.* +import kotlin.math.* + +/** + * + * @version $Id$ + */ +internal class MnFunctionCross( + fcn: MultiFunction?, + state: MnUserParameterState, + fval: Double, + stra: MnStrategy?, + errorDef: Double +) { + private val theErrorDef: Double + private val theFCN: MultiFunction? + private val theFval: Double + private val theState: MnUserParameterState + private val theStrategy: MnStrategy? + fun cross(par: IntArray, pmid: DoubleArray, pdir: DoubleArray, tlr: Double, maxcalls: Int): MnCross { + val npar = par.size + var nfcn = 0 + val prec: MnMachinePrecision = theState.precision() + val tlf = tlr * theErrorDef + var tla = tlr + val maxitr = 15 + var ipt = 0 + val aminsv = theFval + val aim = aminsv + theErrorDef + var aopt = 0.0 + var limset = false + val alsb = DoubleArray(3) + val flsb = DoubleArray(3) + val up = theErrorDef + var aulim = 100.0 + for (i in par.indices) { + val kex = par[i] + if (theState.parameter(kex).hasLimits()) { + val zmid = pmid[i] + val zdir = pdir[i] + if (abs(zdir) < theState.precision().eps()) { + continue + } + if (zdir > 0.0 && theState.parameter(kex).hasUpperLimit()) { + val zlim: Double = theState.parameter(kex).upperLimit() + aulim = min(aulim, (zlim - zmid) / zdir) + } else if (zdir < 0.0 && theState.parameter(kex).hasLowerLimit()) { + val zlim: Double = theState.parameter(kex).lowerLimit() + aulim = min(aulim, (zlim - zmid) / zdir) + } + } + } + if (aulim < aopt + tla) { + limset = true + } + val migrad = MnMigrad(theFCN, theState, MnStrategy(max(0, theStrategy!!.strategy() - 1))) + for (i in 0 until npar) { + migrad.setValue(par[i], pmid[i]) + } + val min0: FunctionMinimum = migrad.minimize(maxcalls, tlr) + nfcn += min0.nfcn() + if (min0.hasReachedCallLimit()) { + return MnCross(min0.userState(), nfcn, MnCross.CrossFcnLimit()) + } + if (!min0.isValid()) { + return MnCross(nfcn) + } + if (limset && min0.fval() < aim) { + return MnCross(min0.userState(), nfcn, MnCross.CrossParLimit()) + } + ipt++ + alsb[0] = 0.0 + flsb[0] = min0.fval() + flsb[0] = max(flsb[0], aminsv + 0.1 * up) + aopt = sqrt(up / (flsb[0] - aminsv)) - 1.0 + if (abs(flsb[0] - aim) < tlf) { + return MnCross(aopt, min0.userState(), nfcn) + } + if (aopt > 1.0) { + aopt = 1.0 + } + if (aopt < -0.5) { + aopt = -0.5 + } + limset = false + if (aopt > aulim) { + aopt = aulim + limset = true + } + for (i in 0 until npar) { + migrad.setValue(par[i], pmid[i] + aopt * pdir[i]) + } + var min1: FunctionMinimum = migrad.minimize(maxcalls, tlr) + nfcn += min1.nfcn() + if (min1.hasReachedCallLimit()) { + return MnCross(min1.userState(), nfcn, MnCross.CrossFcnLimit()) + } + if (!min1.isValid()) { + return MnCross(nfcn) + } + if (limset && min1.fval() < aim) { + return MnCross(min1.userState(), nfcn, MnCross.CrossParLimit()) + } + ipt++ + alsb[1] = aopt + flsb[1] = min1.fval() + var dfda = (flsb[1] - flsb[0]) / (alsb[1] - alsb[0]) + var ecarmn = 0.0 + var ecarmx = 0.0 + var ibest = 0 + var iworst = 0 + var noless = 0 + var min2: FunctionMinimum? = null + L300@ while (true) { + if (dfda < 0.0) { + val maxlk = maxitr - ipt + for (it in 0 until maxlk) { + alsb[0] = alsb[1] + flsb[0] = flsb[1] + aopt = alsb[0] + 0.2 * it + limset = false + if (aopt > aulim) { + aopt = aulim + limset = true + } + for (i in 0 until npar) { + migrad.setValue(par[i], pmid[i] + aopt * pdir[i]) + } + min1 = migrad.minimize(maxcalls, tlr) + nfcn += min1.nfcn() + if (min1.hasReachedCallLimit()) { + return MnCross(min1.userState(), nfcn, MnCross.CrossFcnLimit()) + } + if (!min1.isValid()) { + return MnCross(nfcn) + } + if (limset && min1.fval() < aim) { + return MnCross(min1.userState(), nfcn, MnCross.CrossParLimit()) + } + ipt++ + alsb[1] = aopt + flsb[1] = min1.fval() + dfda = (flsb[1] - flsb[0]) / (alsb[1] - alsb[0]) + if (dfda > 0.0) { + break + } + } + if (ipt > maxitr) { + return MnCross(nfcn) + } + } + L460@ while (true) { + aopt = alsb[1] + (aim - flsb[1]) / dfda + val fdist: Double = + min(abs(aim - flsb[0]), abs(aim - flsb[1])) + val adist: Double = + min(abs(aopt - alsb[0]), abs(aopt - alsb[1])) + tla = tlr + if (abs(aopt) > 1.0) { + tla = tlr * abs(aopt) + } + if (adist < tla && fdist < tlf) { + return MnCross(aopt, min1.userState(), nfcn) + } + if (ipt > maxitr) { + return MnCross(nfcn) + } + val bmin: Double = min(alsb[0], alsb[1]) - 1.0 + if (aopt < bmin) { + aopt = bmin + } + val bmax: Double = max(alsb[0], alsb[1]) + 1.0 + if (aopt > bmax) { + aopt = bmax + } + limset = false + if (aopt > aulim) { + aopt = aulim + limset = true + } + for (i in 0 until npar) { + migrad.setValue(par[i], pmid[i] + aopt * pdir[i]) + } + min2 = migrad.minimize(maxcalls, tlr) + nfcn += min2.nfcn() + if (min2.hasReachedCallLimit()) { + return MnCross(min2.userState(), nfcn, CrossFcnLimit()) + } + if (!min2.isValid()) { + return MnCross(nfcn) + } + if (limset && min2.fval() < aim) { + return MnCross(min2.userState(), nfcn, MnCross.CrossParLimit()) + } + ipt++ + alsb[2] = aopt + flsb[2] = min2.fval() + ecarmn = abs(flsb[2] - aim) + ecarmx = 0.0 + ibest = 2 + iworst = 0 + noless = 0 + for (i in 0..2) { + val ecart: Double = abs(flsb[i] - aim) + if (ecart > ecarmx) { + ecarmx = ecart + iworst = i + } + if (ecart < ecarmn) { + ecarmn = ecart + ibest = i + } + if (flsb[i] < aim) { + noless++ + } + } + if (noless == 1 || noless == 2) { + break@L300 + } + if (noless == 0 && ibest != 2) { + return MnCross(nfcn) + } + if (noless == 3 && ibest != 2) { + alsb[1] = alsb[2] + flsb[1] = flsb[2] + continue@L300 + } + flsb[iworst] = flsb[2] + alsb[iworst] = alsb[2] + dfda = (flsb[1] - flsb[0]) / (alsb[1] - alsb[0]) + } + } + do { + val parbol: MnParabola = MnParabolaFactory.create(MnParabolaPoint(alsb[0], flsb[0]), + MnParabolaPoint(alsb[1], flsb[1]), + MnParabolaPoint( + alsb[2], flsb[2])) + val coeff1: Double = parbol.c() + val coeff2: Double = parbol.b() + val coeff3: Double = parbol.a() + val determ = coeff2 * coeff2 - 4.0 * coeff3 * (coeff1 - aim) + if (determ < prec.eps()) { + return MnCross(nfcn) + } + val rt: Double = sqrt(determ) + val x1 = (-coeff2 + rt) / (2.0 * coeff3) + val x2 = (-coeff2 - rt) / (2.0 * coeff3) + val s1 = coeff2 + 2.0 * x1 * coeff3 + val s2 = coeff2 + 2.0 * x2 * coeff3 + if (s1 * s2 > 0.0) { + MINUITPlugin.logStatic("MnFunctionCross problem 1") + } + aopt = x1 + var slope = s1 + if (s2 > 0.0) { + aopt = x2 + slope = s2 + } + tla = tlr + if (abs(aopt) > 1.0) { + tla = tlr * abs(aopt) + } + if (abs(aopt - alsb[ibest]) < tla && abs(flsb[ibest] - aim) < tlf) { + return MnCross(aopt, min2!!.userState(), nfcn) + } + var ileft = 3 + var iright = 3 + var iout = 3 + ibest = 0 + ecarmx = 0.0 + ecarmn = abs(aim - flsb[0]) + for (i in 0..2) { + val ecart: Double = abs(flsb[i] - aim) + if (ecart < ecarmn) { + ecarmn = ecart + ibest = i + } + if (ecart > ecarmx) { + ecarmx = ecart + } + if (flsb[i] > aim) { + if (iright == 3) { + iright = i + } else if (flsb[i] > flsb[iright]) { + iout = i + } else { + iout = iright + iright = i + } + } else if (ileft == 3) { + ileft = i + } else if (flsb[i] < flsb[ileft]) { + iout = i + } else { + iout = ileft + ileft = i + } + } + if (ecarmx > 10.0 * abs(flsb[iout] - aim)) { + aopt = 0.5 * (aopt + 0.5 * (alsb[iright] + alsb[ileft])) + } + var smalla = 0.1 * tla + if (slope * smalla > tlf) { + smalla = tlf / slope + } + val aleft = alsb[ileft] + smalla + val aright = alsb[iright] - smalla + if (aopt < aleft) { + aopt = aleft + } + if (aopt > aright) { + aopt = aright + } + if (aleft > aright) { + aopt = 0.5 * (aleft + aright) + } + limset = false + if (aopt > aulim) { + aopt = aulim + limset = true + } + for (i in 0 until npar) { + migrad.setValue(par[i], pmid[i] + aopt * pdir[i]) + } + min2 = migrad.minimize(maxcalls, tlr) + nfcn += min2.nfcn() + if (min2.hasReachedCallLimit()) { + return MnCross(min2.userState(), nfcn, CrossFcnLimit()) + } + if (!min2.isValid()) { + return MnCross(nfcn) + } + if (limset && min2.fval() < aim) { + return MnCross(min2.userState(), nfcn, CrossParLimit()) + } + ipt++ + alsb[iout] = aopt + flsb[iout] = min2.fval() + ibest = iout + } while (ipt < maxitr) + return MnCross(nfcn) + } + + init { + theFCN = fcn + theState = state + theFval = fval + theStrategy = stra + theErrorDef = errorDef + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnGlobalCorrelationCoeff.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnGlobalCorrelationCoeff.kt new file mode 100644 index 000000000..939dd7fa0 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnGlobalCorrelationCoeff.kt @@ -0,0 +1,79 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.SingularMatrixException + +/** + * + * MnGlobalCorrelationCoeff class. + * + * @version $Id$ + * @author Darksnake + */ +class MnGlobalCorrelationCoeff { + private var theGlobalCC: DoubleArray + private var theValid = false + + internal constructor() { + theGlobalCC = DoubleArray(0) + } + + internal constructor(cov: MnAlgebraicSymMatrix) { + try { + val inv: MnAlgebraicSymMatrix = cov.copy() + inv.invert() + theGlobalCC = DoubleArray(cov.nrow()) + for (i in 0 until cov.nrow()) { + val denom: Double = inv[i, i] * cov[i, i] + if (denom < 1.0 && denom > 0.0) { + theGlobalCC[i] = 0 + } else { + theGlobalCC[i] = sqrt(1.0 - 1.0 / denom) + } + } + theValid = true + } catch (x: SingularMatrixException) { + theValid = false + theGlobalCC = DoubleArray(0) + } + } + + /** + * + * globalCC. + * + * @return an array of double. + */ + fun globalCC(): DoubleArray { + return theGlobalCC + } + + /** + * + * isValid. + * + * @return a boolean. + */ + fun isValid(): Boolean { + return theValid + } + + /** {@inheritDoc} */ + override fun toString(): String { + return MnPrint.toString(this) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnHesse.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnHesse.kt new file mode 100644 index 000000000..3bb6c4551 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnHesse.kt @@ -0,0 +1,371 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction +import ru.inr.mass.minuit.* + +/** + * With MnHesse the user can instructs MINUITPlugin to calculate, by finite + * differences, the Hessian or error matrix. That is, it calculates the full + * matrix of second derivatives of the function with respect to the currently + * variable parameters, and inverts it. + * + * @version $Id$ + * @author Darksnake + */ +class MnHesse { + private var theStrategy: MnStrategy + + /** + * default constructor with default strategy + */ + constructor() { + theStrategy = MnStrategy(1) + } + + /** + * constructor with user-defined strategy level + * + * @param stra a int. + */ + constructor(stra: Int) { + theStrategy = MnStrategy(stra) + } + + /** + * conctructor with specific strategy + * + * @param stra a [hep.dataforge.MINUIT.MnStrategy] object. + */ + constructor(stra: MnStrategy) { + theStrategy = stra + } + /// + /// low-level API + /// + /** + * + * calculate. + * + * @param fcn a [MultiFunction] object. + * @param par an array of double. + * @param err an array of double. + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun calculate(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray): MnUserParameterState { + return calculate(fcn, par, err, 0) + } + + /** + * FCN + parameters + errors + * + * @param maxcalls a int. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + * @param err an array of double. + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun calculate(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray, maxcalls: Int): MnUserParameterState { + return calculate(fcn, MnUserParameterState(par, err), maxcalls) + } + + /** + * + * calculate. + * + * @param fcn a [MultiFunction] object. + * @param par an array of double. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun calculate(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance): MnUserParameterState { + return calculate(fcn, par, cov, 0) + } + + /** + * FCN + parameters + MnUserCovariance + * + * @param maxcalls a int. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun calculate(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance, maxcalls: Int): MnUserParameterState { + return calculate(fcn, MnUserParameterState(par, cov), maxcalls) + } + /// + /// high-level API + /// + /** + * + * calculate. + * + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun calculate(fcn: MultiFunction?, par: MnUserParameters): MnUserParameterState { + return calculate(fcn, par, 0) + } + + /** + * FCN + MnUserParameters + * + * @param maxcalls a int. + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun calculate(fcn: MultiFunction?, par: MnUserParameters, maxcalls: Int): MnUserParameterState { + return calculate(fcn, MnUserParameterState(par), maxcalls) + } + + /** + * + * calculate. + * + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun calculate(fcn: MultiFunction?, par: MnUserParameters, cov: MnUserCovariance?): MnUserParameterState { + return calculate(fcn, par, 0) + } + + /** + * FCN + MnUserParameters + MnUserCovariance + * + * @param maxcalls a int. + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun calculate( + fcn: MultiFunction?, + par: MnUserParameters, + cov: MnUserCovariance, + maxcalls: Int + ): MnUserParameterState { + return calculate(fcn, MnUserParameterState(par, cov), maxcalls) + } + + /** + * FCN + MnUserParameterState + * + * @param maxcalls a int. + * @param fcn a [MultiFunction] object. + * @param state a [hep.dataforge.MINUIT.MnUserParameterState] object. + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun calculate(fcn: MultiFunction?, state: MnUserParameterState, maxcalls: Int): MnUserParameterState { + val errDef = 1.0 // FixMe! + val n: Int = state.variableParameters() + val mfcn = MnUserFcn(fcn, errDef, state.getTransformation()) + val x: RealVector = ArrayRealVector(n) + for (i in 0 until n) { + x.setEntry(i, state.intParameters()[i]) + } + val amin: Double = mfcn.value(x) + val gc = Numerical2PGradientCalculator(mfcn, state.getTransformation(), theStrategy) + val par = MinimumParameters(x, amin) + val gra: FunctionGradient = gc.gradient(par) + val tmp: MinimumState = calculate(mfcn, + MinimumState(par, MinimumError(MnAlgebraicSymMatrix(n), 1.0), gra, state.edm(), state.nfcn()), + state.getTransformation(), + maxcalls) + return MnUserParameterState(tmp, errDef, state.getTransformation()) + } + + /// + /// internal interface + /// + fun calculate(mfcn: MnFcn, st: MinimumState, trafo: MnUserTransformation, maxcalls: Int): MinimumState { + var maxcalls = maxcalls + val prec: MnMachinePrecision = trafo.precision() + // make sure starting at the right place + val amin: Double = mfcn.value(st.vec()) + val aimsag: Double = sqrt(prec.eps2()) * (abs(amin) + mfcn.errorDef()) + + // diagonal elements first + val n: Int = st.parameters().vec().getDimension() + if (maxcalls == 0) { + maxcalls = 200 + 100 * n + 5 * n * n + } + var vhmat = MnAlgebraicSymMatrix(n) + var g2: RealVector = st.gradient().getGradientDerivative().copy() + var gst: RealVector = st.gradient().getStep().copy() + var grd: RealVector = st.gradient().getGradient().copy() + var dirin: RealVector = st.gradient().getStep().copy() + val yy: RealVector = ArrayRealVector(n) + if (st.gradient().isAnalytical()) { + val igc = InitialGradientCalculator(mfcn, trafo, theStrategy) + val tmp: FunctionGradient = igc.gradient(st.parameters()) + gst = tmp.getStep().copy() + dirin = tmp.getStep().copy() + g2 = tmp.getGradientDerivative().copy() + } + return try { + val x: RealVector = st.parameters().vec().copy() + for (i in 0 until n) { + val xtf: Double = x.getEntry(i) + val dmin: Double = 8.0 * prec.eps2() * (abs(xtf) + prec.eps2()) + var d: Double = abs(gst.getEntry(i)) + if (d < dmin) { + d = dmin + } + for (icyc in 0 until ncycles()) { + var sag = 0.0 + var fs1 = 0.0 + var fs2 = 0.0 + var multpy = 0 + while (multpy < 5) { + x.setEntry(i, xtf + d) + fs1 = mfcn.value(x) + x.setEntry(i, xtf - d) + fs2 = mfcn.value(x) + x.setEntry(i, xtf) + sag = 0.5 * (fs1 + fs2 - 2.0 * amin) + if (sag > prec.eps2()) { + break + } + if (trafo.parameter(i).hasLimits()) { + if (d > 0.5) { + throw MnHesseFailedException("MnHesse: 2nd derivative zero for parameter") + } + d *= 10.0 + if (d > 0.5) { + d = 0.51 + } + multpy++ + continue + } + d *= 10.0 + multpy++ + } + if (multpy >= 5) { + throw MnHesseFailedException("MnHesse: 2nd derivative zero for parameter") + } + val g2bfor: Double = g2.getEntry(i) + g2.setEntry(i, 2.0 * sag / (d * d)) + grd.setEntry(i, (fs1 - fs2) / (2.0 * d)) + gst.setEntry(i, d) + dirin.setEntry(i, d) + yy.setEntry(i, fs1) + val dlast = d + d = sqrt(2.0 * aimsag / abs(g2.getEntry(i))) + if (trafo.parameter(i).hasLimits()) { + d = min(0.5, d) + } + if (d < dmin) { + d = dmin + } + + // see if converged + if (abs((d - dlast) / d) < tolerstp()) { + break + } + if (abs((g2.getEntry(i) - g2bfor) / g2.getEntry(i)) < tolerg2()) { + break + } + d = min(d, 10.0 * dlast) + d = max(d, 0.1 * dlast) + } + vhmat[i, i] = g2.getEntry(i) + if (mfcn.numOfCalls() - st.nfcn() > maxcalls) { + throw MnHesseFailedException("MnHesse: maximum number of allowed function calls exhausted.") + } + } + if (theStrategy.strategy() > 0) { + // refine first derivative + val hgc = HessianGradientCalculator(mfcn, trafo, theStrategy) + val gr: FunctionGradient = hgc.gradient(st.parameters(), FunctionGradient(grd, g2, gst)) + grd = gr.getGradient() + } + + //off-diagonal elements + for (i in 0 until n) { + x.setEntry(i, x.getEntry(i) + dirin.getEntry(i)) + for (j in i + 1 until n) { + x.setEntry(j, x.getEntry(j) + dirin.getEntry(j)) + val fs1: Double = mfcn.value(x) + val elem: Double = + (fs1 + amin - yy.getEntry(i) - yy.getEntry(j)) / (dirin.getEntry(i) * dirin.getEntry(j)) + vhmat[i, j] = elem + x.setEntry(j, x.getEntry(j) - dirin.getEntry(j)) + } + x.setEntry(i, x.getEntry(i) - dirin.getEntry(i)) + } + + //verify if matrix pos-def (still 2nd derivative) + val tmp: MinimumError = MnPosDef.test(MinimumError(vhmat, 1.0), prec) + vhmat = tmp.invHessian() + try { + vhmat.invert() + } catch (xx: SingularMatrixException) { + throw MnHesseFailedException("MnHesse: matrix inversion fails!") + } + val gr = FunctionGradient(grd, g2, gst) + if (tmp.isMadePosDef()) { + MINUITPlugin.logStatic("MnHesse: matrix is invalid!") + MINUITPlugin.logStatic("MnHesse: matrix is not pos. def.!") + MINUITPlugin.logStatic("MnHesse: matrix was forced pos. def.") + return MinimumState(st.parameters(), + MinimumError(vhmat, MnMadePosDef()), + gr, + st.edm(), + mfcn.numOfCalls()) + } + + //calculate edm + val err = MinimumError(vhmat, 0.0) + val edm: Double = VariableMetricEDMEstimator().estimate(gr, err) + MinimumState(st.parameters(), err, gr, edm, mfcn.numOfCalls()) + } catch (x: MnHesseFailedException) { + MINUITPlugin.logStatic(x.message) + MINUITPlugin.logStatic("MnHesse fails and will return diagonal matrix ") + var j = 0 + while (j < n) { + val tmp = if (g2.getEntry(j) < prec.eps2()) 1.0 else 1.0 / g2.getEntry(j) + vhmat[j, j] = if (tmp < prec.eps2()) 1.0 else tmp + j++ + } + MinimumState(st.parameters(), + MinimumError(vhmat, MnHesseFailed()), + st.gradient(), + st.edm(), + st.nfcn() + mfcn.numOfCalls()) + } + } + + /// forward interface of MnStrategy + fun ncycles(): Int { + return theStrategy.hessianNCycles() + } + + fun tolerg2(): Double { + return theStrategy.hessianG2Tolerance() + } + + fun tolerstp(): Double { + return theStrategy.hessianStepTolerance() + } + + private inner class MnHesseFailedException(message: String?) : java.lang.Exception(message) +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnLineSearch.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnLineSearch.kt new file mode 100644 index 000000000..7b1171d3c --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnLineSearch.kt @@ -0,0 +1,204 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.RealVector +import ru.inr.mass.minuit.* + +/** + * + * @version $Id$ + */ +internal object MnLineSearch { + fun search( + fcn: MnFcn, + st: MinimumParameters, + step: RealVector, + gdel: Double, + prec: MnMachinePrecision + ): MnParabolaPoint { + var overal = 1000.0 + var undral = -100.0 + val toler = 0.05 + var slamin = 0.0 + val slambg = 5.0 + val alpha = 2.0 + val maxiter = 12 + var niter = 0 + for (i in 0 until step.getDimension()) { + if (abs(step.getEntry(i)) < prec.eps()) { + continue + } + val ratio: Double = abs(st.vec().getEntry(i) / step.getEntry(i)) + if (abs(slamin) < prec.eps()) { + slamin = ratio + } + if (ratio < slamin) { + slamin = ratio + } + } + if (abs(slamin) < prec.eps()) { + slamin = prec.eps() + } + slamin *= prec.eps2() + val F0: Double = st.fval() + val F1: Double = fcn.value(MnUtils.add(st.vec(), step)) + var fvmin: Double = st.fval() + var xvmin = 0.0 + if (F1 < F0) { + fvmin = F1 + xvmin = 1.0 + } + var toler8 = toler + var slamax = slambg + var flast = F1 + var slam = 1.0 + var iterate = false + var p0 = MnParabolaPoint(0.0, F0) + var p1 = MnParabolaPoint(slam, flast) + var F2 = 0.0 + do { + // cut toler8 as function goes up + iterate = false + val pb: MnParabola = MnParabolaFactory.create(p0, gdel, p1) + var denom = 2.0 * (flast - F0 - gdel * slam) / (slam * slam) + if (abs(denom) < prec.eps()) { + denom = -0.1 * gdel + slam = 1.0 + } + if (abs(denom) > prec.eps()) { + slam = -gdel / denom + } + if (slam < 0.0) { + slam = slamax + } + if (slam > slamax) { + slam = slamax + } + if (slam < toler8) { + slam = toler8 + } + if (slam < slamin) { + return MnParabolaPoint(xvmin, fvmin) + } + if (abs(slam - 1.0) < toler8 && p1.y() < p0.y()) { + return MnParabolaPoint(xvmin, fvmin) + } + if (abs(slam - 1.0) < toler8) { + slam = 1.0 + toler8 + } + F2 = fcn.value(MnUtils.add(st.vec(), MnUtils.mul(step, slam))) + if (F2 < fvmin) { + fvmin = F2 + xvmin = slam + } + if (p0.y() - prec.eps() < fvmin && fvmin < p0.y() + prec.eps()) { + iterate = true + flast = F2 + toler8 = toler * slam + overal = slam - toler8 + slamax = overal + p1 = MnParabolaPoint(slam, flast) + niter++ + } + } while (iterate && niter < maxiter) + if (niter >= maxiter) { + // exhausted max number of iterations + return MnParabolaPoint(xvmin, fvmin) + } + var p2 = MnParabolaPoint(slam, F2) + do { + slamax = max(slamax, alpha * abs(xvmin)) + val pb: MnParabola = MnParabolaFactory.create(p0, p1, p2) + if (pb.a() < prec.eps2()) { + val slopem: Double = 2.0 * pb.a() * xvmin + pb.b() + slam = if (slopem < 0.0) { + xvmin + slamax + } else { + xvmin - slamax + } + } else { + slam = pb.min() + if (slam > xvmin + slamax) { + slam = xvmin + slamax + } + if (slam < xvmin - slamax) { + slam = xvmin - slamax + } + } + if (slam > 0.0) { + if (slam > overal) { + slam = overal + } + } else { + if (slam < undral) { + slam = undral + } + } + var F3 = 0.0 + do { + iterate = false + val toler9: Double = max(toler8, abs(toler8 * slam)) + // min. of parabola at one point + if (abs(p0.x() - slam) < toler9 || abs(p1.x() - slam) < toler9 || abs( + p2.x() - slam) < toler9 + ) { + return MnParabolaPoint(xvmin, fvmin) + } + F3 = fcn.value(MnUtils.add(st.vec(), MnUtils.mul(step, slam))) + // if latest point worse than all three previous, cut step + if (F3 > p0.y() && F3 > p1.y() && F3 > p2.y()) { + if (slam > xvmin) { + overal = min(overal, slam - toler8) + } + if (slam < xvmin) { + undral = max(undral, slam + toler8) + } + slam = 0.5 * (slam + xvmin) + iterate = true + niter++ + } + } while (iterate && niter < maxiter) + if (niter >= maxiter) { + // exhausted max number of iterations + return MnParabolaPoint(xvmin, fvmin) + } + + // find worst previous point out of three and replace + val p3 = MnParabolaPoint(slam, F3) + if (p0.y() > p1.y() && p0.y() > p2.y()) { + p0 = p3 + } else if (p1.y() > p0.y() && p1.y() > p2.y()) { + p1 = p3 + } else { + p2 = p3 + } + if (F3 < fvmin) { + fvmin = F3 + xvmin = slam + } else { + if (slam > xvmin) { + overal = min(overal, slam - toler8) + } + if (slam < xvmin) { + undral = max(undral, slam + toler8) + } + } + niter++ + } while (niter < maxiter) + return MnParabolaPoint(xvmin, fvmin) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnMachinePrecision.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnMachinePrecision.kt new file mode 100644 index 000000000..161ee0c0a --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnMachinePrecision.kt @@ -0,0 +1,71 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * Determines the relative floating point arithmetic precision. The + * setPrecision() method can be used to override Minuit's own determination, + * when the user knows that the {FCN} function value is not calculated to the + * nominal machine accuracy. + * + * @version $Id$ + * @author Darksnake + */ +class MnMachinePrecision internal constructor() { + private var theEpsMa2 = 0.0 + private var theEpsMac = 0.0 + + /** + * eps returns the smallest possible number so that 1.+eps > 1. + * @return + */ + fun eps(): Double { + return theEpsMac + } + + /** + * eps2 returns 2*sqrt(eps) + * @return + */ + fun eps2(): Double { + return theEpsMa2 + } + + /** + * override Minuit's own determination + * + * @param prec a double. + */ + fun setPrecision(prec: Double) { + theEpsMac = prec + theEpsMa2 = 2.0 * sqrt(theEpsMac) + } + + init { + setPrecision(4.0E-7) + var epstry = 0.5 + val one = 1.0 + for (i in 0..99) { + epstry *= 0.5 + val epsp1 = one + epstry + val epsbak = epsp1 - one + if (epsbak < epstry) { + setPrecision(8.0 * epstry) + break + } + } + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnMigrad.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnMigrad.kt new file mode 100644 index 000000000..22616a1a6 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnMigrad.kt @@ -0,0 +1,136 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction + +/** + * MnMigrad provides minimization of the function by the method of MIGRAD, the + * most efficient and complete single method, recommended for general functions, + * and the functionality for parameters interaction. It also retains the result + * from the last minimization in case the user may want to do subsequent + * minimization steps with parameter interactions in between the minimization + * requests. The minimization produces as a by-product the error matrix of the + * parameters, which is usually reliable unless warning messages are produced. + * + * @version $Id$ + * @author Darksnake + */ +class MnMigrad +/** + * construct from MultiFunction + MnUserParameterState + MnStrategy + * + * @param str a [hep.dataforge.MINUIT.MnStrategy] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameterState] object. + * @param fcn a [MultiFunction] object. + */ + (fcn: MultiFunction?, par: MnUserParameterState, str: MnStrategy) : MnApplication(fcn, par, str) { + private val theMinimizer: VariableMetricMinimizer = VariableMetricMinimizer() + + /** + * construct from MultiFunction + double[] for parameters and errors + * with default strategy + * + * @param err an array of double. + * @param par an array of double. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray) : this(fcn, par, err, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + double[] for parameters and errors + * + * @param stra a int. + * @param err an array of double. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray, stra: Int) : this(fcn, + MnUserParameterState(par, err), + MnStrategy(stra)) + + /** + * construct from MultiFunction + double[] for parameters and + * MnUserCovariance with default strategy + * + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param par an array of double. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance) : this(fcn, par, cov, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + double[] for parameters and + * MnUserCovariance + * + * @param stra a int. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance, stra: Int) : this(fcn, + MnUserParameterState(par, cov), + MnStrategy(stra)) + + /** + * construct from MultiFunction + MnUserParameters with default + * strategy + * + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters) : this(fcn, par, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + MnUserParameters + * + * @param stra a int. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, stra: Int) : this(fcn, + MnUserParameterState(par), + MnStrategy(stra)) + + /** + * construct from MultiFunction + MnUserParameters + MnUserCovariance + * with default strategy + * + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, cov: MnUserCovariance) : this(fcn, + par, + cov, + DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + MnUserParameters + MnUserCovariance + * + * @param stra a int. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, cov: MnUserCovariance, stra: Int) : this(fcn, + MnUserParameterState(par, cov), + MnStrategy(stra)) + + override fun minimizer(): ModularFunctionMinimizer { + return theMinimizer + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnMinimize.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnMinimize.kt new file mode 100644 index 000000000..ea14a5453 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnMinimize.kt @@ -0,0 +1,133 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction + +/** + * Causes minimization of the function by the method of MIGRAD, as does the + * MnMigrad class, but switches to the SIMPLEX method if MIGRAD fails to + * converge. Constructor arguments, methods arguments and names of methods are + * the same as for MnMigrad or MnSimplex. + * + * @version $Id$ + * @author Darksnake + */ +class MnMinimize +/** + * construct from MultiFunction + MnUserParameterState + MnStrategy + * + * @param str a [hep.dataforge.MINUIT.MnStrategy] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameterState] object. + * @param fcn a [MultiFunction] object. + */ + (fcn: MultiFunction?, par: MnUserParameterState, str: MnStrategy) : MnApplication(fcn, par, str) { + private val theMinimizer: CombinedMinimizer = CombinedMinimizer() + + /** + * construct from MultiFunction + double[] for parameters and errors + * with default strategy + * + * @param err an array of double. + * @param par an array of double. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray) : this(fcn, par, err, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + double[] for parameters and errors + * + * @param stra a int. + * @param err an array of double. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray, stra: Int) : this(fcn, + MnUserParameterState(par, err), + MnStrategy(stra)) + + /** + * construct from MultiFunction + double[] for parameters and + * MnUserCovariance with default strategy + * + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param par an array of double. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance) : this(fcn, par, cov, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + double[] for parameters and + * MnUserCovariance + * + * @param stra a int. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance, stra: Int) : this(fcn, + MnUserParameterState(par, cov), + MnStrategy(stra)) + + /** + * construct from MultiFunction + MnUserParameters with default + * strategy + * + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters) : this(fcn, par, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + MnUserParameters + * + * @param stra a int. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, stra: Int) : this(fcn, + MnUserParameterState(par), + MnStrategy(stra)) + + /** + * construct from MultiFunction + MnUserParameters + MnUserCovariance + * with default strategy + * + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, cov: MnUserCovariance) : this(fcn, + par, + cov, + DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + MnUserParameters + MnUserCovariance + * + * @param stra a int. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, cov: MnUserCovariance, stra: Int) : this(fcn, + MnUserParameterState(par, cov), + MnStrategy(stra)) + + override fun minimizer(): ModularFunctionMinimizer { + return theMinimizer + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnMinos.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnMinos.kt new file mode 100644 index 000000000..d49379b3b --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnMinos.kt @@ -0,0 +1,379 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction +import ru.inr.mass.minuit.* +import kotlin.jvm.JvmOverloads + +/** + * API class for Minos error analysis (asymmetric errors). Minimization has to + * be done before and minimum must be valid; possibility to ask only for one + * side of the Minos error; + * + * @version $Id$ + * @author Darksnake + */ +class MnMinos(fcn: MultiFunction?, min: FunctionMinimum?, stra: MnStrategy?) { + private var theFCN: MultiFunction? = null + private var theMinimum: FunctionMinimum? = null + private var theStrategy: MnStrategy? = null + + /** + * construct from FCN + minimum + * + * @param fcn a [MultiFunction] object. + * @param min a [hep.dataforge.MINUIT.FunctionMinimum] object. + */ + constructor(fcn: MultiFunction?, min: FunctionMinimum?) : this(fcn, min, MnApplication.DEFAULT_STRATEGY) + + /** + * construct from FCN + minimum + strategy + * + * @param stra a int. + * @param min a [hep.dataforge.MINUIT.FunctionMinimum] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, min: FunctionMinimum?, stra: Int) : this(fcn, min, MnStrategy(stra)) + // public MnMinos(MultiFunction fcn, MnUserParameterState state, double errDef, MnStrategy stra) { + // theFCN = fcn; + // theStrategy = stra; + // + // MinimumState minState = null; + // + // MnUserTransformation transformation = state.getTransformation(); + // + // MinimumSeed seed = new MinimumSeed(minState, transformation); + // + // theMinimum = new FunctionMinimum(seed,errDef); + // } + /** + * + * loval. + * + * @param par a int. + * @return a [hep.dataforge.MINUIT.MnCross] object. + */ + fun loval(par: Int): MnCross { + return loval(par, 1.0) + } + + /** + * + * loval. + * + * @param par a int. + * @param errDef a double. + * @return a [hep.dataforge.MINUIT.MnCross] object. + */ + fun loval(par: Int, errDef: Double): MnCross { + return loval(par, errDef, MnApplication.DEFAULT_MAXFCN) + } + + /** + * + * loval. + * + * @param par a int. + * @param errDef a double. + * @param maxcalls a int. + * @return a [hep.dataforge.MINUIT.MnCross] object. + */ + fun loval(par: Int, errDef: Double, maxcalls: Int): MnCross { + var errDef = errDef + var maxcalls = maxcalls + errDef *= theMinimum!!.errorDef() + assert(theMinimum!!.isValid()) + assert(!theMinimum!!.userState().parameter(par).isFixed()) + assert(!theMinimum!!.userState().parameter(par).isConst()) + if (maxcalls == 0) { + val nvar: Int = theMinimum!!.userState().variableParameters() + maxcalls = 2 * (nvar + 1) * (200 + 100 * nvar + 5 * nvar * nvar) + } + val para = intArrayOf(par) + val upar: MnUserParameterState = theMinimum!!.userState().copy() + val err: Double = upar.error(par) + val `val`: Double = upar.value(par) - err + val xmid = doubleArrayOf(`val`) + val xdir = doubleArrayOf(-err) + val ind: Int = upar.intOfExt(par) + val m: MnAlgebraicSymMatrix = theMinimum!!.error().matrix() + val xunit: Double = sqrt(errDef / err) + for (i in 0 until m.nrow()) { + if (i == ind) { + continue + } + val xdev: Double = xunit * m[ind, i] + val ext: Int = upar.extOfInt(i) + upar.setValue(ext, upar.value(ext) - xdev) + } + upar.fix(par) + upar.setValue(par, `val`) + val toler = 0.1 + val cross = MnFunctionCross(theFCN, upar, theMinimum!!.fval(), theStrategy, errDef) + val aopt: MnCross = cross.cross(para, xmid, xdir, toler, maxcalls) + if (aopt.atLimit()) { + MINUITPlugin.logStatic("MnMinos parameter $par is at lower limit.") + } + if (aopt.atMaxFcn()) { + MINUITPlugin.logStatic("MnMinos maximum number of function calls exceeded for parameter $par") + } + if (aopt.newMinimum()) { + MINUITPlugin.logStatic("MnMinos new minimum found while looking for parameter $par") + } + if (!aopt.isValid()) { + MINUITPlugin.logStatic("MnMinos could not find lower value for parameter $par.") + } + return aopt + } + /** + * calculate one side (negative or positive error) of the parameter + * + * @param maxcalls a int. + * @param par a int. + * @param errDef a double. + * @return a double. + */ + /** + * + * lower. + * + * @param par a int. + * @param errDef a double. + * @return a double. + */ + /** + * + * lower. + * + * @param par a int. + * @return a double. + */ + @JvmOverloads + fun lower(par: Int, errDef: Double = 1.0, maxcalls: Int = MnApplication.DEFAULT_MAXFCN): Double { + val upar: MnUserParameterState = theMinimum!!.userState() + val err: Double = theMinimum!!.userState().error(par) + val aopt: MnCross = loval(par, errDef, maxcalls) + return if (aopt.isValid()) -1.0 * err * (1.0 + aopt.value()) else if (aopt.atLimit()) upar.parameter(par) + .lowerLimit() else upar.value(par) + } + + /** + * + * minos. + * + * @param par a int. + * @return a [hep.dataforge.MINUIT.MinosError] object. + */ + fun minos(par: Int): MinosError { + return minos(par, 1.0) + } + + /** + * + * minos. + * + * @param par a int. + * @param errDef a double. + * @return a [hep.dataforge.MINUIT.MinosError] object. + */ + fun minos(par: Int, errDef: Double): MinosError { + return minos(par, errDef, MnApplication.DEFAULT_MAXFCN) + } + + /** + * Causes a MINOS error analysis to be performed on the parameter whose + * number is specified. MINOS errors may be expensive to calculate, but are + * very reliable since they take account of non-linearities in the problem + * as well as parameter correlations, and are in general asymmetric. + * + * @param maxcalls Specifies the (approximate) maximum number of function + * calls per parameter requested, after which the calculation will be + * stopped for that parameter. + * @param errDef a double. + * @param par a int. + * @return a [hep.dataforge.MINUIT.MinosError] object. + */ + fun minos(par: Int, errDef: Double, maxcalls: Int): MinosError { + assert(theMinimum!!.isValid()) + assert(!theMinimum!!.userState().parameter(par).isFixed()) + assert(!theMinimum!!.userState().parameter(par).isConst()) + val up: MnCross = upval(par, errDef, maxcalls) + val lo: MnCross = loval(par, errDef, maxcalls) + return MinosError(par, theMinimum!!.userState().value(par), lo, up) + } + + /** + * + * range. + * + * @param par a int. + * @return + */ + fun range(par: Int): Range { + return range(par, 1.0) + } + + /** + * + * range. + * + * @param par a int. + * @param errDef a double. + * @return + */ + fun range(par: Int, errDef: Double): Range { + return range(par, errDef, MnApplication.DEFAULT_MAXFCN) + } + + /** + * Causes a MINOS error analysis for external parameter n. + * + * @param maxcalls a int. + * @param errDef a double. + * @return The lower and upper bounds of parameter + * @param par a int. + */ + fun range(par: Int, errDef: Double, maxcalls: Int): Range { + val mnerr: MinosError = minos(par, errDef, maxcalls) + return mnerr.range() + } + /** + * + * upper. + * + * @param par a int. + * @param errDef a double. + * @param maxcalls a int. + * @return a double. + */ + /** + * + * upper. + * + * @param par a int. + * @param errDef a double. + * @return a double. + */ + /** + * + * upper. + * + * @param par a int. + * @return a double. + */ + @JvmOverloads + fun upper(par: Int, errDef: Double = 1.0, maxcalls: Int = MnApplication.DEFAULT_MAXFCN): Double { + val upar: MnUserParameterState = theMinimum!!.userState() + val err: Double = theMinimum!!.userState().error(par) + val aopt: MnCross = upval(par, errDef, maxcalls) + return if (aopt.isValid()) err * (1.0 + aopt.value()) else if (aopt.atLimit()) upar.parameter(par) + .upperLimit() else upar.value(par) + } + + /** + * + * upval. + * + * @param par a int. + * @return a [hep.dataforge.MINUIT.MnCross] object. + */ + fun upval(par: Int): MnCross { + return upval(par, 1.0) + } + + /** + * + * upval. + * + * @param par a int. + * @param errDef a double. + * @return a [hep.dataforge.MINUIT.MnCross] object. + */ + fun upval(par: Int, errDef: Double): MnCross { + return upval(par, errDef, MnApplication.DEFAULT_MAXFCN) + } + + /** + * + * upval. + * + * @param par a int. + * @param errDef a double. + * @param maxcalls a int. + * @return a [hep.dataforge.MINUIT.MnCross] object. + */ + fun upval(par: Int, errDef: Double, maxcalls: Int): MnCross { + var errDef = errDef + var maxcalls = maxcalls + errDef *= theMinimum!!.errorDef() + assert(theMinimum!!.isValid()) + assert(!theMinimum!!.userState().parameter(par).isFixed()) + assert(!theMinimum!!.userState().parameter(par).isConst()) + if (maxcalls == 0) { + val nvar: Int = theMinimum!!.userState().variableParameters() + maxcalls = 2 * (nvar + 1) * (200 + 100 * nvar + 5 * nvar * nvar) + } + val para = intArrayOf(par) + val upar: MnUserParameterState = theMinimum!!.userState().copy() + val err: Double = upar.error(par) + val `val`: Double = upar.value(par) + err + val xmid = doubleArrayOf(`val`) + val xdir = doubleArrayOf(err) + val ind: Int = upar.intOfExt(par) + val m: MnAlgebraicSymMatrix = theMinimum!!.error().matrix() + val xunit: Double = sqrt(errDef / err) + for (i in 0 until m.nrow()) { + if (i == ind) { + continue + } + val xdev: Double = xunit * m[ind, i] + val ext: Int = upar.extOfInt(i) + upar.setValue(ext, upar.value(ext) + xdev) + } + upar.fix(par) + upar.setValue(par, `val`) + val toler = 0.1 + val cross = MnFunctionCross(theFCN, upar, theMinimum!!.fval(), theStrategy, errDef) + val aopt: MnCross = cross.cross(para, xmid, xdir, toler, maxcalls) + if (aopt.atLimit()) { + MINUITPlugin.logStatic("MnMinos parameter $par is at upper limit.") + } + if (aopt.atMaxFcn()) { + MINUITPlugin.logStatic("MnMinos maximum number of function calls exceeded for parameter $par") + } + if (aopt.newMinimum()) { + MINUITPlugin.logStatic("MnMinos new minimum found while looking for parameter $par") + } + if (!aopt.isValid()) { + MINUITPlugin.logStatic("MnMinos could not find upper value for parameter $par.") + } + return aopt + } + + /** + * construct from FCN + minimum + strategy + * + * @param stra a [hep.dataforge.MINUIT.MnStrategy] object. + * @param min a [hep.dataforge.MINUIT.FunctionMinimum] object. + * @param fcn a [MultiFunction] object. + */ + init { + theFCN = fcn + theMinimum = min + theStrategy = stra + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnParabola.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnParabola.kt new file mode 100644 index 000000000..a0a56dedd --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnParabola.kt @@ -0,0 +1,55 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * parabola = a*xx + b*x + c + * + * @version $Id$ + */ +internal class MnParabola(private val theA: Double, private val theB: Double, private val theC: Double) { + fun a(): Double { + return theA + } + + fun b(): Double { + return theB + } + + fun c(): Double { + return theC + } + + fun min(): Double { + return -theB / (2.0 * theA) + } + + fun x_neg(y: Double): Double { + return -sqrt(y / theA + min() * min() - theC / theA) + min() + } + + fun x_pos(y: Double): Double { + return sqrt(y / theA + min() * min() - theC / theA) + min() + } + + fun y(x: Double): Double { + return theA * x * x + theB * x + theC + } + + fun ymin(): Double { + return -theB * theB / (4.0 * theA) + theC + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnParabolaFactory.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnParabolaFactory.kt new file mode 100644 index 000000000..f45d2b9c9 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnParabolaFactory.kt @@ -0,0 +1,58 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal object MnParabolaFactory { + fun create(p1: MnParabolaPoint, p2: MnParabolaPoint, p3: MnParabolaPoint): MnParabola { + var x1: Double = p1.x() + var x2: Double = p2.x() + var x3: Double = p3.x() + val dx12 = x1 - x2 + val dx13 = x1 - x3 + val dx23 = x2 - x3 + val xm = (x1 + x2 + x3) / 3.0 + x1 -= xm + x2 -= xm + x3 -= xm + val y1: Double = p1.y() + val y2: Double = p2.y() + val y3: Double = p3.y() + val a = y1 / (dx12 * dx13) - y2 / (dx12 * dx23) + y3 / (dx13 * dx23) + var b = -y1 * (x2 + x3) / (dx12 * dx13) + y2 * (x1 + x3) / (dx12 * dx23) - y3 * (x1 + x2) / (dx13 * dx23) + var c = y1 - a * x1 * x1 - b * x1 + c += xm * (xm * a - b) + b -= 2.0 * xm * a + return MnParabola(a, b, c) + } + + fun create(p1: MnParabolaPoint, dxdy1: Double, p2: MnParabolaPoint): MnParabola { + val x1: Double = p1.x() + val xx1 = x1 * x1 + val x2: Double = p2.x() + val xx2 = x2 * x2 + val y1: Double = p1.y() + val y12: Double = p1.y() - p2.y() + val det = xx1 - xx2 - 2.0 * x1 * (x1 - x2) + val a = -(y12 + (x2 - x1) * dxdy1) / det + val b = -(-2.0 * x1 * y12 + (xx1 - xx2) * dxdy1) / det + val c = y1 - a * xx1 - b * x1 + return MnParabola(a, b, c) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnParabolaPoint.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnParabolaPoint.kt new file mode 100644 index 000000000..858e010e6 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnParabolaPoint.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal class MnParabolaPoint(private val theX: Double, private val theY: Double) { + fun x(): Double { + return theX + } + + fun y(): Double { + return theY + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnParameterScan.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnParameterScan.kt new file mode 100644 index 000000000..7791c20e8 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnParameterScan.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction + +/** + * Scans the values of FCN as a function of one parameter and retains the best + * function and parameter values found + * + * @version $Id$ + */ +internal class MnParameterScan { + private var theAmin: Double + private var theFCN: MultiFunction? + private var theParameters: MnUserParameters + + constructor(fcn: MultiFunction, par: MnUserParameters) { + theFCN = fcn + theParameters = par + theAmin = fcn.value(par.params()) + } + + constructor(fcn: MultiFunction?, par: MnUserParameters, fval: Double) { + theFCN = fcn + theParameters = par + theAmin = fval + } + + fun fval(): Double { + return theAmin + } + + fun parameters(): MnUserParameters { + return theParameters + } + + fun scan(par: Int): List { + return scan(par, 41) + } + + fun scan(par: Int, maxsteps: Int): List { + return scan(par, maxsteps, 0.0, 0.0) + } + + /** + * returns pairs of (x,y) points, x=parameter value, y=function value of FCN + * @param high + * @return + */ + fun scan(par: Int, maxsteps: Int, low: Double, high: Double): List { + var maxsteps = maxsteps + var low = low + var high = high + if (maxsteps > 101) { + maxsteps = 101 + } + val result: MutableList = java.util.ArrayList(maxsteps + 1) + val params: DoubleArray = theParameters.params() + result.add(Range(params[par], theAmin)) + if (low > high) { + return result + } + if (maxsteps < 2) { + return result + } + if (low == 0.0 && high == 0.0) { + low = params[par] - 2.0 * theParameters.error(par) + high = params[par] + 2.0 * theParameters.error(par) + } + if (low == 0.0 && high == 0.0 && theParameters.parameter(par).hasLimits()) { + if (theParameters.parameter(par).hasLowerLimit()) { + low = theParameters.parameter(par).lowerLimit() + } + if (theParameters.parameter(par).hasUpperLimit()) { + high = theParameters.parameter(par).upperLimit() + } + } + if (theParameters.parameter(par).hasLimits()) { + if (theParameters.parameter(par).hasLowerLimit()) { + low = max(low, theParameters.parameter(par).lowerLimit()) + } + if (theParameters.parameter(par).hasUpperLimit()) { + high = min(high, theParameters.parameter(par).upperLimit()) + } + } + val x0 = low + val stp = (high - low) / (maxsteps - 1.0) + for (i in 0 until maxsteps) { + params[par] = x0 + i.toDouble() * stp + val fval: Double = theFCN.value(params) + if (fval < theAmin) { + theParameters.setValue(par, params[par]) + theAmin = fval + } + result.add(Range(params[par], fval)) + } + return result + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnPlot.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnPlot.kt new file mode 100644 index 000000000..656dd8d35 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnPlot.kt @@ -0,0 +1,438 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import java.lang.StringBuffer +import kotlin.jvm.JvmOverloads + +/** + * MnPlot produces a text-screen graphical output of (x,y) points. E.g. from + * Scan or Contours. + * + * @version $Id$ + * @author Darksnake + */ +class MnPlot @JvmOverloads constructor(private val thePageWidth: Int = 80, private val thePageLength: Int = 30) { + private var bh = 0.0 + private var bl = 0.0 + private var bwid = 0.0 + private var nb = 0 + fun length(): Int { + return thePageLength + } + + private fun mnbins(a1: Double, a2: Double, naa: Int) { + + //*-*-*-*-*-*-*-*-*-*-*Compute reasonable histogram intervals*-*-*-*-*-*-*-*-* + //*-* ====================================== + //*-* Function TO DETERMINE REASONABLE HISTOGRAM INTERVALS + //*-* GIVEN ABSOLUTE UPPER AND LOWER BOUNDS A1 AND A2 + //*-* AND DESIRED MAXIMUM NUMBER OF BINS NAA + //*-* PROGRAM MAKES REASONABLE BINNING FROM BL TO BH OF WIDTH BWID + //*-* F. JAMES, AUGUST, 1974 , stolen for Minuit, 1988 + //*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* + + /* Local variables */ + var awid: Double + var ah: Double + var sigfig: Double + var sigrnd: Double + var alb: Double + var kwid: Int + var lwid: Int + var na = 0 + var log_: Int + val al: Double = if (a1 < a2) a1 else a2 + ah = if (a1 > a2) a1 else a2 + if (al == ah) { + ah = al + 1 + } + + //*-*- IF NAA .EQ. -1 , PROGRAM USES BWID INPUT FROM CALLING ROUTINE + var skip = naa == -1 && bwid > 0 + if (!skip) { + na = naa - 1 + if (na < 1) { + na = 1 + } + } + while (true) { + if (!skip) { + //*-*- GET NOMINAL BIN WIDTH IN EXPON FORM + awid = (ah - al) / na.toDouble() + log_ = log10(awid) + if (awid <= 1) { + --log_ + } + sigfig = awid * pow(10.0, -log_.toDouble()) + //*-*- ROUND MANTISSA UP TO 2, 2.5, 5, OR 10 + if (sigfig <= 2) { + sigrnd = 2.0 + } else if (sigfig <= 2.5) { + sigrnd = 2.5 + } else if (sigfig <= 5) { + sigrnd = 5.0 + } else { + sigrnd = 1.0 + ++log_ + } + bwid = sigrnd * pow(10.0, log_.toDouble()) + } + alb = al / bwid + lwid = alb.toInt() + if (alb < 0) { + --lwid + } + bl = bwid * lwid.toDouble() + alb = ah / bwid + 1 + kwid = alb.toInt() + if (alb < 0) { + --kwid + } + bh = bwid * kwid.toDouble() + nb = kwid - lwid + if (naa <= 5) { + if (naa == -1) { + return + } + //*-*- REQUEST FOR ONE BIN IS DIFFICULT CASE + if (naa > 1 || nb == 1) { + return + } + bwid *= 2.0 + nb = 1 + return + } + if (nb shl 1 != naa) { + return + } + ++na + skip = false + continue + } + } + + private fun mnplot(xpt: DoubleArray, ypt: DoubleArray, chpt: StringBuffer, nxypt: Int, npagwd: Int, npagln: Int) { + //*-*-*-*Plots points in array xypt onto one page with labelled axes*-*-*-*-* + //*-* =========================================================== + //*-* NXYPT is the number of points to be plotted + //*-* XPT(I) = x-coord. of ith point + //*-* YPT(I) = y-coord. of ith point + //*-* CHPT(I) = character to be plotted at this position + //*-* the input point arrays XPT, YPT, CHPT are destroyed. + //*-* + //*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* + + /* Local variables */ + var xmin: Double + var xmax: Double + var ymax: Double + var savx: Double + var savy: Double + var yprt: Double + var xbest: Double + var ybest: Double + val xvalus = DoubleArray(12) + val any: Double + val iten: Int + var j: Int + var k: Int + var maxnx: Int + var maxny: Int + var iquit: Int + var ni: Int + var linodd: Int + var ibk: Int + var isp1: Int + var ks: Int + var ix: Int + var overpr: Boolean + val cline = StringBuffer(npagwd) + for (ii in 0 until npagwd) { + cline.append(' ') + } + var chsav: Char + val chbest: Char + + /* Function Body */ + //*-* Computing MIN + maxnx = if (npagwd - 20 < 100) npagwd - 20 else 100 + if (maxnx < 10) { + maxnx = 10 + } + maxny = npagln + if (maxny < 10) { + maxny = 10 + } + if (nxypt <= 1) { + return + } + xbest = xpt[0] + ybest = ypt[0] + chbest = chpt.get(0) + //*-*- order the points by decreasing y + val km1: Int = nxypt - 1 + var i: Int = 1 + while (i <= km1) { + iquit = 0 + ni = nxypt - i + j = 1 + while (j <= ni) { + if (ypt[j - 1] > ypt[j]) { + ++j + continue + } + savx = xpt[j - 1] + xpt[j - 1] = xpt[j] + xpt[j] = savx + savy = ypt[j - 1] + ypt[j - 1] = ypt[j] + ypt[j] = savy + chsav = chpt.get(j - 1) + chpt.setCharAt(j - 1, chpt.get(j)) + chpt.setCharAt(j, chsav) + iquit = 1 + ++j + } + if (iquit == 0) { + break + } + ++i + } + //*-*- find extreme values + xmax = xpt[0] + xmin = xmax + i = 1 + while (i <= nxypt) { + if (xpt[i - 1] > xmax) { + xmax = xpt[i - 1] + } + if (xpt[i - 1] < xmin) { + xmin = xpt[i - 1] + } + ++i + } + val dxx: Double = (xmax - xmin) * .001 + xmax += dxx + xmin -= dxx + mnbins(xmin, xmax, maxnx) + xmin = bl + xmax = bh + var nx: Int = nb + val bwidx: Double = bwid + ymax = ypt[0] + var ymin: Double = ypt[nxypt - 1] + if (ymax == ymin) { + ymax = ymin + 1 + } + val dyy: Double = (ymax - ymin) * .001 + ymax += dyy + ymin -= dyy + mnbins(ymin, ymax, maxny) + ymin = bl + ymax = bh + var ny: Int = nb + val bwidy: Double = bwid + any = ny.toDouble() + //*-*- if first point is blank, it is an 'origin' + if (chbest != ' ') { + xbest = (xmax + xmin) * .5 + ybest = (ymax + ymin) * .5 + } + //*-*- find scale constants + val ax: Double = 1 / bwidx + val ay: Double = 1 / bwidy + val bx: Double = -ax * xmin + 2 + val by: Double = -ay * ymin - 2 + //*-*- convert points to grid positions + i = 1 + while (i <= nxypt) { + xpt[i - 1] = ax * xpt[i - 1] + bx + ypt[i - 1] = any - ay * ypt[i - 1] - by + ++i + } + val nxbest: Int = (ax * xbest + bx).toInt() + val nybest: Int = (any - ay * ybest - by).toInt() + //*-*- print the points + ny += 2 + nx += 2 + isp1 = 1 + linodd = 1 + overpr = false + i = 1 + while (i <= ny) { + ibk = 1 + while (ibk <= nx) { + cline.setCharAt(ibk - 1, ' ') + ++ibk + } + // cline.setCharAt(nx,'\0'); + // cline.setCharAt(nx+1,'\0'); + cline.setCharAt(0, '.') + cline.setCharAt(nx - 1, '.') + cline.setCharAt(nxbest - 1, '.') + if (i == 1 || i == nybest || i == ny) { + j = 1 + while (j <= nx) { + cline.setCharAt(j - 1, '.') + ++j + } + } + yprt = ymax - (i - 1.0) * bwidy + var isplset = false + if (isp1 <= nxypt) { + //*-*- find the points to be plotted on this line + k = isp1 + while (k <= nxypt) { + ks = ypt[k - 1].toInt() + if (ks > i) { + isp1 = k + isplset = true + break + } + ix = xpt[k - 1].toInt() + if (cline.get(ix - 1) != '.' && cline.get(ix - 1) != ' ') { + if (cline.get(ix - 1) == chpt.get(k - 1)) { + ++k + continue + } + overpr = true + //*-*- OVERPR is true if one or more positions contains more than + //*-*- one point + cline.setCharAt(ix - 1, '&') + ++k + continue + } + cline.setCharAt(ix - 1, chpt.get(k - 1)) + ++k + } + if (!isplset) { + isp1 = nxypt + 1 + } + } + if (linodd != 1 && i != ny) { + linodd = 1 + java.lang.System.out.printf(" %s", cline.substring(0, 60)) + } else { + java.lang.System.out.printf(" %14.7g ..%s", yprt, cline.substring(0, 60)) + linodd = 0 + } + println() + ++i + } + //*-*- print labels on x-axis every ten columns + ibk = 1 + while (ibk <= nx) { + cline.setCharAt(ibk - 1, ' ') + if (ibk % 10 == 1) { + cline.setCharAt(ibk - 1, '/') + } + ++ibk + } + java.lang.System.out.printf(" %s", cline) + java.lang.System.out.printf("\n") + ibk = 1 + while (ibk <= 12) { + xvalus[ibk - 1] = xmin + (ibk - 1.0) * 10 * bwidx + ++ibk + } + java.lang.System.out.printf(" ") + iten = (nx + 9) / 10 + ibk = 1 + while (ibk <= iten) { + java.lang.System.out.printf(" %9.4g", xvalus[ibk - 1]) + ++ibk + } + java.lang.System.out.printf("\n") + if (overpr) { + val chmess = " Overprint character is &" + java.lang.System.out.printf(" ONE COLUMN=%13.7g%s", bwidx, chmess) + } else { + val chmess = " " + java.lang.System.out.printf(" ONE COLUMN=%13.7g%s", bwidx, chmess) + } + println() + } + + /** + * + * plot. + * + * @param points a [List] object. + */ + fun plot(points: List) { + val x = DoubleArray(points.size) + val y = DoubleArray(points.size) + val chpt = StringBuffer(points.size) + for ((i, ipoint) in points.withIndex()) { + x[i] = ipoint.getFirst() + y[i] = ipoint.getSecond() + chpt.append('*') + } + mnplot(x, y, chpt, points.size, width(), length()) + } + + /** + * + * plot. + * + * @param xmin a double. + * @param ymin a double. + * @param points a [List] object. + */ + fun plot(xmin: Double, ymin: Double, points: List) { + val x = DoubleArray(points.size + 2) + x[0] = xmin + x[1] = xmin + val y = DoubleArray(points.size + 2) + y[0] = ymin + y[1] = ymin + val chpt = StringBuffer(points.size + 2) + chpt.append(' ') + chpt.append('X') + var i = 2 + for (ipoint in points) { + x[i] = ipoint.getFirst() + y[i] = ipoint.getSecond() + chpt.append('*') + i++ + } + mnplot(x, y, chpt, points.size + 2, width(), length()) + } + + fun width(): Int { + return thePageWidth + } + /** + * + * Constructor for MnPlot. + * + * @param thePageWidth a int. + * @param thePageLength a int. + */ + /** + * + * Constructor for MnPlot. + */ + init { + if (thePageWidth > 120) { + thePageWidth = 120 + } + if (thePageLength > 56) { + thePageLength = 56 + } + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnPosDef.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnPosDef.kt new file mode 100644 index 000000000..f94e387d9 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnPosDef.kt @@ -0,0 +1,89 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import space.kscience.kmath.optimization.minuit.MINUITPlugin + +/** + * + * @version $Id$ + */ +internal object MnPosDef { + fun test(st: MinimumState, prec: MnMachinePrecision): MinimumState { + val err: MinimumError = test(st.error(), prec) + return MinimumState(st.parameters(), err, st.gradient(), st.edm(), st.nfcn()) + } + + fun test(e: MinimumError, prec: MnMachinePrecision): MinimumError { + val err: MnAlgebraicSymMatrix = e.invHessian().copy() + if (err.size() === 1 && err[0, 0] < prec.eps()) { + err[0, 0] = 1.0 + return MinimumError(err, MnMadePosDef()) + } + if (err.size() === 1 && err[0, 0] > prec.eps()) { + return e + } + // std::cout<<"MnPosDef init matrix= "< 0.0) { + os.printf(" limited || %10g", ipar.value()) + if (abs(ipar.value() - ipar.lowerLimit()) < par.precision().eps2()) { + os.print("* ") + atLoLim = true + } + if (abs(ipar.value() - ipar.upperLimit()) < par.precision().eps2()) { + os.print("**") + atHiLim = true + } + os.printf(" || %10g\n", ipar.error()) + } else { + os.printf(" free || %10g || no\n", ipar.value()) + } + } else { + if (ipar.error() > 0.0) { + os.printf(" free || %10g || %10g\n", ipar.value(), ipar.error()) + } else { + os.printf(" free || %10g || no\n", ipar.value()) + } + } + } + os.println() + if (atLoLim) { + os.print("* parameter is at lower limit") + } + if (atHiLim) { + os.print("** parameter is at upper limit") + } + os.println() + } + + /** + * + * print. + * + * @param os a [PrintWriter] object. + * @param matrix a [hep.dataforge.MINUIT.MnUserCovariance] object. + */ + fun print(os: PrintWriter, matrix: MnUserCovariance) { + os.println() + os.println("MnUserCovariance: ") + run { + os.println() + val n: Int = matrix.nrow() + for (i in 0 until n) { + for (j in 0 until n) { + os.printf("%10g ", matrix[i, j]) + } + os.println() + } + } + os.println() + os.println("MnUserCovariance parameter correlations: ") + run { + os.println() + val n: Int = matrix.nrow() + for (i in 0 until n) { + val di: Double = matrix[i, i] + for (j in 0 until n) { + val dj: Double = matrix[j, j] + os.printf("%g ", matrix[i, j] / sqrt(abs(di * dj))) + } + os.println() + } + } + } + + /** + * + * print. + * + * @param os a [PrintWriter] object. + * @param coeff a [hep.dataforge.MINUIT.MnGlobalCorrelationCoeff] object. + */ + fun print(os: PrintWriter, coeff: MnGlobalCorrelationCoeff) { + os.println() + os.println("MnGlobalCorrelationCoeff: ") + run { + os.println() + for (i in 0 until coeff.globalCC().length) { + os.printf("%g\n", coeff.globalCC()[i]) + } + } + } + + /** + * + * print. + * + * @param os a [PrintWriter] object. + * @param state a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun print(os: PrintWriter, state: MnUserParameterState) { + os.println() + if (!state.isValid()) { + os.println() + os.println("WARNING: MnUserParameterState is not valid.") + os.println() + } + os.println("# of function calls: " + state.nfcn()) + os.println("function value: " + state.fval()) + os.println("expected distance to the minimum (edm): " + state.edm()) + os.println("external parameters: " + state.parameters()) + if (state.hasCovariance()) { + os.println("covariance matrix: " + state.covariance()) + } + if (state.hasGlobalCC()) { + os.println("global correlation coefficients : " + state.globalCC()) + } + if (!state.isValid()) { + os.println("WARNING: MnUserParameterState is not valid.") + } + os.println() + } + + /** + * + * print. + * + * @param os a [PrintWriter] object. + * @param me a [hep.dataforge.MINUIT.MinosError] object. + */ + fun print(os: PrintWriter, me: MinosError) { + os.println() + os.printf("Minos # of function calls: %d\n", me.nfcn()) + if (!me.isValid()) { + os.println("Minos error is not valid.") + } + if (!me.lowerValid()) { + os.println("lower Minos error is not valid.") + } + if (!me.upperValid()) { + os.println("upper Minos error is not valid.") + } + if (me.atLowerLimit()) { + os.println("Minos error is lower limit of parameter " + me.parameter()) + } + if (me.atUpperLimit()) { + os.println("Minos error is upper limit of parameter " + me.parameter()) + } + if (me.atLowerMaxFcn()) { + os.println("Minos number of function calls for lower error exhausted.") + } + if (me.atUpperMaxFcn()) { + os.println("Minos number of function calls for upper error exhausted.") + } + if (me.lowerNewMin()) { + os.println("Minos found a new minimum in negative direction.") + os.println(me.lowerState()) + } + if (me.upperNewMin()) { + os.println("Minos found a new minimum in positive direction.") + os.println(me.upperState()) + } + os.println("# ext. || name || value@min || negative || positive ") + os.printf("%4d||%10s||%10g||%10g||%10g\n", + me.parameter(), + me.lowerState().name(me.parameter()), + me.min(), + me.lower(), + me.upper()) + os.println() + } + + /** + * + * print. + * + * @param os a [PrintWriter] object. + * @param ce a [hep.dataforge.MINUIT.ContoursError] object. + */ + fun print(os: PrintWriter, ce: ContoursError) { + os.println() + os.println("Contours # of function calls: " + ce.nfcn()) + os.println("MinosError in x: ") + os.println(ce.xMinosError()) + os.println("MinosError in y: ") + os.println(ce.yMinosError()) + val plot = MnPlot() + plot.plot(ce.xmin(), ce.ymin(), ce.points()) + for ((i, ipoint) in ce.points().withIndex()) { + os.printf("%d %10g %10g\n", i, ipoint.getFirst(), ipoint.getSecond()) + } + os.println() + } + + fun toString(x: RealVector): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, x) } + return writer.toString() + } + + fun toString(x: MnAlgebraicSymMatrix?): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, x) } + return writer.toString() + } + + fun toString(min: FunctionMinimum?): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, min) } + return writer.toString() + } + + fun toString(x: MinimumState?): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, x) } + return writer.toString() + } + + fun toString(x: MnUserParameters?): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, x) } + return writer.toString() + } + + fun toString(x: MnUserCovariance?): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, x) } + return writer.toString() + } + + fun toString(x: MnGlobalCorrelationCoeff?): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, x) } + return writer.toString() + } + + fun toString(x: MnUserParameterState?): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, x) } + return writer.toString() + } + + fun toString(x: MinosError?): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, x) } + return writer.toString() + } + + fun toString(x: ContoursError?): String { + val writer: java.io.StringWriter = java.io.StringWriter() + PrintWriter(writer).use { pw -> print(pw, x) } + return writer.toString() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnScan.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnScan.kt new file mode 100644 index 000000000..63e565b4f --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnScan.kt @@ -0,0 +1,181 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction +import ru.inr.mass.minuit.* + +/** + * MnScan scans the value of the user function by varying one parameter. It is + * sometimes useful for debugging the user function or finding a reasonable + * starting point. + * construct from MultiFunction + MnUserParameterState + MnStrategy + * + * @param str a [hep.dataforge.MINUIT.MnStrategy] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameterState] object. + * @param fcn a [MultiFunction] object. + * @version $Id$ + * @author Darksnake + */ +class MnScan(fcn: MultiFunction?, par: MnUserParameterState, str: MnStrategy) : MnApplication(fcn, par, str) { + private val theMinimizer: ScanMinimizer = ScanMinimizer() + + /** + * construct from MultiFunction + double[] for parameters and errors + * with default strategy + * + * @param err an array of double. + * @param par an array of double. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray) : this(fcn, par, err, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + double[] for parameters and errors + * + * @param stra a int. + * @param err an array of double. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray, stra: Int) : this(fcn, + MnUserParameterState(par, err), + MnStrategy(stra)) + + /** + * construct from MultiFunction + double[] for parameters and + * MnUserCovariance with default strategy + * + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param par an array of double. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance) : this(fcn, par, cov, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + double[] for parameters and + * MnUserCovariance + * + * @param stra a int. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance, stra: Int) : this(fcn, + MnUserParameterState(par, cov), + MnStrategy(stra)) + + /** + * construct from MultiFunction + MnUserParameters with default + * strategy + * + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters) : this(fcn, par, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + MnUserParameters + * + * @param stra a int. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, stra: Int) : this(fcn, + MnUserParameterState(par), + MnStrategy(stra)) + + /** + * construct from MultiFunction + MnUserParameters + MnUserCovariance + * with default strategy + * + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, cov: MnUserCovariance) : this(fcn, + par, + cov, + DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + MnUserParameters + MnUserCovariance + * + * @param stra a int. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, cov: MnUserCovariance, stra: Int) : this(fcn, + MnUserParameterState(par, cov), + MnStrategy(stra)) + + override fun minimizer(): ModularFunctionMinimizer { + return theMinimizer + } + + /** + * + * scan. + * + * @param par a int. + * @return a [List] object. + */ + fun scan(par: Int): List { + return scan(par, 41) + } + + /** + * + * scan. + * + * @param par a int. + * @param maxsteps a int. + * @return a [List] object. + */ + fun scan(par: Int, maxsteps: Int): List { + return scan(par, maxsteps, 0.0, 0.0) + } + + /** + * Scans the value of the user function by varying parameter number par, + * leaving all other parameters fixed at the current value. If par is not + * specified, all variable parameters are scanned in sequence. The number of + * points npoints in the scan is 40 by default, and cannot exceed 100. The + * range of the scan is by default 2 standard deviations on each side of the + * current best value, but can be specified as from low to high. After each + * scan, if a new minimum is found, the best parameter values are retained + * as start values for future scans or minimizations. The curve resulting + * from each scan can be plotted on the output terminal using MnPlot in + * order to show the approximate behaviour of the function. + * + * @param high a double. + * @param par a int. + * @param maxsteps a int. + * @param low a double. + * @return a [List] object. + */ + fun scan(par: Int, maxsteps: Int, low: Double, high: Double): List { + val scan = MnParameterScan(theFCN, theState.parameters()) + var amin: Double = scan.fval() + val result: List = scan.scan(par, maxsteps, low, high) + if (scan.fval() < amin) { + theState.setValue(par, scan.parameters().value(par)) + amin = scan.fval() + } + return result + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnSeedGenerator.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnSeedGenerator.kt new file mode 100644 index 000000000..a42edf4f1 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnSeedGenerator.kt @@ -0,0 +1,108 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import space.kscience.kmath.optimization.minuit.MINUITPlugin +import ru.inr.mass.minuit.* +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * + * @version $Id$ + */ +internal class MnSeedGenerator : MinimumSeedGenerator { + /** {@inheritDoc} */ + fun generate(fcn: MnFcn, gc: GradientCalculator, st: MnUserParameterState, stra: MnStrategy): MinimumSeed { + val n: Int = st.variableParameters() + val prec: MnMachinePrecision = st.precision() + + // initial starting values + val x: RealVector = ArrayRealVector(n) + for (i in 0 until n) { + x.setEntry(i, st.intParameters()[i]) + } + val fcnmin: Double = fcn.value(x) + val pa = MinimumParameters(x, fcnmin) + val dgrad: FunctionGradient + if (gc is AnalyticalGradientCalculator) { + val igc = InitialGradientCalculator(fcn, st.getTransformation(), stra) + val tmp: FunctionGradient = igc.gradient(pa) + val grd: FunctionGradient = gc.gradient(pa) + dgrad = FunctionGradient(grd.getGradient(), tmp.getGradientDerivative(), tmp.getStep()) + if (gc.checkGradient()) { + val good = true + val hgc = HessianGradientCalculator(fcn, st.getTransformation(), MnStrategy(2)) + val hgrd: Pair = hgc.deltaGradient(pa, dgrad) + for (i in 0 until n) { + val provided: Double = grd.getGradient().getEntry(i) + val calculated: Double = hgrd.getFirst().getGradient().getEntry(i) + val delta: Double = hgrd.getSecond().getEntry(i) + if (abs(calculated - provided) > delta) { + MINUITPlugin.logStatic("" + + "gradient discrepancy of external parameter \"%d\" " + + "(internal parameter \"%d\") too large. Expected: \"%f\", provided: \"%f\"", + st.getTransformation().extOfInt(i), i, provided, calculated) + +// +// MINUITPlugin.logStatic("gradient discrepancy of external parameter " +// + st.getTransformation().extOfInt(i) +// + " (internal parameter " + i + ") too large."); +// good = false; + } + } + if (!good) { + MINUITPlugin.logStatic("Minuit does not accept user specified gradient.") + // assert(good); + } + } + } else { + dgrad = gc.gradient(pa) + } + val mat = MnAlgebraicSymMatrix(n) + var dcovar = 1.0 + if (st.hasCovariance()) { + for (i in 0 until n) { + for (j in i until n) { + mat[i, j] = st.intCovariance()[i, j] + } + } + dcovar = 0.0 + } else { + for (i in 0 until n) { + mat[i, i] = if (abs(dgrad.getGradientDerivative() + .getEntry(i)) > prec.eps2() + ) 1.0 / dgrad.getGradientDerivative().getEntry(i) else 1.0 + } + } + val err = MinimumError(mat, dcovar) + val edm: Double = VariableMetricEDMEstimator().estimate(dgrad, err) + var state = MinimumState(pa, err, dgrad, edm, fcn.numOfCalls()) + if (NegativeG2LineSearch.hasNegativeG2(dgrad, prec)) { + state = if (gc is AnalyticalGradientCalculator) { + val ngc = Numerical2PGradientCalculator(fcn, st.getTransformation(), stra) + NegativeG2LineSearch.search(fcn, state, ngc, prec) + } else { + NegativeG2LineSearch.search(fcn, state, gc, prec) + } + } + if (stra.strategy() === 2 && !st.hasCovariance()) { + //calculate full 2nd derivative + val tmp: MinimumState = MnHesse(stra).calculate(fcn, state, st.getTransformation(), 0) + return MinimumSeed(tmp, st.getTransformation()) + } + return MinimumSeed(state, st.getTransformation()) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnSimplex.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnSimplex.kt new file mode 100644 index 000000000..b00745f26 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnSimplex.kt @@ -0,0 +1,138 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction +import ru.inr.mass.minuit.* + +/** + * SIMPLEX is a function minimization method using the simplex method of Nelder + * and Mead. MnSimplex provides minimization of the function by the method of + * SIMPLEX and the functionality for parameters interaction. It also retains the + * result from the last minimization in case the user may want to do subsequent + * minimization steps with parameter interactions in between the minimization + * requests. As SIMPLEX is a stepping method it does not produce a covariance + * matrix. + * + * @version $Id$ + * @author Darksnake + */ +class MnSimplex +/** + * construct from MultiFunction + MnUserParameterState + MnStrategy + * + * @param str a [hep.dataforge.MINUIT.MnStrategy] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameterState] object. + * @param fcn a [MultiFunction] object. + */ + (fcn: MultiFunction?, par: MnUserParameterState, str: MnStrategy) : MnApplication(fcn, par, str) { + private val theMinimizer: SimplexMinimizer = SimplexMinimizer() + + /** + * construct from MultiFunction + double[] for parameters and errors + * with default strategy + * + * @param err an array of double. + * @param par an array of double. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray) : this(fcn, par, err, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + double[] for parameters and errors + * + * @param stra a int. + * @param err an array of double. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, err: DoubleArray, stra: Int) : this(fcn, + MnUserParameterState(par, err), + MnStrategy(stra)) + + /** + * construct from MultiFunction + double[] for parameters and + * MnUserCovariance with default strategy + * + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param par an array of double. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance) : this(fcn, par, cov, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + double[] for parameters and + * MnUserCovariance + * + * @param stra a int. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param fcn a [MultiFunction] object. + * @param par an array of double. + */ + constructor(fcn: MultiFunction?, par: DoubleArray, cov: MnUserCovariance, stra: Int) : this(fcn, + MnUserParameterState(par, cov), + MnStrategy(stra)) + + /** + * construct from MultiFunction + MnUserParameters with default + * strategy + * + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters) : this(fcn, par, DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + MnUserParameters + * + * @param stra a int. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, stra: Int) : this(fcn, + MnUserParameterState(par), + MnStrategy(stra)) + + /** + * construct from MultiFunction + MnUserParameters + MnUserCovariance + * with default strategy + * + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + * @param fcn a [MultiFunction] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, cov: MnUserCovariance) : this(fcn, + par, + cov, + DEFAULT_STRATEGY) + + /** + * construct from MultiFunction + MnUserParameters + MnUserCovariance + * + * @param stra a int. + * @param cov a [hep.dataforge.MINUIT.MnUserCovariance] object. + * @param fcn a [MultiFunction] object. + * @param par a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + constructor(fcn: MultiFunction?, par: MnUserParameters, cov: MnUserCovariance, stra: Int) : this(fcn, + MnUserParameterState(par, cov), + MnStrategy(stra)) + + /** {@inheritDoc} */ + override fun minimizer(): ModularFunctionMinimizer { + return theMinimizer + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnStrategy.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnStrategy.kt new file mode 100644 index 000000000..31b894665 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnStrategy.kt @@ -0,0 +1,310 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * API class for defining three levels of strategies: low (0), medium (1), high + * (2). + * + * + * At many places in the analysis of the FCN (the user provided function), + * MINUIT must decide whether to be safe and waste a few function calls + * in order to know where it is, or to be fast and attempt to get the + * requested results with the fewest possible calls at a certain risk of not + * obtaining the precision desired by the user. In order to allow the user to + * infuence these decisions, the MnStrategy class allows the user to control + * different settings. MnStrategy can be instantiated with three different + * minimization quality levels for low (0), medium (1) and high (2) quality. + * Default settings for iteration cycles and tolerances are initialized then. + * + * + * The default setting is set for medium quality. Value 0 (low) indicates to + * MINUIT that it should economize function calls; it is intended for cases + * where there are many variable parameters and/or the function takes a long + * time to calculate and/or the user is not interested in very precise values + * for parameter errors. On the other hand, value 2 (high) indicates that MINUIT + * is allowed to waste function calls in order to be sure that all values are + * precise; it is it is intended for cases where the function is evaluated in a + * relatively short time and/or where the parameter errors must be calculated + * reliably. + * + * In addition all constants set in MnStrategy can be changed individually by + * the user, e.g. the number of iteration cycles in the numerical gradient. + * + * + * + * + * Acts on: Migrad (behavioural), Minos (lowers strategy by 1 for Minos-own + * minimization), Hesse (iterations), Numerical2PDerivative (iterations) + * + * @author Darksnake + * @version $Id$ + */ +class MnStrategy { + private var theGradNCyc = 0 + private var theGradTlr = 0.0 + private var theGradTlrStp = 0.0 + private var theHessGradNCyc = 0 + + //default strategy + private var theHessNCyc = 0 + private var theHessTlrG2 = 0.0 + private var theHessTlrStp = 0.0 + private var theStrategy = 0 + + /** + * Creates a MnStrategy object with the default strategy (medium) + */ + constructor() { + setMediumStrategy() + } + //user defined strategy (0, 1, >=2) + /** + * Creates a MnStrategy object with the user specified strategy. + * + * @param stra The use defined strategy, 0=low, 1 medium, 2=high. + */ + constructor(stra: Int) { + if (stra == 0) { + setLowStrategy() + } else if (stra == 1) { + setMediumStrategy() + } else { + setHighStrategy() + } + } + + /** + * + * gradientNCycles. + * + * @return a int. + */ + fun gradientNCycles(): Int { + return theGradNCyc + } + + /** + * + * gradientStepTolerance. + * + * @return a double. + */ + fun gradientStepTolerance(): Double { + return theGradTlrStp + } + + /** + * + * gradientTolerance. + * + * @return a double. + */ + fun gradientTolerance(): Double { + return theGradTlr + } + + /** + * + * hessianG2Tolerance. + * + * @return a double. + */ + fun hessianG2Tolerance(): Double { + return theHessTlrG2 + } + + /** + * + * hessianGradientNCycles. + * + * @return a int. + */ + fun hessianGradientNCycles(): Int { + return theHessGradNCyc + } + + /** + * + * hessianNCycles. + * + * @return a int. + */ + fun hessianNCycles(): Int { + return theHessNCyc + } + + /** + * + * hessianStepTolerance. + * + * @return a double. + */ + fun hessianStepTolerance(): Double { + return theHessTlrStp + } + + /** + * + * isHigh. + * + * @return a boolean. + */ + fun isHigh(): Boolean { + return theStrategy >= 2 + } + + /** + * + * isLow. + * + * @return a boolean. + */ + fun isLow(): Boolean { + return theStrategy <= 0 + } + + /** + * + * isMedium. + * + * @return a boolean. + */ + fun isMedium(): Boolean { + return theStrategy == 1 + } + + /** + * + * setGradientNCycles. + * + * @param n a int. + */ + fun setGradientNCycles(n: Int) { + theGradNCyc = n + } + + /** + * + * setGradientStepTolerance. + * + * @param stp a double. + */ + fun setGradientStepTolerance(stp: Double) { + theGradTlrStp = stp + } + + /** + * + * setGradientTolerance. + * + * @param toler a double. + */ + fun setGradientTolerance(toler: Double) { + theGradTlr = toler + } + + /** + * + * setHessianG2Tolerance. + * + * @param toler a double. + */ + fun setHessianG2Tolerance(toler: Double) { + theHessTlrG2 = toler + } + + /** + * + * setHessianGradientNCycles. + * + * @param n a int. + */ + fun setHessianGradientNCycles(n: Int) { + theHessGradNCyc = n + } + + /** + * + * setHessianNCycles. + * + * @param n a int. + */ + fun setHessianNCycles(n: Int) { + theHessNCyc = n + } + + /** + * + * setHessianStepTolerance. + * + * @param stp a double. + */ + fun setHessianStepTolerance(stp: Double) { + theHessTlrStp = stp + } + + fun setHighStrategy() { + theStrategy = 2 + setGradientNCycles(5) + setGradientStepTolerance(0.1) + setGradientTolerance(0.02) + setHessianNCycles(7) + setHessianStepTolerance(0.1) + setHessianG2Tolerance(0.02) + setHessianGradientNCycles(6) + } + + /** + * + * setLowStrategy. + */ + fun setLowStrategy() { + theStrategy = 0 + setGradientNCycles(2) + setGradientStepTolerance(0.5) + setGradientTolerance(0.1) + setHessianNCycles(3) + setHessianStepTolerance(0.5) + setHessianG2Tolerance(0.1) + setHessianGradientNCycles(1) + } + + /** + * + * setMediumStrategy. + */ + fun setMediumStrategy() { + theStrategy = 1 + setGradientNCycles(3) + setGradientStepTolerance(0.3) + setGradientTolerance(0.05) + setHessianNCycles(5) + setHessianStepTolerance(0.3) + setHessianG2Tolerance(0.05) + setHessianGradientNCycles(2) + } + + /** + * + * strategy. + * + * @return a int. + */ + fun strategy(): Int { + return theStrategy + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnUserCovariance.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnUserCovariance.kt new file mode 100644 index 000000000..297588f8e --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnUserCovariance.kt @@ -0,0 +1,147 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * MnUserCovariance is the external covariance matrix designed for the + * interaction of the user. The result of the minimization (internal covariance + * matrix) is converted into the user representable format. It can also be used + * as input prior to the minimization. The size of the covariance matrix is + * according to the number of variable parameters (free and limited). + * + * @version $Id$ + * @author Darksnake + */ +class MnUserCovariance { + private var theData: DoubleArray + private var theNRow: Int + + private constructor(other: MnUserCovariance) { + theData = other.theData.clone() + theNRow = other.theNRow + } + + internal constructor() { + theData = DoubleArray(0) + theNRow = 0 + } + + /* + * covariance matrix is stored in upper triangular packed storage format, + * e.g. the elements in the array are arranged like + * {a(0,0), a(0,1), a(1,1), a(0,2), a(1,2), a(2,2), ...}, + * the size is nrow*(nrow+1)/2. + */ + internal constructor(data: DoubleArray, nrow: Int) { + require(data.size == nrow * (nrow + 1) / 2) { "Inconsistent arguments" } + theData = data + theNRow = nrow + } + + /** + * + * Constructor for MnUserCovariance. + * + * @param nrow a int. + */ + constructor(nrow: Int) { + theData = DoubleArray(nrow * (nrow + 1) / 2) + theNRow = nrow + } + + /** + * + * copy. + * + * @return a [hep.dataforge.MINUIT.MnUserCovariance] object. + */ + fun copy(): MnUserCovariance { + return MnUserCovariance(this) + } + + fun data(): DoubleArray { + return theData + } + + /** + * + * get. + * + * @param row a int. + * @param col a int. + * @return a double. + */ + operator fun get(row: Int, col: Int): Double { + require(!(row >= theNRow || col >= theNRow)) + return if (row > col) { + theData[col + row * (row + 1) / 2] + } else { + theData[row + col * (col + 1) / 2] + } + } + + /** + * + * ncol. + * + * @return a int. + */ + fun ncol(): Int { + return theNRow + } + + /** + * + * nrow. + * + * @return a int. + */ + fun nrow(): Int { + return theNRow + } + + fun scale(f: Double) { + for (i in theData.indices) { + theData[i] *= f + } + } + + /** + * + * set. + * + * @param row a int. + * @param col a int. + * @param value a double. + */ + operator fun set(row: Int, col: Int, value: Double) { + require(!(row >= theNRow || col >= theNRow)) + if (row > col) { + theData[col + row * (row + 1) / 2] = value + } else { + theData[row + col * (col + 1) / 2] = value + } + } + + fun size(): Int { + return theData.size + } + + /** {@inheritDoc} */ + override fun toString(): String { + return MnPrint.toString(this) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnUserFcn.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnUserFcn.kt new file mode 100644 index 000000000..8198a41ab --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnUserFcn.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction + +/** + * + * @version $Id$ + */ +internal class MnUserFcn(fcn: MultiFunction?, errDef: Double, trafo: MnUserTransformation) : MnFcn(fcn, errDef) { + private val theTransform: MnUserTransformation = trafo + override fun value(v: RealVector): Double { + return super.value(theTransform.transform(v)) + } + +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnUserParameterState.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnUserParameterState.kt new file mode 100644 index 000000000..e80dd60a1 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnUserParameterState.kt @@ -0,0 +1,756 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.minuit.* + +/** + * The class MnUserParameterState contains the MnUserParameters and the + * MnUserCovariance. It can be created on input by the user, or by MINUIT itself + * as user representable format of the result of the minimization. + * + * @version $Id$ + * @author Darksnake + */ +class MnUserParameterState { + private var theCovariance: MnUserCovariance + private var theCovarianceValid = false + private var theEDM = 0.0 + private var theFVal = 0.0 + private var theGCCValid = false + private var theGlobalCC: MnGlobalCorrelationCoeff? = null + private var theIntCovariance: MnUserCovariance + private var theIntParameters: MutableList + private var theNFcn = 0 + private var theParameters: MnUserParameters + private var theValid: Boolean + + internal constructor() { + theValid = false + theCovarianceValid = false + theParameters = MnUserParameters() + theCovariance = MnUserCovariance() + theIntParameters = java.util.ArrayList() + theIntCovariance = MnUserCovariance() + } + + private constructor(other: MnUserParameterState) { + theValid = other.theValid + theCovarianceValid = other.theCovarianceValid + theGCCValid = other.theGCCValid + theFVal = other.theFVal + theEDM = other.theEDM + theNFcn = other.theNFcn + theParameters = other.theParameters.copy() + theCovariance = other.theCovariance + theGlobalCC = other.theGlobalCC + theIntParameters = java.util.ArrayList(other.theIntParameters) + theIntCovariance = other.theIntCovariance.copy() + } + + /** + * construct from user parameters (before minimization) + * @param par + * @param err + */ + internal constructor(par: DoubleArray, err: DoubleArray) { + theValid = true + theParameters = MnUserParameters(par, err) + theCovariance = MnUserCovariance() + theGlobalCC = MnGlobalCorrelationCoeff() + theIntParameters = java.util.ArrayList(par.size) + for (i in par.indices) { + theIntParameters.add(par[i]) + } + theIntCovariance = MnUserCovariance() + } + + internal constructor(par: MnUserParameters) { + theValid = true + theParameters = par + theCovariance = MnUserCovariance() + theGlobalCC = MnGlobalCorrelationCoeff() + theIntParameters = java.util.ArrayList(par.variableParameters()) + theIntCovariance = MnUserCovariance() + val i = 0 + for (ipar in par.parameters()) { + if (ipar.isConst() || ipar.isFixed()) { + continue + } + if (ipar.hasLimits()) { + theIntParameters.add(ext2int(ipar.number(), ipar.value())) + } else { + theIntParameters.add(ipar.value()) + } + } + } + + /** + * construct from user parameters + covariance (before minimization) + * @param nrow + * @param cov + */ + internal constructor(par: DoubleArray, cov: DoubleArray, nrow: Int) { + theValid = true + theCovarianceValid = true + theCovariance = MnUserCovariance(cov, nrow) + theGlobalCC = MnGlobalCorrelationCoeff() + theIntParameters = java.util.ArrayList(par.size) + theIntCovariance = MnUserCovariance(cov, nrow) + val err = DoubleArray(par.size) + for (i in par.indices) { + assert(theCovariance[i, i] > 0.0) + err[i] = sqrt(theCovariance[i, i]) + theIntParameters.add(par[i]) + } + theParameters = MnUserParameters(par, err) + assert(theCovariance.nrow() === variableParameters()) + } + + internal constructor(par: DoubleArray, cov: MnUserCovariance) { + theValid = true + theCovarianceValid = true + theCovariance = cov + theGlobalCC = MnGlobalCorrelationCoeff() + theIntParameters = java.util.ArrayList(par.size) + theIntCovariance = cov.copy() + require(!(theCovariance.nrow() !== variableParameters())) { "Bad covariance size" } + val err = DoubleArray(par.size) + for (i in par.indices) { + require(theCovariance[i, i] > 0.0) { "Bad covariance" } + err[i] = sqrt(theCovariance[i, i]) + theIntParameters.add(par[i]) + } + theParameters = MnUserParameters(par, err) + } + + internal constructor(par: MnUserParameters, cov: MnUserCovariance) { + theValid = true + theCovarianceValid = true + theParameters = par + theCovariance = cov + theGlobalCC = MnGlobalCorrelationCoeff() + theIntParameters = java.util.ArrayList() + theIntCovariance = cov.copy() + theIntCovariance.scale(0.5) + val i = 0 + for (ipar in par.parameters()) { + if (ipar.isConst() || ipar.isFixed()) { + continue + } + if (ipar.hasLimits()) { + theIntParameters.add(ext2int(ipar.number(), ipar.value())) + } else { + theIntParameters.add(ipar.value()) + } + } + assert(theCovariance.nrow() === variableParameters()) + } + + /** + * construct from internal parameters (after minimization) + * @param trafo + * @param up + */ + internal constructor(st: MinimumState, up: Double, trafo: MnUserTransformation) { + theValid = st.isValid() + theCovarianceValid = false + theGCCValid = false + theFVal = st.fval() + theEDM = st.edm() + theNFcn = st.nfcn() + theParameters = MnUserParameters() + theCovariance = MnUserCovariance() + theGlobalCC = MnGlobalCorrelationCoeff() + theIntParameters = java.util.ArrayList() + theIntCovariance = MnUserCovariance() + for (ipar in trafo.parameters()) { + if (ipar.isConst()) { + add(ipar.name(), ipar.value()) + } else if (ipar.isFixed()) { + add(ipar.name(), ipar.value(), ipar.error()) + if (ipar.hasLimits()) { + if (ipar.hasLowerLimit() && ipar.hasUpperLimit()) { + setLimits(ipar.name(), ipar.lowerLimit(), ipar.upperLimit()) + } else if (ipar.hasLowerLimit() && !ipar.hasUpperLimit()) { + setLowerLimit(ipar.name(), ipar.lowerLimit()) + } else { + setUpperLimit(ipar.name(), ipar.upperLimit()) + } + } + fix(ipar.name()) + } else if (ipar.hasLimits()) { + val i: Int = trafo.intOfExt(ipar.number()) + val err: Double = if (st.hasCovariance()) sqrt(2.0 * up * st.error().invHessian()[i, i]) else st.parameters().dirin().getEntry(i) + add(ipar.name(), + trafo.int2ext(i, st.vec().getEntry(i)), + trafo.int2extError(i, st.vec().getEntry(i), err)) + if (ipar.hasLowerLimit() && ipar.hasUpperLimit()) { + setLimits(ipar.name(), ipar.lowerLimit(), ipar.upperLimit()) + } else if (ipar.hasLowerLimit() && !ipar.hasUpperLimit()) { + setLowerLimit(ipar.name(), ipar.lowerLimit()) + } else { + setUpperLimit(ipar.name(), ipar.upperLimit()) + } + } else { + val i: Int = trafo.intOfExt(ipar.number()) + val err: Double = if (st.hasCovariance()) sqrt(2.0 * up * st.error().invHessian()[i, i]) else st.parameters().dirin().getEntry(i) + add(ipar.name(), st.vec().getEntry(i), err) + } + } + theCovarianceValid = st.error().isValid() + if (theCovarianceValid) { + theCovariance = trafo.int2extCovariance(st.vec(), st.error().invHessian()) + theIntCovariance = MnUserCovariance(st.error().invHessian().data().clone(), st.error().invHessian().nrow()) + theCovariance.scale(2.0 * up) + theGlobalCC = MnGlobalCorrelationCoeff(st.error().invHessian()) + theGCCValid = true + assert(theCovariance.nrow() === variableParameters()) + } + } + + /** + * add free parameter name, value, error + * + * @param err a double. + * @param val a double. + * @param name a [String] object. + */ + fun add(name: String, `val`: Double, err: Double) { + theParameters.add(name, `val`, err) + theIntParameters.add(`val`) + theCovarianceValid = false + theGCCValid = false + theValid = true + } + + /** + * add limited parameter name, value, lower bound, upper bound + * + * @param name a [String] object. + * @param val a double. + * @param low a double. + * @param err a double. + * @param up a double. + */ + fun add(name: String, `val`: Double, err: Double, low: Double, up: Double) { + theParameters.add(name, `val`, err, low, up) + theCovarianceValid = false + theIntParameters.add(ext2int(index(name), `val`)) + theGCCValid = false + theValid = true + } + + /** + * add const parameter name, value + * + * @param name a [String] object. + * @param val a double. + */ + fun add(name: String, `val`: Double) { + theParameters.add(name, `val`) + theValid = true + } + + /** + * + * copy. + * + * @return a [hep.dataforge.MINUIT.MnUserParameterState] object. + */ + fun copy(): MnUserParameterState { + return MnUserParameterState(this) + } + + /** + * Covariance matrix in the external representation + * + * @return a [hep.dataforge.MINUIT.MnUserCovariance] object. + */ + fun covariance(): MnUserCovariance { + return theCovariance + } + + /** + * Returns the expected vertival distance to the minimum (EDM) + * + * @return a double. + */ + fun edm(): Double { + return theEDM + } + + /** + * + * error. + * + * @param index a int. + * @return a double. + */ + fun error(index: Int): Double { + return theParameters.error(index) + } + + /** + * + * error. + * + * @param name a [String] object. + * @return a double. + */ + fun error(name: String?): Double { + return error(index(name)) + } + + /** + * + * errors. + * + * @return an array of double. + */ + fun errors(): DoubleArray { + return theParameters.errors() + } + + fun ext2int(i: Int, `val`: Double): Double { + return theParameters.trafo().ext2int(i, `val`) + } + + /** + * + * extOfInt. + * + * @param internal a int. + * @return a int. + */ + fun extOfInt(internal: Int): Int { + return theParameters.trafo().extOfInt(internal) + } + /// interaction via external number of parameter + /** + * + * fix. + * + * @param e a int. + */ + fun fix(e: Int) { + val i = intOfExt(e) + if (theCovarianceValid) { + theCovariance = MnCovarianceSqueeze.squeeze(theCovariance, i) + theIntCovariance = MnCovarianceSqueeze.squeeze(theIntCovariance, i) + } + theIntParameters.removeAt(i) + theParameters.fix(e) + theGCCValid = false + } + /// interaction via name of parameter + /** + * + * fix. + * + * @param name a [String] object. + */ + fun fix(name: String?) { + fix(index(name)) + } + + /** + * returns the function value at the minimum + * + * @return a double. + */ + fun fval(): Double { + return theFVal + } + + /** + * transformation internal <-> external + * @return + */ + fun getTransformation(): MnUserTransformation { + return theParameters.trafo() + } + + fun globalCC(): MnGlobalCorrelationCoeff? { + return theGlobalCC + } + + /** + * Returns + * true if the the state has a valid covariance, + * false otherwise. + * + * @return a boolean. + */ + fun hasCovariance(): Boolean { + return theCovarianceValid + } + + /** + * + * hasGlobalCC. + * + * @return a boolean. + */ + fun hasGlobalCC(): Boolean { + return theGCCValid + } + + /** + * convert name into external number of parameter + * + * @param name a [String] object. + * @return a int. + */ + fun index(name: String?): Int { + return theParameters.index(name) + } + + // transformation internal <-> external + fun int2ext(i: Int, `val`: Double): Double { + return theParameters.trafo().int2ext(i, `val`) + } + + fun intCovariance(): MnUserCovariance { + return theIntCovariance + } + + fun intOfExt(ext: Int): Int { + return theParameters.trafo().intOfExt(ext) + } + + /** + * Minuit internal representation + * @return + */ + fun intParameters(): List { + return theIntParameters + } + + /** + * Returns + * true if the the state is valid, + * false if not + * + * @return a boolean. + */ + fun isValid(): Boolean { + return theValid + } + + // facade: forward interface of MnUserParameters and MnUserTransformation + fun minuitParameters(): List { + return theParameters.parameters() + } + + /** + * convert external number into name of parameter + * + * @param index a int. + * @return a [String] object. + */ + fun name(index: Int): String { + return theParameters.name(index) + } + + /** + * Returns the number of function calls during the minimization. + * + * @return a int. + */ + fun nfcn(): Int { + return theNFcn + } + + fun parameter(i: Int): MinuitParameter { + return theParameters.parameter(i) + } + + //user external representation + fun parameters(): MnUserParameters { + return theParameters + } + + /** + * access to parameters and errors in column-wise representation + * + * @return an array of double. + */ + fun params(): DoubleArray { + return theParameters.params() + } + + /** + * + * precision. + * + * @return a [hep.dataforge.MINUIT.MnMachinePrecision] object. + */ + fun precision(): MnMachinePrecision { + return theParameters.precision() + } + + /** + * + * release. + * + * @param e a int. + */ + fun release(e: Int) { + theParameters.release(e) + theCovarianceValid = false + theGCCValid = false + val i = intOfExt(e) + if (parameter(e).hasLimits()) { + theIntParameters.add(i, ext2int(e, parameter(e).value())) + } else { + theIntParameters.add(i, parameter(e).value()) + } + } + + /** + * + * release. + * + * @param name a [String] object. + */ + fun release(name: String?) { + release(index(name)) + } + + /** + * + * removeLimits. + * + * @param e a int. + */ + fun removeLimits(e: Int) { + theParameters.removeLimits(e) + theCovarianceValid = false + theGCCValid = false + if (!parameter(e).isFixed() && !parameter(e).isConst()) { + theIntParameters[intOfExt(e)] = value(e) + } + } + + /** + * + * removeLimits. + * + * @param name a [String] object. + */ + fun removeLimits(name: String?) { + removeLimits(index(name)) + } + + /** + * + * setError. + * + * @param e a int. + * @param err a double. + * @param err a double. + */ + fun setError(e: Int, err: Double) { + theParameters.setError(e, err) + } + + /** + * + * setError. + * + * @param name a [String] object. + * @param err a double. + */ + fun setError(name: String?, err: Double) { + setError(index(name), err) + } + + /** + * + * setLimits. + * + * @param e a int. + * @param low a double. + * @param up a double. + */ + fun setLimits(e: Int, low: Double, up: Double) { + theParameters.setLimits(e, low, up) + theCovarianceValid = false + theGCCValid = false + if (!parameter(e).isFixed() && !parameter(e).isConst()) { + val i = intOfExt(e) + if (low < theIntParameters[i] && theIntParameters[i] < up) { + theIntParameters[i] = ext2int(e, theIntParameters[i]) + } else { + theIntParameters[i] = ext2int(e, 0.5 * (low + up)) + } + } + } + + /** + * + * setLimits. + * + * @param name a [String] object. + * @param low a double. + * @param up a double. + */ + fun setLimits(name: String?, low: Double, up: Double) { + setLimits(index(name), low, up) + } + + /** + * + * setLowerLimit. + * + * @param e a int. + * @param low a double. + */ + fun setLowerLimit(e: Int, low: Double) { + theParameters.setLowerLimit(e, low) + theCovarianceValid = false + theGCCValid = false + if (!parameter(e).isFixed() && !parameter(e).isConst()) { + val i = intOfExt(e) + if (low < theIntParameters[i]) { + theIntParameters[i] = ext2int(e, theIntParameters[i]) + } else { + theIntParameters[i] = ext2int(e, low + 0.5 * abs(low + 1.0)) + } + } + } + + /** + * + * setLowerLimit. + * + * @param name a [String] object. + * @param low a double. + */ + fun setLowerLimit(name: String?, low: Double) { + setLowerLimit(index(name), low) + } + + /** + * + * setPrecision. + * + * @param eps a double. + */ + fun setPrecision(eps: Double) { + theParameters.setPrecision(eps) + } + + /** + * + * setUpperLimit. + * + * @param e a int. + * @param up a double. + */ + fun setUpperLimit(e: Int, up: Double) { + theParameters.setUpperLimit(e, up) + theCovarianceValid = false + theGCCValid = false + if (!parameter(e).isFixed() && !parameter(e).isConst()) { + val i = intOfExt(e) + if (theIntParameters[i] < up) { + theIntParameters[i] = ext2int(e, theIntParameters[i]) + } else { + theIntParameters[i] = ext2int(e, up - 0.5 * abs(up + 1.0)) + } + } + } + + /** + * + * setUpperLimit. + * + * @param name a [String] object. + * @param up a double. + */ + fun setUpperLimit(name: String?, up: Double) { + setUpperLimit(index(name), up) + } + + /** + * + * setValue. + * + * @param e a int. + * @param val a double. + */ + fun setValue(e: Int, `val`: Double) { + theParameters.setValue(e, `val`) + if (!parameter(e).isFixed() && !parameter(e).isConst()) { + val i = intOfExt(e) + if (parameter(e).hasLimits()) { + theIntParameters[i] = ext2int(e, `val`) + } else { + theIntParameters[i] = `val` + } + } + } + + /** + * + * setValue. + * + * @param name a [String] object. + * @param val a double. + */ + fun setValue(name: String?, `val`: Double) { + setValue(index(name), `val`) + } + + /** {@inheritDoc} */ + override fun toString(): String { + return MnPrint.toString(this) + } + + /** + * + * value. + * + * @param index a int. + * @return a double. + */ + fun value(index: Int): Double { + return theParameters.value(index) + } + + /** + * + * value. + * + * @param name a [String] object. + * @return a double. + */ + fun value(name: String?): Double { + return value(index(name)) + } + + /** + * + * variableParameters. + * + * @return a int. + */ + fun variableParameters(): Int { + return theParameters.variableParameters() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnUserParameters.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnUserParameters.kt new file mode 100644 index 000000000..9bac54b25 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnUserParameters.kt @@ -0,0 +1,402 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * API class for the user interaction with the parameters. Serves as input to + * the minimizer as well as output from it; users can interact: fix/release + * parameters, set values and errors, etc.; parameters can be accessed via their + * parameter number or via their user-specified name. + * + * @version $Id$ + * @author Darksnake + */ +class MnUserParameters { + private var theTransformation: MnUserTransformation + + /** + * Creates a new instance of MnUserParameters + */ + constructor() { + theTransformation = MnUserTransformation() + } + + /** + * + * Constructor for MnUserParameters. + * + * @param par an array of double. + * @param err an array of double. + */ + constructor(par: DoubleArray, err: DoubleArray) { + theTransformation = MnUserTransformation(par, err) + } + + private constructor(other: MnUserParameters) { + theTransformation = other.theTransformation.copy() + } + + /** + * Add free parameter name, value, error + * + * + * When adding parameters, MINUIT assigns indices to each parameter which + * will be the same as in the double[] in the + * MultiFunction.valueOf(). That means the first parameter the user + * adds gets index 0, the second index 1, and so on. When calculating the + * function value inside FCN, MINUIT will call + * MultiFunction.valueOf() with the elements at their respective + * positions. + * + * @param err a double. + * @param val a double. + * @param name a [String] object. + */ + fun add(name: String, `val`: Double, err: Double) { + theTransformation.add(name, `val`, err) + } + + /** + * Add limited parameter name, value, lower bound, upper bound + * + * @param up a double. + * @param low a double. + * @param name a [String] object. + * @param val a double. + * @param err a double. + */ + fun add(name: String, `val`: Double, err: Double, low: Double, up: Double) { + theTransformation.add(name, `val`, err, low, up) + } + + /** + * Add const parameter name, value + * + * @param name a [String] object. + * @param val a double. + */ + fun add(name: String, `val`: Double) { + theTransformation.add(name, `val`) + } + + /** + * + * copy. + * + * @return a [hep.dataforge.MINUIT.MnUserParameters] object. + */ + fun copy(): MnUserParameters { + return MnUserParameters(this) + } + + /** + * + * error. + * + * @param index a int. + * @return a double. + */ + fun error(index: Int): Double { + return theTransformation.error(index) + } + + /** + * + * error. + * + * @param name a [String] object. + * @return a double. + */ + fun error(name: String?): Double { + return theTransformation.error(name) + } + + fun errors(): DoubleArray { + return theTransformation.errors() + } + /// interaction via external number of parameter + /** + * Fixes the specified parameter (so that the minimizer will no longer vary + * it) + * + * @param index a int. + */ + fun fix(index: Int) { + theTransformation.fix(index) + } + /// interaction via name of parameter + /** + * Fixes the specified parameter (so that the minimizer will no longer vary + * it) + * + * @param name a [String] object. + */ + fun fix(name: String?) { + theTransformation.fix(name) + } + + /** + * convert name into external number of parameter + * @param name + * @return + */ + fun index(name: String?): Int { + return theTransformation.index(name) + } + + /** + * convert external number into name of parameter + * @param index + * @return + */ + fun name(index: Int): String { + return theTransformation.name(index) + } + + /** + * access to single parameter + * @param index + * @return + */ + fun parameter(index: Int): MinuitParameter { + return theTransformation.parameter(index) + } + + /** + * access to parameters (row-wise) + * @return + */ + fun parameters(): List { + return theTransformation.parameters() + } + + /** + * access to parameters and errors in column-wise representation + * @return + */ + fun params(): DoubleArray { + return theTransformation.params() + } + + /** + * + * precision. + * + * @return a [hep.dataforge.MINUIT.MnMachinePrecision] object. + */ + fun precision(): MnMachinePrecision { + return theTransformation.precision() + } + + /** + * Releases the specified parameter (so that the minimizer can vary it) + * + * @param index a int. + */ + fun release(index: Int) { + theTransformation.release(index) + } + + /** + * Releases the specified parameter (so that the minimizer can vary it) + * + * @param name a [String] object. + */ + fun release(name: String?) { + theTransformation.release(name) + } + + /** + * + * removeLimits. + * + * @param index a int. + */ + fun removeLimits(index: Int) { + theTransformation.removeLimits(index) + } + + /** + * + * removeLimits. + * + * @param name a [String] object. + */ + fun removeLimits(name: String?) { + theTransformation.removeLimits(name) + } + + /** + * + * setError. + * + * @param index a int. + * @param err a double. + */ + fun setError(index: Int, err: Double) { + theTransformation.setError(index, err) + } + + /** + * + * setError. + * + * @param name a [String] object. + * @param err a double. + */ + fun setError(name: String?, err: Double) { + theTransformation.setError(name, err) + } + + /** + * Set the lower and upper bound on the specified variable. + * + * @param up a double. + * @param low a double. + * @param index a int. + */ + fun setLimits(index: Int, low: Double, up: Double) { + theTransformation.setLimits(index, low, up) + } + + /** + * Set the lower and upper bound on the specified variable. + * + * @param up a double. + * @param low a double. + * @param name a [String] object. + */ + fun setLimits(name: String?, low: Double, up: Double) { + theTransformation.setLimits(name, low, up) + } + + /** + * + * setLowerLimit. + * + * @param index a int. + * @param low a double. + */ + fun setLowerLimit(index: Int, low: Double) { + theTransformation.setLowerLimit(index, low) + } + + /** + * + * setLowerLimit. + * + * @param name a [String] object. + * @param low a double. + */ + fun setLowerLimit(name: String?, low: Double) { + theTransformation.setLowerLimit(name, low) + } + + /** + * + * setPrecision. + * + * @param eps a double. + */ + fun setPrecision(eps: Double) { + theTransformation.setPrecision(eps) + } + + /** + * + * setUpperLimit. + * + * @param index a int. + * @param up a double. + */ + fun setUpperLimit(index: Int, up: Double) { + theTransformation.setUpperLimit(index, up) + } + + /** + * + * setUpperLimit. + * + * @param name a [String] object. + * @param up a double. + */ + fun setUpperLimit(name: String?, up: Double) { + theTransformation.setUpperLimit(name, up) + } + + /** + * Set the value of parameter. The parameter in question may be variable, + * fixed, or constant, but must be defined. + * + * @param index a int. + * @param val a double. + */ + fun setValue(index: Int, `val`: Double) { + theTransformation.setValue(index, `val`) + } + + /** + * Set the value of parameter. The parameter in question may be variable, + * fixed, or constant, but must be defined. + * + * @param name a [String] object. + * @param val a double. + */ + fun setValue(name: String?, `val`: Double) { + theTransformation.setValue(name, `val`) + } + + /** {@inheritDoc} */ + override fun toString(): String { + return MnPrint.toString(this) + } + + fun trafo(): MnUserTransformation { + return theTransformation + } + + /** + * + * value. + * + * @param index a int. + * @return a double. + */ + fun value(index: Int): Double { + return theTransformation.value(index) + } + + /** + * + * value. + * + * @param name a [String] object. + * @return a double. + */ + fun value(name: String?): Double { + return theTransformation.value(name) + } + + /** + * + * variableParameters. + * + * @return a int. + */ + fun variableParameters(): Int { + return theTransformation.variableParameters() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnUserTransformation.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnUserTransformation.kt new file mode 100644 index 000000000..1066ac2da --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnUserTransformation.kt @@ -0,0 +1,390 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector + +/** + * knows how to andThen between user specified parameters (external) and + * internal parameters used for minimization + * + * Жуткий октопус, который занимается преобразованием внешних параметров во внутренние + * TODO по возможности отказаться от использования этого монстра + * @version $Id$ + */ +class MnUserTransformation { + private val nameMap: MutableMap = HashMap() + private var theCache: MutableList + private var theExtOfInt: MutableList + private var theParameters: MutableList + private var thePrecision: MnMachinePrecision + + constructor() { + thePrecision = MnMachinePrecision() + theParameters = java.util.ArrayList() + theExtOfInt = java.util.ArrayList() + theCache = java.util.ArrayList(0) + } + + private constructor(other: MnUserTransformation) { + thePrecision = other.thePrecision + theParameters = java.util.ArrayList(other.theParameters.size) + for (par in other.theParameters) { + theParameters.add(par.copy()) + } + theExtOfInt = java.util.ArrayList(other.theExtOfInt) + theCache = java.util.ArrayList(other.theCache) + } + + constructor(par: DoubleArray, err: DoubleArray) { + thePrecision = MnMachinePrecision() + theParameters = java.util.ArrayList(par.size) + theExtOfInt = java.util.ArrayList(par.size) + theCache = java.util.ArrayList(par.size) + for (i in par.indices) { + add("p$i", par[i], err[i]) + } + } + + /** + * add free parameter + * @param err + * @param val + */ + fun add(name: String, `val`: Double, err: Double) { + require(!nameMap.containsKey(name)) { "duplicate name: $name" } + nameMap[name] = theParameters.size + theExtOfInt.add(theParameters.size) + theCache.add(`val`) + theParameters.add(MinuitParameter(theParameters.size, name, `val`, err)) + } + + /** + * add limited parameter + * @param up + * @param low + */ + fun add(name: String, `val`: Double, err: Double, low: Double, up: Double) { + require(!nameMap.containsKey(name)) { "duplicate name: $name" } + nameMap[name] = theParameters.size + theExtOfInt.add(theParameters.size) + theCache.add(`val`) + theParameters.add(MinuitParameter(theParameters.size, name, `val`, err, low, up)) + } + + /** + * add parameter + * @param name + * @param val + */ + fun add(name: String, `val`: Double) { + require(!nameMap.containsKey(name)) { "duplicate name: $name" } + nameMap[name] = theParameters.size + theCache.add(`val`) + theParameters.add(MinuitParameter(theParameters.size, name, `val`)) + } + + /** + * + * copy. + * + * @return a [hep.dataforge.MINUIT.MnUserTransformation] object. + */ + fun copy(): MnUserTransformation { + return MnUserTransformation(this) + } + + fun dInt2Ext(i: Int, `val`: Double): Double { + var dd = 1.0 + val parm: MinuitParameter = theParameters[theExtOfInt[i]] + if (parm.hasLimits()) { + dd = if (parm.hasUpperLimit() && parm.hasLowerLimit()) { + theDoubleLimTrafo.dInt2Ext(`val`, + parm.upperLimit(), + parm.lowerLimit()) + } else if (parm.hasUpperLimit() && !parm.hasLowerLimit()) { + theUpperLimTrafo.dInt2Ext(`val`, parm.upperLimit()) + } else { + theLowerLimTrafo.dInt2Ext(`val`, parm.lowerLimit()) + } + } + return dd + } + + fun error(index: Int): Double { + return theParameters[index].error() + } + + fun error(name: String?): Double { + return error(index(name)) + } + + fun errors(): DoubleArray { + val result = DoubleArray(theParameters.size) + var i = 0 + for (parameter in theParameters) { + result[i++] = parameter.error() + } + return result + } + + fun ext2int(i: Int, `val`: Double): Double { + val parm: MinuitParameter = theParameters[i] + return if (parm.hasLimits()) { + if (parm.hasUpperLimit() && parm.hasLowerLimit()) { + theDoubleLimTrafo.ext2int(`val`, + parm.upperLimit(), + parm.lowerLimit(), + precision()) + } else if (parm.hasUpperLimit() && !parm.hasLowerLimit()) { + theUpperLimTrafo.ext2int(`val`, + parm.upperLimit(), + precision()) + } else { + theLowerLimTrafo.ext2int(`val`, + parm.lowerLimit(), + precision()) + } + } else `val` + } + + fun extOfInt(internal: Int): Int { + return theExtOfInt[internal] + } + + /** + * interaction via external number of parameter + * @param index + */ + fun fix(index: Int) { + val iind = intOfExt(index) + theExtOfInt.removeAt(iind) + theParameters[index].fix() + } + + /** + * interaction via name of parameter + * @param name + */ + fun fix(name: String?) { + fix(index(name)) + } + + /** + * convert name into external number of parameter + * @param name + * @return + */ + fun index(name: String?): Int { + return nameMap[name]!! + } + + fun int2ext(i: Int, `val`: Double): Double { + val parm: MinuitParameter = theParameters[theExtOfInt[i]] + return if (parm.hasLimits()) { + if (parm.hasUpperLimit() && parm.hasLowerLimit()) { + theDoubleLimTrafo.int2ext(`val`, + parm.upperLimit(), + parm.lowerLimit()) + } else if (parm.hasUpperLimit() && !parm.hasLowerLimit()) { + theUpperLimTrafo.int2ext(`val`, parm.upperLimit()) + } else { + theLowerLimTrafo.int2ext(`val`, parm.lowerLimit()) + } + } else `val` + } + + fun int2extCovariance(vec: RealVector, cov: MnAlgebraicSymMatrix): MnUserCovariance { + val result = MnUserCovariance(cov.nrow()) + for (i in 0 until vec.getDimension()) { + var dxdi = 1.0 + if (theParameters[theExtOfInt[i]].hasLimits()) { + dxdi = dInt2Ext(i, vec.getEntry(i)) + } + for (j in i until vec.getDimension()) { + var dxdj = 1.0 + if (theParameters[theExtOfInt[j]].hasLimits()) { + dxdj = dInt2Ext(j, vec.getEntry(j)) + } + result[i, j] = dxdi * cov[i, j] * dxdj + } + } + return result + } + + fun int2extError(i: Int, `val`: Double, err: Double): Double { + var dx = err + val parm: MinuitParameter = theParameters[theExtOfInt[i]] + if (parm.hasLimits()) { + val ui = int2ext(i, `val`) + var du1 = int2ext(i, `val` + dx) - ui + val du2 = int2ext(i, `val` - dx) - ui + if (parm.hasUpperLimit() && parm.hasLowerLimit()) { + if (dx > 1.0) { + du1 = parm.upperLimit() - parm.lowerLimit() + } + dx = 0.5 * (abs(du1) + abs(du2)) + } else { + dx = 0.5 * (abs(du1) + abs(du2)) + } + } + return dx + } + + fun intOfExt(ext: Int): Int { + for (iind in theExtOfInt.indices) { + if (ext == theExtOfInt[iind]) { + return iind + } + } + throw IllegalArgumentException("ext=$ext") + } + + /** + * convert external number into name of parameter + * @param index + * @return + */ + fun name(index: Int): String { + return theParameters[index].name() + } + + /** + * access to single parameter + * @param index + * @return + */ + fun parameter(index: Int): MinuitParameter { + return theParameters[index] + } + + fun parameters(): List { + return theParameters + } + + //access to parameters and errors in column-wise representation + fun params(): DoubleArray { + val result = DoubleArray(theParameters.size) + var i = 0 + for (parameter in theParameters) { + result[i++] = parameter.value() + } + return result + } + + fun precision(): MnMachinePrecision { + return thePrecision + } + + fun release(index: Int) { + require(!theExtOfInt.contains(index)) { "index=$index" } + theExtOfInt.add(index) + Collections.sort(theExtOfInt) + theParameters[index].release() + } + + fun release(name: String?) { + release(index(name)) + } + + fun removeLimits(index: Int) { + theParameters[index].removeLimits() + } + + fun removeLimits(name: String?) { + removeLimits(index(name)) + } + + fun setError(index: Int, err: Double) { + theParameters[index].setError(err) + } + + fun setError(name: String?, err: Double) { + setError(index(name), err) + } + + fun setLimits(index: Int, low: Double, up: Double) { + theParameters[index].setLimits(low, up) + } + + fun setLimits(name: String?, low: Double, up: Double) { + setLimits(index(name), low, up) + } + + fun setLowerLimit(index: Int, low: Double) { + theParameters[index].setLowerLimit(low) + } + + fun setLowerLimit(name: String?, low: Double) { + setLowerLimit(index(name), low) + } + + fun setPrecision(eps: Double) { + thePrecision.setPrecision(eps) + } + + fun setUpperLimit(index: Int, up: Double) { + theParameters[index].setUpperLimit(up) + } + + fun setUpperLimit(name: String?, up: Double) { + setUpperLimit(index(name), up) + } + + fun setValue(index: Int, `val`: Double) { + theParameters[index].setValue(`val`) + theCache[index] = `val` + } + + fun setValue(name: String?, `val`: Double) { + setValue(index(name), `val`) + } + + fun transform(pstates: RealVector): ArrayRealVector { + // FixMe: Worry about efficiency here + val result = ArrayRealVector(theCache.size) + for (i in 0 until result.getDimension()) { + result.setEntry(i, theCache[i]) + } + for (i in 0 until pstates.getDimension()) { + if (theParameters[theExtOfInt[i]].hasLimits()) { + result.setEntry(theExtOfInt[i], int2ext(i, pstates.getEntry(i))) + } else { + result.setEntry(theExtOfInt[i], pstates.getEntry(i)) + } + } + return result + } + + //forwarded interface + fun value(index: Int): Double { + return theParameters[index].value() + } + + fun value(name: String?): Double { + return value(index(name)) + } + + fun variableParameters(): Int { + return theExtOfInt.size + } + + companion object { + private val theDoubleLimTrafo: SinParameterTransformation = SinParameterTransformation() + private val theLowerLimTrafo: SqrtLowParameterTransformation = SqrtLowParameterTransformation() + private val theUpperLimTrafo: SqrtUpParameterTransformation = SqrtUpParameterTransformation() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/MnUtils.kt b/kmath-optimization/src/commonMain/tmp/minuit/MnUtils.kt new file mode 100644 index 000000000..d9f3e1bd5 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/MnUtils.kt @@ -0,0 +1,147 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector + +/** + * Utilities for operating on vectors and matrices + * + * @version $Id$ + */ +internal object MnUtils { + fun absoluteSumOfElements(m: MnAlgebraicSymMatrix): Double { + val data: DoubleArray = m.data() + var result = 0.0 + for (i in data.indices) { + result += abs(data[i]) + } + return result + } + + fun add(v1: RealVector, v2: RealVector?): RealVector { + return v1.add(v2) + } + + fun add(m1: MnAlgebraicSymMatrix, m2: MnAlgebraicSymMatrix): MnAlgebraicSymMatrix { + require(!(m1.size() !== m2.size())) { "Incompatible matrices" } + val result: MnAlgebraicSymMatrix = m1.copy() + val a: DoubleArray = result.data() + val b: DoubleArray = m2.data() + for (i in a.indices) { + a[i] += b[i] + } + return result + } + + fun div(m: MnAlgebraicSymMatrix?, scale: Double): MnAlgebraicSymMatrix { + return mul(m, 1 / scale) + } + + fun div(m: RealVector?, scale: Double): RealVector { + return mul(m, 1 / scale) + } + + fun innerProduct(v1: RealVector, v2: RealVector): Double { + require(!(v1.getDimension() !== v2.getDimension())) { "Incompatible vectors" } + var total = 0.0 + for (i in 0 until v1.getDimension()) { + total += v1.getEntry(i) * v2.getEntry(i) + } + return total + } + + fun mul(v1: RealVector, scale: Double): RealVector { + return v1.mapMultiply(scale) + } + + fun mul(m1: MnAlgebraicSymMatrix, scale: Double): MnAlgebraicSymMatrix { + val result: MnAlgebraicSymMatrix = m1.copy() + val a: DoubleArray = result.data() + for (i in a.indices) { + a[i] *= scale + } + return result + } + + fun mul(m1: MnAlgebraicSymMatrix, v1: RealVector): ArrayRealVector { + require(!(m1.nrow() !== v1.getDimension())) { "Incompatible arguments" } + val result = ArrayRealVector(m1.nrow()) + for (i in 0 until result.getDimension()) { + var total = 0.0 + for (k in 0 until result.getDimension()) { + total += m1[i, k] * v1.getEntry(k) + } + result.setEntry(i, total) + } + return result + } + + fun mul(m1: MnAlgebraicSymMatrix, m2: MnAlgebraicSymMatrix): MnAlgebraicSymMatrix { + require(!(m1.size() !== m2.size())) { "Incompatible matrices" } + val n: Int = m1.nrow() + val result = MnAlgebraicSymMatrix(n) + for (i in 0 until n) { + for (j in 0..i) { + var total = 0.0 + for (k in 0 until n) { + total += m1[i, k] * m2[k, j] + } + result[i, j] = total + } + } + return result + } + + fun outerProduct(v2: RealVector): MnAlgebraicSymMatrix { + // Fixme: check this. I am assuming this is just an outer-product of vector + // with itself. + val n: Int = v2.getDimension() + val result = MnAlgebraicSymMatrix(n) + val data: DoubleArray = v2.toArray() + for (i in 0 until n) { + for (j in 0..i) { + result[i, j] = data[i] * data[j] + } + } + return result + } + + fun similarity(avec: RealVector, mat: MnAlgebraicSymMatrix): Double { + val n: Int = avec.getDimension() + val tmp: RealVector = mul(mat, avec) + var result = 0.0 + for (i in 0 until n) { + result += tmp.getEntry(i) * avec.getEntry(i) + } + return result + } + + fun sub(v1: RealVector, v2: RealVector?): RealVector { + return v1.subtract(v2) + } + + fun sub(m1: MnAlgebraicSymMatrix, m2: MnAlgebraicSymMatrix): MnAlgebraicSymMatrix { + require(!(m1.size() !== m2.size())) { "Incompatible matrices" } + val result: MnAlgebraicSymMatrix = m1.copy() + val a: DoubleArray = result.data() + val b: DoubleArray = m2.data() + for (i in a.indices) { + a[i] -= b[i] + } + return result + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/ModularFunctionMinimizer.kt b/kmath-optimization/src/commonMain/tmp/minuit/ModularFunctionMinimizer.kt new file mode 100644 index 000000000..84130d24f --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/ModularFunctionMinimizer.kt @@ -0,0 +1,73 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import ru.inr.mass.maths.MultiFunction +import ru.inr.mass.minuit.* +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * + * @version $Id$ + */ +abstract class ModularFunctionMinimizer { + abstract fun builder(): MinimumBuilder + fun minimize( + fcn: MultiFunction?, + st: MnUserParameterState, + strategy: MnStrategy, + maxfcn: Int, + toler: Double, + errorDef: Double, + useAnalyticalGradient: Boolean, + checkGradient: Boolean + ): FunctionMinimum { + var maxfcn = maxfcn + val mfcn = MnUserFcn(fcn, errorDef, st.getTransformation()) + val gc: GradientCalculator + var providesAllDerivs = true + /* + * Проверяем в явном виде, что все аналитические производные присутствуют + * TODO сделать возможность того, что часть производных задается аналитически, а часть численно + */for (i in 0 until fcn.getDimension()) { + if (!fcn.providesDeriv(i)) providesAllDerivs = false + } + gc = if (providesAllDerivs && useAnalyticalGradient) { + AnalyticalGradientCalculator(fcn, st.getTransformation(), checkGradient) + } else { + Numerical2PGradientCalculator(mfcn, st.getTransformation(), strategy) + } + val npar: Int = st.variableParameters() + if (maxfcn == 0) { + maxfcn = 200 + 100 * npar + 5 * npar * npar + } + val mnseeds: MinimumSeed = seedGenerator().generate(mfcn, gc, st, strategy) + return minimize(mfcn, gc, mnseeds, strategy, maxfcn, toler) + } + + fun minimize( + mfcn: MnFcn, + gc: GradientCalculator?, + seed: MinimumSeed?, + strategy: MnStrategy?, + maxfcn: Int, + toler: Double + ): FunctionMinimum { + return builder().minimum(mfcn, gc, seed, strategy, maxfcn, toler * mfcn.errorDef()) + } + + abstract fun seedGenerator(): MinimumSeedGenerator +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/NegativeG2LineSearch.kt b/kmath-optimization/src/commonMain/tmp/minuit/NegativeG2LineSearch.kt new file mode 100644 index 000000000..2e9ce5813 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/NegativeG2LineSearch.kt @@ -0,0 +1,80 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector +import ru.inr.mass.minuit.* + +/** + * In case that one of the components of the second derivative g2 calculated by + * the numerical gradient calculator is negative, a 1dim line search in the + * direction of that component is done in order to find a better position where + * g2 is again positive. + * + * @version $Id$ + */ +internal object NegativeG2LineSearch { + fun hasNegativeG2(grad: FunctionGradient, prec: MnMachinePrecision): Boolean { + for (i in 0 until grad.getGradient().getDimension()) { + if (grad.getGradientDerivative().getEntry(i) < prec.eps2()) { + return true + } + } + return false + } + + fun search(fcn: MnFcn, st: MinimumState, gc: GradientCalculator, prec: MnMachinePrecision): MinimumState { + val negG2 = hasNegativeG2(st.gradient(), prec) + if (!negG2) { + return st + } + val n: Int = st.parameters().vec().getDimension() + var dgrad: FunctionGradient = st.gradient() + var pa: MinimumParameters = st.parameters() + var iterate = false + var iter = 0 + do { + iterate = false + for (i in 0 until n) { + if (dgrad.getGradientDerivative().getEntry(i) < prec.eps2()) { + // do line search if second derivative negative + var step: RealVector = ArrayRealVector(n) + step.setEntry(i, dgrad.getStep().getEntry(i) * dgrad.getGradient().getEntry(i)) + if (abs(dgrad.getGradient().getEntry(i)) > prec.eps2()) { + step.setEntry(i, + step.getEntry(i) * (-1.0 / abs(dgrad.getGradient().getEntry(i)))) + } + val gdel: Double = step.getEntry(i) * dgrad.getGradient().getEntry(i) + val pp: MnParabolaPoint = MnLineSearch.search(fcn, pa, step, gdel, prec) + step = MnUtils.mul(step, pp.x()) + pa = MinimumParameters(MnUtils.add(pa.vec(), step), pp.y()) + dgrad = gc.gradient(pa, dgrad) + iterate = true + break + } + } + } while (iter++ < 2 * n && iterate) + val mat = MnAlgebraicSymMatrix(n) + for (i in 0 until n) { + mat[i, i] = if (abs(dgrad.getGradientDerivative() + .getEntry(i)) > prec.eps2() + ) 1.0 / dgrad.getGradientDerivative().getEntry(i) else 1.0 + } + val err = MinimumError(mat, 1.0) + val edm: Double = VariableMetricEDMEstimator().estimate(dgrad, err) + return MinimumState(pa, err, dgrad, edm, fcn.numOfCalls()) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/Numerical2PGradientCalculator.kt b/kmath-optimization/src/commonMain/tmp/minuit/Numerical2PGradientCalculator.kt new file mode 100644 index 000000000..efa1d57af --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/Numerical2PGradientCalculator.kt @@ -0,0 +1,122 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.RealVector +import ru.inr.mass.minuit.* + +/** + * + * @version $Id$ + */ +internal class Numerical2PGradientCalculator(fcn: MnFcn, par: MnUserTransformation, stra: MnStrategy) : + GradientCalculator { + private val theFcn: MnFcn = fcn + private val theStrategy: MnStrategy + private val theTransformation: MnUserTransformation + fun fcn(): MnFcn { + return theFcn + } + + fun gradTolerance(): Double { + return strategy().gradientTolerance() + } + + /** {@inheritDoc} */ + fun gradient(par: MinimumParameters): FunctionGradient { + val gc = InitialGradientCalculator(theFcn, theTransformation, theStrategy) + val gra: FunctionGradient = gc.gradient(par) + return gradient(par, gra) + } + + /** {@inheritDoc} */ + fun gradient(par: MinimumParameters, gradient: FunctionGradient): FunctionGradient { + require(par.isValid()) { "Parameters are invalid" } + val x: RealVector = par.vec().copy() + val fcnmin: Double = par.fval() + val dfmin: Double = 8.0 * precision().eps2() * (abs(fcnmin) + theFcn.errorDef()) + val vrysml: Double = 8.0 * precision().eps() * precision().eps() + val n: Int = x.getDimension() + val grd: RealVector = gradient.getGradient().copy() + val g2: RealVector = gradient.getGradientDerivative().copy() + val gstep: RealVector = gradient.getStep().copy() + for (i in 0 until n) { + val xtf: Double = x.getEntry(i) + val epspri: Double = precision().eps2() + abs(grd.getEntry(i) * precision().eps2()) + var stepb4 = 0.0 + for (j in 0 until ncycle()) { + val optstp: Double = sqrt(dfmin / (abs(g2.getEntry(i)) + epspri)) + var step: Double = max(optstp, abs(0.1 * gstep.getEntry(i))) + if (trafo().parameter(trafo().extOfInt(i)).hasLimits()) { + if (step > 0.5) { + step = 0.5 + } + } + val stpmax: Double = 10.0 * abs(gstep.getEntry(i)) + if (step > stpmax) { + step = stpmax + } + val stpmin: Double = + max(vrysml, 8.0 * abs(precision().eps2() * x.getEntry(i))) + if (step < stpmin) { + step = stpmin + } + if (abs((step - stepb4) / step) < stepTolerance()) { + break + } + gstep.setEntry(i, step) + stepb4 = step + x.setEntry(i, xtf + step) + val fs1: Double = theFcn.value(x) + x.setEntry(i, xtf - step) + val fs2: Double = theFcn.value(x) + x.setEntry(i, xtf) + val grdb4: Double = grd.getEntry(i) + grd.setEntry(i, 0.5 * (fs1 - fs2) / step) + g2.setEntry(i, (fs1 + fs2 - 2.0 * fcnmin) / step / step) + if (abs(grdb4 - grd.getEntry(i)) / (abs(grd.getEntry(i)) + dfmin / step) < gradTolerance()) { + break + } + } + } + return FunctionGradient(grd, g2, gstep) + } + + fun ncycle(): Int { + return strategy().gradientNCycles() + } + + fun precision(): MnMachinePrecision { + return theTransformation.precision() + } + + fun stepTolerance(): Double { + return strategy().gradientStepTolerance() + } + + fun strategy(): MnStrategy { + return theStrategy + } + + fun trafo(): MnUserTransformation { + return theTransformation + } + + init { + theTransformation = par + theStrategy = stra + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/ScanBuilder.kt b/kmath-optimization/src/commonMain/tmp/minuit/ScanBuilder.kt new file mode 100644 index 000000000..57f910a26 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/ScanBuilder.kt @@ -0,0 +1,59 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector +import ru.inr.mass.minuit.* +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * Performs a minimization using the simplex method of Nelder and Mead (ref. + * Comp. J. 7, 308 (1965)). + * + * @version $Id$ + */ +internal class ScanBuilder : MinimumBuilder { + /** {@inheritDoc} */ + fun minimum( + mfcn: MnFcn, + gc: GradientCalculator?, + seed: MinimumSeed, + stra: MnStrategy?, + maxfcn: Int, + toler: Double + ): FunctionMinimum { + val x: RealVector = seed.parameters().vec().copy() + val upst = MnUserParameterState(seed.state(), mfcn.errorDef(), seed.trafo()) + val scan = MnParameterScan(mfcn.fcn(), upst.parameters(), seed.fval()) + var amin: Double = scan.fval() + val n: Int = seed.trafo().variableParameters() + val dirin: RealVector = ArrayRealVector(n) + for (i in 0 until n) { + val ext: Int = seed.trafo().extOfInt(i) + scan.scan(ext) + if (scan.fval() < amin) { + amin = scan.fval() + x.setEntry(i, seed.trafo().ext2int(ext, scan.parameters().value(ext))) + } + dirin.setEntry(i, sqrt(2.0 * mfcn.errorDef() * seed.error().invHessian()[i, i])) + } + val mp = MinimumParameters(x, dirin, amin) + val st = MinimumState(mp, 0.0, mfcn.numOfCalls()) + val states: MutableList = java.util.ArrayList(1) + states.add(st) + return FunctionMinimum(seed, states, mfcn.errorDef()) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/ScanMinimizer.kt b/kmath-optimization/src/commonMain/tmp/minuit/ScanMinimizer.kt new file mode 100644 index 000000000..e39a49c0d --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/ScanMinimizer.kt @@ -0,0 +1,36 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal class ScanMinimizer : ModularFunctionMinimizer() { + private val theBuilder: ScanBuilder + private val theSeedGenerator: SimplexSeedGenerator = SimplexSeedGenerator() + override fun builder(): MinimumBuilder { + return theBuilder + } + + override fun seedGenerator(): MinimumSeedGenerator { + return theSeedGenerator + } + + init { + theBuilder = ScanBuilder() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/SimplexBuilder.kt b/kmath-optimization/src/commonMain/tmp/minuit/SimplexBuilder.kt new file mode 100644 index 000000000..0b10155ff --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/SimplexBuilder.kt @@ -0,0 +1,180 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import space.kscience.kmath.optimization.minuit.MINUITPlugin +import ru.inr.mass.minuit.* +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * + * @version $Id$ + */ +internal class SimplexBuilder : MinimumBuilder { + /** {@inheritDoc} */ + fun minimum( + mfcn: MnFcn, + gc: GradientCalculator?, + seed: MinimumSeed, + strategy: MnStrategy?, + maxfcn: Int, + minedm: Double + ): FunctionMinimum { + val prec: MnMachinePrecision = seed.precision() + val x: RealVector = seed.parameters().vec().copy() + val step: RealVector = MnUtils.mul(seed.gradient().getStep(), 10.0) + val n: Int = x.getDimension() + val wg = 1.0 / n + val alpha = 1.0 + val beta = 0.5 + val gamma = 2.0 + val rhomin = 4.0 + val rhomax = 8.0 + val rho1 = 1.0 + alpha + val rho2 = 1.0 + alpha * gamma + val simpl: MutableList> = java.util.ArrayList>(n + 1) + simpl.add(Pair(seed.fval(), x.copy())) + var jl = 0 + var jh = 0 + var amin: Double = seed.fval() + var aming: Double = seed.fval() + for (i in 0 until n) { + val dmin: Double = 8.0 * prec.eps2() * (abs(x.getEntry(i)) + prec.eps2()) + if (step.getEntry(i) < dmin) { + step.setEntry(i, dmin) + } + x.setEntry(i, x.getEntry(i) + step.getEntry(i)) + val tmp: Double = mfcn.value(x) + if (tmp < amin) { + amin = tmp + jl = i + 1 + } + if (tmp > aming) { + aming = tmp + jh = i + 1 + } + simpl.add(Pair(tmp, x.copy())) + x.setEntry(i, x.getEntry(i) - step.getEntry(i)) + } + val simplex = SimplexParameters(simpl, jh, jl) + do { + amin = simplex[jl].getFirst() + jl = simplex.jl() + jh = simplex.jh() + var pbar: RealVector = ArrayRealVector(n) + for (i in 0 until n + 1) { + if (i == jh) { + continue + } + pbar = MnUtils.add(pbar, MnUtils.mul(simplex[i].getSecond(), wg)) + } + val pstar: RealVector = + MnUtils.sub(MnUtils.mul(pbar, 1.0 + alpha), MnUtils.mul(simplex[jh].getSecond(), alpha)) + val ystar: Double = mfcn.value(pstar) + if (ystar > amin) { + if (ystar < simplex[jh].getFirst()) { + simplex.update(ystar, pstar) + if (jh != simplex.jh()) { + continue + } + } + val pstst: RealVector = + MnUtils.add(MnUtils.mul(simplex[jh].getSecond(), beta), MnUtils.mul(pbar, 1.0 - beta)) + val ystst: Double = mfcn.value(pstst) + if (ystst > simplex[jh].getFirst()) { + break + } + simplex.update(ystst, pstst) + continue + } + var pstst: RealVector = MnUtils.add(MnUtils.mul(pstar, gamma), MnUtils.mul(pbar, 1.0 - gamma)) + var ystst: Double = mfcn.value(pstst) + val y1: Double = (ystar - simplex[jh].getFirst()) * rho2 + val y2: Double = (ystst - simplex[jh].getFirst()) * rho1 + var rho = 0.5 * (rho2 * y1 - rho1 * y2) / (y1 - y2) + if (rho < rhomin) { + if (ystst < simplex[jl].getFirst()) { + simplex.update(ystst, pstst) + } else { + simplex.update(ystar, pstar) + } + continue + } + if (rho > rhomax) { + rho = rhomax + } + val prho: RealVector = + MnUtils.add(MnUtils.mul(pbar, rho), MnUtils.mul(simplex[jh].getSecond(), 1.0 - rho)) + val yrho: Double = mfcn.value(prho) + if (yrho < simplex[jl].getFirst() && yrho < ystst) { + simplex.update(yrho, prho) + continue + } + if (ystst < simplex[jl].getFirst()) { + simplex.update(ystst, pstst) + continue + } + if (yrho > simplex[jl].getFirst()) { + if (ystst < simplex[jl].getFirst()) { + simplex.update(ystst, pstst) + } else { + simplex.update(ystar, pstar) + } + continue + } + if (ystar > simplex[jh].getFirst()) { + pstst = MnUtils.add(MnUtils.mul(simplex[jh].getSecond(), beta), MnUtils.mul(pbar, 1 - beta)) + ystst = mfcn.value(pstst) + if (ystst > simplex[jh].getFirst()) { + break + } + simplex.update(ystst, pstst) + } + } while (simplex.edm() > minedm && mfcn.numOfCalls() < maxfcn) + amin = simplex[jl].getFirst() + jl = simplex.jl() + jh = simplex.jh() + var pbar: RealVector = ArrayRealVector(n) + for (i in 0 until n + 1) { + if (i == jh) { + continue + } + pbar = MnUtils.add(pbar, MnUtils.mul(simplex[i].getSecond(), wg)) + } + var ybar: Double = mfcn.value(pbar) + if (ybar < amin) { + simplex.update(ybar, pbar) + } else { + pbar = simplex[jl].getSecond() + ybar = simplex[jl].getFirst() + } + var dirin: RealVector = simplex.dirin() + // scale to sigmas on parameters werr^2 = dirin^2 * (up/edm) + dirin = MnUtils.mul(dirin, sqrt(mfcn.errorDef() / simplex.edm())) + val st = MinimumState(MinimumParameters(pbar, dirin, ybar), simplex.edm(), mfcn.numOfCalls()) + val states: MutableList = java.util.ArrayList(1) + states.add(st) + if (mfcn.numOfCalls() > maxfcn) { + MINUITPlugin.logStatic("Simplex did not converge, #fcn calls exhausted.") + return FunctionMinimum(seed, states, mfcn.errorDef(), MnReachedCallLimit()) + } + if (simplex.edm() > minedm) { + MINUITPlugin.logStatic("Simplex did not converge, edm > minedm.") + return FunctionMinimum(seed, states, mfcn.errorDef(), MnAboveMaxEdm()) + } + return FunctionMinimum(seed, states, mfcn.errorDef()) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/SimplexMinimizer.kt b/kmath-optimization/src/commonMain/tmp/minuit/SimplexMinimizer.kt new file mode 100644 index 000000000..f4bbcc320 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/SimplexMinimizer.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal class SimplexMinimizer : ModularFunctionMinimizer() { + private val theBuilder: SimplexBuilder + private val theSeedGenerator: SimplexSeedGenerator = SimplexSeedGenerator() + + /** {@inheritDoc} */ + override fun builder(): MinimumBuilder { + return theBuilder + } + + /** {@inheritDoc} */ + override fun seedGenerator(): MinimumSeedGenerator { + return theSeedGenerator + } + + /** + * + * Constructor for SimplexMinimizer. + */ + init { + theBuilder = SimplexBuilder() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/SimplexParameters.kt b/kmath-optimization/src/commonMain/tmp/minuit/SimplexParameters.kt new file mode 100644 index 000000000..fef6e2010 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/SimplexParameters.kt @@ -0,0 +1,85 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector + +/** + * + * @version $Id$ + */ +internal class SimplexParameters(simpl: MutableList>, jh: Int, jl: Int) { + private var theJHigh: Int + private var theJLow: Int + private val theSimplexParameters: MutableList> + fun dirin(): ArrayRealVector { + val dirin = ArrayRealVector(theSimplexParameters.size - 1) + for (i in 0 until theSimplexParameters.size - 1) { + var pbig: Double = theSimplexParameters[0].getSecond().getEntry(i) + var plit = pbig + for (theSimplexParameter in theSimplexParameters) { + if (theSimplexParameter.getSecond().getEntry(i) < plit) { + plit = theSimplexParameter.getSecond().getEntry(i) + } + if (theSimplexParameter.getSecond().getEntry(i) > pbig) { + pbig = theSimplexParameter.getSecond().getEntry(i) + } + } + dirin.setEntry(i, pbig - plit) + } + return dirin + } + + fun edm(): Double { + return theSimplexParameters[jh()].getFirst() - theSimplexParameters[jl()].getFirst() + } + + operator fun get(i: Int): Pair { + return theSimplexParameters[i] + } + + fun jh(): Int { + return theJHigh + } + + fun jl(): Int { + return theJLow + } + + fun simplex(): List> { + return theSimplexParameters + } + + fun update(y: Double, p: RealVector?) { + theSimplexParameters.set(jh(), Pair(y, p)) + if (y < theSimplexParameters[jl()].getFirst()) { + theJLow = jh() + } + var jh = 0 + for (i in 1 until theSimplexParameters.size) { + if (theSimplexParameters[i].getFirst() > theSimplexParameters[jh].getFirst()) { + jh = i + } + } + theJHigh = jh + } + + init { + theSimplexParameters = simpl + theJHigh = jh + theJLow = jl + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/SimplexSeedGenerator.kt b/kmath-optimization/src/commonMain/tmp/minuit/SimplexSeedGenerator.kt new file mode 100644 index 000000000..577545fc3 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/SimplexSeedGenerator.kt @@ -0,0 +1,53 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import org.apache.commons.math3.linear.ArrayRealVector +import ru.inr.mass.minuit.* +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * + * @version $Id$ + */ +internal class SimplexSeedGenerator : MinimumSeedGenerator { + /** {@inheritDoc} */ + fun generate(fcn: MnFcn, gc: GradientCalculator?, st: MnUserParameterState, stra: MnStrategy): MinimumSeed { + val n: Int = st.variableParameters() + val prec: MnMachinePrecision = st.precision() + + // initial starting values + val x: RealVector = ArrayRealVector(n) + for (i in 0 until n) { + x.setEntry(i, st.intParameters()[i]) + } + val fcnmin: Double = fcn.value(x) + val pa = MinimumParameters(x, fcnmin) + val igc = InitialGradientCalculator(fcn, st.getTransformation(), stra) + val dgrad: FunctionGradient = igc.gradient(pa) + val mat = MnAlgebraicSymMatrix(n) + val dcovar = 1.0 + for (i in 0 until n) { + mat[i, i] = if (abs(dgrad.getGradientDerivative() + .getEntry(i)) > prec.eps2() + ) 1.0 / dgrad.getGradientDerivative().getEntry(i) else 1.0 + } + val err = MinimumError(mat, dcovar) + val edm: Double = VariableMetricEDMEstimator().estimate(dgrad, err) + val state = MinimumState(pa, err, dgrad, edm, fcn.numOfCalls()) + return MinimumSeed(state, st.getTransformation()) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/SinParameterTransformation.kt b/kmath-optimization/src/commonMain/tmp/minuit/SinParameterTransformation.kt new file mode 100644 index 000000000..821addef7 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/SinParameterTransformation.kt @@ -0,0 +1,48 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal class SinParameterTransformation { + fun dInt2Ext(value: Double, upper: Double, lower: Double): Double { + return 0.5 * abs((upper - lower) * cos(value)) + } + + fun ext2int(value: Double, upper: Double, lower: Double, prec: MnMachinePrecision): Double { + val piby2: Double = 2.0 * atan(1.0) + val distnn: Double = 8.0 * sqrt(prec.eps2()) + val vlimhi = piby2 - distnn + val vlimlo = -piby2 + distnn + val yy = 2.0 * (value - lower) / (upper - lower) - 1.0 + val yy2 = yy * yy + return if (yy2 > 1.0 - prec.eps2()) { + if (yy < 0.0) { + vlimlo + } else { + vlimhi + } + } else { + asin(yy) + } + } + + fun int2ext(value: Double, upper: Double, lower: Double): Double { + return lower + 0.5 * (upper - lower) * (sin(value) + 1.0) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/SqrtLowParameterTransformation.kt b/kmath-optimization/src/commonMain/tmp/minuit/SqrtLowParameterTransformation.kt new file mode 100644 index 000000000..444b63847 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/SqrtLowParameterTransformation.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal class SqrtLowParameterTransformation { + // derivative of transformation from internal to external + fun dInt2Ext(value: Double, lower: Double): Double { + return value / sqrt(value * value + 1.0) + } + + // transformation from external to internal + fun ext2int(value: Double, lower: Double, prec: MnMachinePrecision): Double { + val yy = value - lower + 1.0 + val yy2 = yy * yy + return if (yy2 < 1.0 + prec.eps2()) { + 8 * sqrt(prec.eps2()) + } else { + sqrt(yy2 - 1) + } + } + + // transformation from internal to external + fun int2ext(value: Double, lower: Double): Double { + return lower - 1.0 + sqrt(value * value + 1.0) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/SqrtUpParameterTransformation.kt b/kmath-optimization/src/commonMain/tmp/minuit/SqrtUpParameterTransformation.kt new file mode 100644 index 000000000..5774848bd --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/SqrtUpParameterTransformation.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal class SqrtUpParameterTransformation { + // derivative of transformation from internal to external + fun dInt2Ext(value: Double, upper: Double): Double { + return -value / sqrt(value * value + 1.0) + } + + // transformation from external to internal + fun ext2int(value: Double, upper: Double, prec: MnMachinePrecision): Double { + val yy = upper - value + 1.0 + val yy2 = yy * yy + return if (yy2 < 1.0 + prec.eps2()) { + 8 * sqrt(prec.eps2()) + } else { + sqrt(yy2 - 1) + } + } + + // transformation from internal to external + fun int2ext(value: Double, upper: Double): Double { + return upper + 1.0 - sqrt(value * value + 1.0) + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/VariableMetricBuilder.kt b/kmath-optimization/src/commonMain/tmp/minuit/VariableMetricBuilder.kt new file mode 100644 index 000000000..edc6783b6 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/VariableMetricBuilder.kt @@ -0,0 +1,138 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +import space.kscience.kmath.optimization.minuit.MINUITPlugin +import ru.inr.mass.minuit.* +import space.kscience.kmath.optimization.minuit.MinimumSeed + +/** + * + * @version $Id$ + */ +internal class VariableMetricBuilder : MinimumBuilder { + private val theErrorUpdator: DavidonErrorUpdator + private val theEstimator: VariableMetricEDMEstimator = VariableMetricEDMEstimator() + fun errorUpdator(): DavidonErrorUpdator { + return theErrorUpdator + } + + fun estimator(): VariableMetricEDMEstimator { + return theEstimator + } + + /** {@inheritDoc} */ + fun minimum( + fcn: MnFcn, + gc: GradientCalculator, + seed: MinimumSeed, + strategy: MnStrategy, + maxfcn: Int, + edmval: Double + ): FunctionMinimum { + val min: FunctionMinimum = minimum(fcn, gc, seed, maxfcn, edmval) + if (strategy.strategy() === 2 || strategy.strategy() === 1 && min.error().dcovar() > 0.05) { + val st: MinimumState = MnHesse(strategy).calculate(fcn, min.state(), min.seed().trafo(), 0) + min.add(st) + } + if (!min.isValid()) { + MINUITPlugin.logStatic("FunctionMinimum is invalid.") + } + return min + } + + fun minimum(fcn: MnFcn, gc: GradientCalculator, seed: MinimumSeed, maxfcn: Int, edmval: Double): FunctionMinimum { + var edmval = edmval + edmval *= 0.0001 + if (seed.parameters().vec().getDimension() === 0) { + return FunctionMinimum(seed, fcn.errorDef()) + } + val prec: MnMachinePrecision = seed.precision() + val result: MutableList = java.util.ArrayList(8) + var edm: Double = seed.state().edm() + if (edm < 0.0) { + MINUITPlugin.logStatic("VariableMetricBuilder: initial matrix not pos.def.") + if (seed.error().isPosDef()) { + throw RuntimeException("Something is wrong!") + } + return FunctionMinimum(seed, fcn.errorDef()) + } + result.add(seed.state()) + + // iterate until edm is small enough or max # of iterations reached + edm *= 1.0 + 3.0 * seed.error().dcovar() + var step: RealVector // = new ArrayRealVector(seed.gradient().getGradient().getDimension()); + do { + var s0: MinimumState = result[result.size - 1] + step = MnUtils.mul(MnUtils.mul(s0.error().invHessian(), s0.gradient().getGradient()), -1) + var gdel: Double = MnUtils.innerProduct(step, s0.gradient().getGradient()) + if (gdel > 0.0) { + MINUITPlugin.logStatic("VariableMetricBuilder: matrix not pos.def.") + MINUITPlugin.logStatic("gdel > 0: $gdel") + s0 = MnPosDef.test(s0, prec) + step = MnUtils.mul(MnUtils.mul(s0.error().invHessian(), s0.gradient().getGradient()), -1) + gdel = MnUtils.innerProduct(step, s0.gradient().getGradient()) + MINUITPlugin.logStatic("gdel: $gdel") + if (gdel > 0.0) { + result.add(s0) + return FunctionMinimum(seed, result, fcn.errorDef()) + } + } + val pp: MnParabolaPoint = MnLineSearch.search(fcn, s0.parameters(), step, gdel, prec) + if (abs(pp.y() - s0.fval()) < prec.eps()) { + MINUITPlugin.logStatic("VariableMetricBuilder: no improvement") + break //no improvement + } + val p = MinimumParameters(MnUtils.add(s0.vec(), MnUtils.mul(step, pp.x())), pp.y()) + val g: FunctionGradient = gc.gradient(p, s0.gradient()) + edm = estimator().estimate(g, s0.error()) + if (edm < 0.0) { + MINUITPlugin.logStatic("VariableMetricBuilder: matrix not pos.def.") + MINUITPlugin.logStatic("edm < 0") + s0 = MnPosDef.test(s0, prec) + edm = estimator().estimate(g, s0.error()) + if (edm < 0.0) { + result.add(s0) + return FunctionMinimum(seed, result, fcn.errorDef()) + } + } + val e: MinimumError = errorUpdator().update(s0, p, g) + result.add(MinimumState(p, e, g, edm, fcn.numOfCalls())) + // result[0] = MinimumState(p, e, g, edm, fcn.numOfCalls()); + edm *= 1.0 + 3.0 * e.dcovar() + } while (edm > edmval && fcn.numOfCalls() < maxfcn) + if (fcn.numOfCalls() >= maxfcn) { + MINUITPlugin.logStatic("VariableMetricBuilder: call limit exceeded.") + return FunctionMinimum(seed, result, fcn.errorDef(), MnReachedCallLimit()) + } + return if (edm > edmval) { + if (edm < abs(prec.eps2() * result[result.size - 1].fval())) { + MINUITPlugin.logStatic("VariableMetricBuilder: machine accuracy limits further improvement.") + FunctionMinimum(seed, result, fcn.errorDef()) + } else if (edm < 10.0 * edmval) { + FunctionMinimum(seed, result, fcn.errorDef()) + } else { + MINUITPlugin.logStatic("VariableMetricBuilder: finishes without convergence.") + MINUITPlugin.logStatic("VariableMetricBuilder: edm= $edm requested: $edmval") + FunctionMinimum(seed, result, fcn.errorDef(), MnAboveMaxEdm()) + } + } else FunctionMinimum(seed, result, fcn.errorDef()) + } + + init { + theErrorUpdator = DavidonErrorUpdator() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/VariableMetricEDMEstimator.kt b/kmath-optimization/src/commonMain/tmp/minuit/VariableMetricEDMEstimator.kt new file mode 100644 index 000000000..8fca4e6ee --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/VariableMetricEDMEstimator.kt @@ -0,0 +1,31 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @author tonyj + * @version $Id$ + */ +internal class VariableMetricEDMEstimator { + fun estimate(g: FunctionGradient, e: MinimumError): Double { + if (e.invHessian().size() === 1) { + return 0.5 * g.getGradient().getEntry(0) * g.getGradient().getEntry(0) * e.invHessian()[0, 0] + } + val rho: Double = MnUtils.similarity(g.getGradient(), e.invHessian()) + return 0.5 * rho + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/VariableMetricMinimizer.kt b/kmath-optimization/src/commonMain/tmp/minuit/VariableMetricMinimizer.kt new file mode 100644 index 000000000..2a13a5fff --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/VariableMetricMinimizer.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + +/** + * + * @version $Id$ + */ +internal class VariableMetricMinimizer : ModularFunctionMinimizer() { + private val theMinBuilder: VariableMetricBuilder + private val theMinSeedGen: MnSeedGenerator = MnSeedGenerator() + + /** {@inheritDoc} */ + override fun builder(): MinimumBuilder { + return theMinBuilder + } + + /** {@inheritDoc} */ + override fun seedGenerator(): MinimumSeedGenerator { + return theMinSeedGen + } + + /** + * + * Constructor for VariableMetricMinimizer. + */ + init { + theMinBuilder = VariableMetricBuilder() + } +} \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/tmp/minuit/package-info.kt b/kmath-optimization/src/commonMain/tmp/minuit/package-info.kt new file mode 100644 index 000000000..22779da86 --- /dev/null +++ b/kmath-optimization/src/commonMain/tmp/minuit/package-info.kt @@ -0,0 +1,17 @@ +/* + * Copyright 2015 Alexander Nozik. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ru.inr.mass.minuit + diff --git a/kmath-stat/build.gradle.kts b/kmath-stat/build.gradle.kts index e8f629f7a..41a1666f8 100644 --- a/kmath-stat/build.gradle.kts +++ b/kmath-stat/build.gradle.kts @@ -1,6 +1,6 @@ plugins { - kotlin("multiplatform") - id("ru.mipt.npm.gradle.common") + id("ru.mipt.npm.gradle.mpp") + id("ru.mipt.npm.gradle.native") } kscience { diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/Distribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/Distribution.kt index e3adcdc44..298bbc858 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/Distribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/Distribution.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.distributions @@ -19,10 +19,10 @@ public interface Distribution : Sampler { */ public fun probability(arg: T): Double - public override fun sample(generator: RandomGenerator): Chain + override fun sample(generator: RandomGenerator): Chain /** - * An empty companion. Distribution factories should be written as its extensions + * An empty companion. Distribution factories should be written as its extensions. */ public companion object } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt index dde429244..067b47796 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.distributions @@ -10,12 +10,12 @@ import space.kscience.kmath.chains.SimpleChain import space.kscience.kmath.stat.RandomGenerator /** - * A multivariate distribution which takes a map of parameters + * A multivariate distribution that takes a map of parameters. */ public interface NamedDistribution : Distribution> /** - * A multivariate distribution that has independent distributions for separate axis + * A multivariate distribution that has independent distributions for separate axis. */ public class FactorizedDistribution(public val distributions: Collection>) : NamedDistribution { diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt index 04ec8b171..24429cf32 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.distributions @@ -23,14 +23,14 @@ public class NormalDistribution(public val sampler: GaussianSampler) : Univariat normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler, ) : this(GaussianSampler(mean, standardDeviation, normalized)) - public override fun probability(arg: Double): Double { + override fun probability(arg: Double): Double { val x1 = (arg - sampler.mean) / sampler.standardDeviation return exp(-0.5 * x1 * x1 - (ln(sampler.standardDeviation) + 0.5 * ln(2 * PI))) } - public override fun sample(generator: RandomGenerator): Chain = sampler.sample(generator) + override fun sample(generator: RandomGenerator): Chain = sampler.sample(generator) - public override fun cumulative(arg: Double): Double { + override fun cumulative(arg: Double): Double { val dev = arg - sampler.mean return when { diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalErf.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalErf.kt index 25668446c..5b3cb1859 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalErf.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalErf.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.internal diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt index a584af4f9..18abd669f 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.internal @@ -110,7 +110,7 @@ internal object InternalGamma { x <= 8.0 -> { val n = floor(x - 1.5).toInt() - val prod = (1..n).fold(1.0, { prod, i -> prod * (x - i) }) + val prod = (1..n).fold(1.0) { prod, i -> prod * (x - i) } logGamma1p(x - (n + 1)) + ln(prod) } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt index 3997a77b3..77ba02a25 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.internal diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt deleted file mode 100644 index f54ba5723..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.optimization - -import space.kscience.kmath.expressions.* -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.indices - -/** - * A likelihood function optimization problem with provided derivatives - */ -public interface FunctionOptimization : Optimization { - /** - * The optimization direction. If true search for function maximum, if false, search for the minimum - */ - public var maximize: Boolean - - /** - * Define the initial guess for the optimization problem - */ - public fun initialGuess(map: Map) - - /** - * Set a differentiable expression as objective function as function and gradient provider - */ - public fun diffFunction(expression: DifferentiableExpression) - - public companion object { - /** - * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation - */ - public fun chiSquared( - autoDiff: AutoDiffProcessor>, - x: Buffer, - y: Buffer, - yErr: Buffer, - model: A.(I) -> I, - ): DifferentiableExpression where A : ExtendedField, A : ExpressionAlgebra { - require(x.size == y.size) { "X and y buffers should be of the same size" } - require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } - - return autoDiff.process { - var sum = zero - - x.indices.forEach { - val xValue = const(x[it]) - val yValue = const(y[it]) - val yErrValue = const(yErr[it]) - val modelValue = model(xValue) - sum += ((yValue - modelValue) / yErrValue).pow(2) - } - - sum - } - } - } -} - -/** - * Define a chi-squared-based objective function - */ -public fun FunctionOptimization.chiSquared( - autoDiff: AutoDiffProcessor>, - x: Buffer, - y: Buffer, - yErr: Buffer, - model: A.(I) -> I, -) where A : ExtendedField, A : ExpressionAlgebra { - val chiSquared = FunctionOptimization.chiSquared(autoDiff, x, y, yErr, model) - diffFunction(chiSquared) - maximize = false -} - -/** - * Optimize differentiable expression using specific [OptimizationProblemFactory] - */ -public fun > DifferentiableExpression.optimizeWith( - factory: OptimizationProblemFactory, - vararg symbols: Symbol, - configuration: F.() -> Unit, -): OptimizationResult { - require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = factory(symbols.toList(), configuration) - problem.diffFunction(this) - return problem.optimize() -} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/NoDerivFunctionOptimization.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/NoDerivFunctionOptimization.kt deleted file mode 100644 index 0f2167549..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/NoDerivFunctionOptimization.kt +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.optimization - -import space.kscience.kmath.expressions.Expression -import space.kscience.kmath.expressions.Symbol -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.indices -import kotlin.math.pow - -/** - * A likelihood function optimization problem - */ -public interface NoDerivFunctionOptimization : Optimization { - /** - * The optimization direction. If true search for function maximum, if false, search for the minimum - */ - public var maximize: Boolean - - /** - * Define the initial guess for the optimization problem - */ - public fun initialGuess(map: Map) - - /** - * Set an objective function expression - */ - public fun function(expression: Expression) - - public companion object { - /** - * Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives - */ - public fun chiSquared( - x: Buffer, - y: Buffer, - yErr: Buffer, - model: Expression, - xSymbol: Symbol = Symbol.x, - ): Expression { - require(x.size == y.size) { "X and y buffers should be of the same size" } - require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } - - return Expression { arguments -> - x.indices.sumOf { - val xValue = x[it] - val yValue = y[it] - val yErrValue = yErr[it] - val modifiedArgs = arguments + (xSymbol to xValue) - val modelValue = model(modifiedArgs) - ((yValue - modelValue) / yErrValue).pow(2) - } - } - } - } -} - - -/** - * Optimize expression without derivatives using specific [OptimizationProblemFactory] - */ -public fun > Expression.noDerivOptimizeWith( - factory: OptimizationProblemFactory, - vararg symbols: Symbol, - configuration: F.() -> Unit, -): OptimizationResult { - require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = factory(symbols.toList(), configuration) - problem.function(this) - return problem.optimize() -} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/Optimization.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/Optimization.kt deleted file mode 100644 index 4a1676412..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/Optimization.kt +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.optimization - -import space.kscience.kmath.expressions.Symbol - -public interface OptimizationFeature - -public class OptimizationResult( - public val point: Map, - public val value: T, - public val features: Set = emptySet(), -) { - override fun toString(): String { - return "OptimizationResult(point=$point, value=$value)" - } -} - -public operator fun OptimizationResult.plus( - feature: OptimizationFeature, -): OptimizationResult = OptimizationResult(point, value, features + feature) - -/** - * An optimization problem builder over [T] variables - */ -public interface Optimization { - - /** - * Update the problem from previous optimization run - */ - public fun update(result: OptimizationResult) - - /** - * Make an optimization run - */ - public fun optimize(): OptimizationResult -} - -public fun interface OptimizationProblemFactory> { - public fun build(symbols: List): P -} - -public operator fun > OptimizationProblemFactory.invoke( - symbols: List, - block: P.() -> Unit, -): P = build(symbols).apply(block) diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt deleted file mode 100644 index 70d7fdf79..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.optimization - -import space.kscience.kmath.data.ColumnarData -import space.kscience.kmath.expressions.* -import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.Field - -@UnstableKMathAPI -public interface XYFit : Optimization { - - public val algebra: Field - - /** - * Set X-Y data for this fit optionally including x and y errors - */ - public fun data( - dataSet: ColumnarData, - xSymbol: Symbol, - ySymbol: Symbol, - xErrSymbol: Symbol? = null, - yErrSymbol: Symbol? = null, - ) - - public fun model(model: (T) -> DifferentiableExpression) - - /** - * Set the differentiable model for this fit - */ - public fun model( - autoDiff: AutoDiffProcessor>, - modelFunction: A.(I) -> I, - ): Unit where A : ExtendedField, A : ExpressionAlgebra = model { arg -> - autoDiff.process { modelFunction(const(arg)) } - } -} \ No newline at end of file diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterExponentialSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterExponentialSampler.kt index a231842df..5f923fe5f 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterExponentialSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterExponentialSampler.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers @@ -24,7 +24,7 @@ public class AhrensDieterExponentialSampler(public val mean: Double) : Sampler 0) { "mean is not strictly positive: $mean" } } - public override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { + override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { override fun nextBlocking(): Double { // Step 1: var a = 0.0 diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt index 2f32eee85..063e055ce 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers @@ -14,9 +14,9 @@ import kotlin.math.* /** * Sampling from the [gamma distribution](http://mathworld.wolfram.com/GammaDistribution.html). - * - For 0 < alpha < 1: + * * For 0 < alpha < 1: * Ahrens, J. H. and Dieter, U., Computer methods for sampling from gamma, beta, Poisson and binomial distributions, Computing, 12, 223-246, 1974. - * - For alpha >= 1: + * * For alpha >= 1: * Marsaglia and Tsang, A Simple Method for Generating Gamma Variables. ACM Transactions on Mathematical Software, Volume 26 Issue 3, September, 2000. * * Based on Commons RNG implementation. @@ -113,8 +113,8 @@ public class AhrensDieterMarsagliaTsangGammaSampler private constructor( } } - public override fun sample(generator: RandomGenerator): Chain = delegate.sample(generator) - public override fun toString(): String = delegate.toString() + override fun sample(generator: RandomGenerator): Chain = delegate.sample(generator) + override fun toString(): String = delegate.toString() public companion object { public fun of( diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AliasMethodDiscreteSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AliasMethodDiscreteSampler.kt index db4f598b7..b00db5b30 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AliasMethodDiscreteSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AliasMethodDiscreteSampler.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers @@ -20,7 +20,7 @@ import kotlin.math.min * implements Vose's algorithm. * * Vose, M.D., A linear algorithm for generating random numbers with a given distribution, IEEE Transactions on - * Software Engineering, 17, 972-975, 1991. he algorithm will sample values in O(1) time after a pre-processing step + * Software Engineering, 17, 972-975, 1991. The algorithm will sample values in O(1) time after a pre-processing step * of O(n) time. * * The alias tables are constructed using fraction probabilities with an assumed denominator of 253. In the generic @@ -76,8 +76,8 @@ public open class AliasMethodDiscreteSampler private constructor( } } - public override fun sample(generator: RandomGenerator): Chain = generator.chain { - // This implements the algorithm as per Vose (1991): + override fun sample(generator: RandomGenerator): Chain = generator.chain { + // This implements the algorithm in accordance with Vose (1991): // v = uniform() in [0, 1) // j = uniform(n) in [0, n) // if v < prob[j] then @@ -95,7 +95,7 @@ public open class AliasMethodDiscreteSampler private constructor( // p(j) == 1 => j // However it is assumed these edge cases are rare: // - // The probability table will be 1 for approximately 1/n samples, i.e. only the + // The probability table will be 1 for approximately 1/n samples i.e., only the // last unpaired probability. This is only worth checking for when the table size (n) // is small. But in that case the user should zero-pad the table for performance. // @@ -107,7 +107,7 @@ public open class AliasMethodDiscreteSampler private constructor( if (generator.nextLong() ushr 11 < probability[j]) j else alias[j] } - public override fun toString(): String = "Alias method" + override fun toString(): String = "Alias method" public companion object { private const val DEFAULT_ALPHA = 0 @@ -211,7 +211,7 @@ public open class AliasMethodDiscreteSampler private constructor( // c: 2=2/3; 6=1/3 (6 is the alias) // d: 1=1/3; 6=2/3 (6 is the alias) // - // The sample is obtained by randomly selecting a section, then choosing which category + // The sample is obtained by randomly selecting a section, then choosing, which category // from the pair based on a uniform random deviate. val sumProb = InternalUtils.validateProbabilities(probabilities) // Allow zero-padding @@ -241,9 +241,9 @@ public open class AliasMethodDiscreteSampler private constructor( val alias = IntArray(n) // This loop uses each large in turn to fill the alias table for small probabilities that - // do not reach the requirement to fill an entire section alone (i.e. p < mean). + // do not reach the requirement to fill an entire section alone (i.e., p < mean). // Since the sum of the small should be less than the sum of the large it should use up - // all the small first. However floating point round-off can result in + // all the small first. However, floating point round-off can result in // misclassification of items as small or large. The Vose algorithm handles this using // a while loop conditioned on the size of both sets and a subsequent loop to use // unpaired items. diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/BoxMullerSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/BoxMullerSampler.kt index b3c014553..14aa26275 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/BoxMullerSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/BoxMullerSampler.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ConstantSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ConstantSampler.kt deleted file mode 100644 index 0d38fe19b..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ConstantSampler.kt +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.samplers - -import space.kscience.kmath.chains.BlockingBufferChain -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler -import space.kscience.kmath.structures.Buffer - -public class ConstantSampler(public val const: T) : Sampler { - override fun sample(generator: RandomGenerator): BlockingBufferChain = object : BlockingBufferChain { - override fun nextBufferBlocking(size: Int): Buffer = Buffer.boxing(size) { const } - override suspend fun fork(): BlockingBufferChain = this - } -} \ No newline at end of file diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/GaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/GaussianSampler.kt index d7d8e87b7..e5d1ecb49 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/GaussianSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/GaussianSampler.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers @@ -28,7 +28,7 @@ public class GaussianSampler( require(standardDeviation > 0.0) { "standard deviation is not strictly positive: $standardDeviation" } } - public override fun sample(generator: RandomGenerator): BlockingDoubleChain = normalized + override fun sample(generator: RandomGenerator): BlockingDoubleChain = normalized .sample(generator) .map { standardDeviation * it + mean } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/KempSmallMeanPoissonSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/KempSmallMeanPoissonSampler.kt index 9bb48fe4e..16f91570f 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/KempSmallMeanPoissonSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/KempSmallMeanPoissonSampler.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers @@ -13,7 +13,7 @@ import kotlin.math.exp /** * Sampler for the Poisson distribution. - * - Kemp, A, W, (1981) Efficient Generation of Logarithmically Distributed Pseudo-Random Variables. Journal of the Royal Statistical Society. Vol. 30, No. 3, pp. 249-253. + * * Kemp, A, W, (1981) Efficient Generation of Logarithmically Distributed Pseudo-Random Variables. Journal of the Royal Statistical Society. Vol. 30, No. 3, pp. 249-253. * This sampler is suitable for mean < 40. For large means, LargeMeanPoissonSampler should be used instead. * * Note: The algorithm uses a recurrence relation to compute the Poisson probability and a rolling summation for the cumulative probability. When the mean is large the initial probability (Math.exp(-mean)) is zero and an exception is raised by the constructor. @@ -27,7 +27,7 @@ public class KempSmallMeanPoissonSampler internal constructor( private val p0: Double, private val mean: Double, ) : Sampler { - public override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain { + override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain { override fun nextBlocking(): Int { //TODO move to nextBufferBlocking // Note on the algorithm: @@ -60,14 +60,13 @@ public class KempSmallMeanPoissonSampler internal constructor( override suspend fun fork(): BlockingIntChain = sample(generator.fork()) } - public override fun toString(): String = "Kemp Small Mean Poisson deviate" + override fun toString(): String = "Kemp Small Mean Poisson deviate" } public fun KempSmallMeanPoissonSampler(mean: Double): KempSmallMeanPoissonSampler { require(mean > 0) { "Mean is not strictly positive: $mean" } val p0 = exp(-mean) - // Probability must be positive. As mean increases then p(0) decreases. + // Probability must be positive. As mean increases, p(0) decreases. require(p0 > 0) { "No probability for mean: $mean" } return KempSmallMeanPoissonSampler(p0, mean) } - diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MarsagliaNormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MarsagliaNormalizedGaussianSampler.kt index 0a68e5c88..5e636f246 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MarsagliaNormalizedGaussianSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MarsagliaNormalizedGaussianSampler.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/NormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/NormalizedGaussianSampler.kt index 83f87e832..ceb324e8d 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/NormalizedGaussianSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/NormalizedGaussianSampler.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/PoissonSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/PoissonSampler.kt index e95778b9e..d3ff05b06 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/PoissonSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/PoissonSampler.kt @@ -1,12 +1,13 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers import space.kscience.kmath.chains.BlockingIntChain import space.kscience.kmath.internal.InternalUtils +import space.kscience.kmath.misc.toIntExact import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.stat.Sampler import space.kscience.kmath.structures.IntBuffer @@ -17,11 +18,11 @@ private const val PIVOT = 40.0 /** * Sampler for the Poisson distribution. - * - For small means, a Poisson process is simulated using uniform deviates, as described in + * * For small means, a Poisson process is simulated using uniform deviates, as described in * Knuth (1969). Seminumerical Algorithms. The Art of Computer Programming, Volume 2. Chapter 3.4.1.F.3 * Important integer-valued distributions: The Poisson distribution. Addison Wesley. * The Poisson process (and hence, the returned value) is bounded by 1000 * mean. - * - For large means, we use the rejection algorithm described in + * * For large means, we use the rejection algorithm described in * Devroye, Luc. (1981). The Computer Generation of Poisson Random Variables Computing vol. 26 pp. 197-207. * * Based on Commons RNG implementation. @@ -34,10 +35,10 @@ public fun PoissonSampler(mean: Double): Sampler { /** * Sampler for the Poisson distribution. - * - For small means, a Poisson process is simulated using uniform deviates, as described in + * * For small means, a Poisson process is simulated using uniform deviates, as described in * Knuth (1969). Seminumerical Algorithms. The Art of Computer Programming, Volume 2. Chapter 3.4.1.F.3 Important * integer-valued distributions: The Poisson distribution. Addison Wesley. - * - The Poisson process (and hence, the returned value) is bounded by 1000 * mean. + * * The Poisson process (and hence, the returned value) is bounded by 1000 * mean. * This sampler is suitable for mean < 40. For large means, [LargeMeanPoissonSampler] should be used instead. * * Based on Commons RNG implementation. @@ -58,7 +59,7 @@ public class SmallMeanPoissonSampler(public val mean: Double) : Sampler { throw IllegalArgumentException("No p(x=0) probability for mean: $mean") }.toInt() - public override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain { + override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain { override fun nextBlocking(): Int { var n = 0 var r = 1.0 @@ -76,7 +77,7 @@ public class SmallMeanPoissonSampler(public val mean: Double) : Sampler { override suspend fun fork(): BlockingIntChain = sample(generator.fork()) } - public override fun toString(): String = "Small Mean Poisson deviate" + override fun toString(): String = "Small Mean Poisson deviate" } @@ -113,13 +114,13 @@ public class LargeMeanPoissonSampler(public val mean: Double) : Sampler { private val p1: Double = a1 / aSum private val p2: Double = a2 / aSum - public override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain { + override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain { override fun nextBlocking(): Int { val exponential = AhrensDieterExponentialSampler(1.0).sample(generator) val gaussian = ZigguratNormalizedGaussianSampler.sample(generator) val smallMeanPoissonSampler = if (mean - lambda < Double.MIN_VALUE) { - null + null } else { KempSmallMeanPoissonSampler(mean - lambda).sample(generator) } @@ -188,7 +189,7 @@ public class LargeMeanPoissonSampler(public val mean: Double) : Sampler { } } - return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt() + return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toIntExact() } override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() } @@ -197,7 +198,7 @@ public class LargeMeanPoissonSampler(public val mean: Double) : Sampler { } private fun getFactorialLog(n: Int): Double = factorialLog.value(n) - public override fun toString(): String = "Large Mean Poisson deviate" + override fun toString(): String = "Large Mean Poisson deviate" public companion object { private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ZigguratNormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ZigguratNormalizedGaussianSampler.kt index 24148271d..bda6f9819 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ZigguratNormalizedGaussianSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ZigguratNormalizedGaussianSampler.kt @@ -1,11 +1,12 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.samplers import space.kscience.kmath.chains.BlockingDoubleChain +import space.kscience.kmath.misc.toIntExact import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.structures.DoubleBuffer import kotlin.math.* diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/MCScope.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/MCScope.kt index 9e5c70a26..5e1e577ba 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/MCScope.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/MCScope.kt @@ -1,11 +1,13 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat import kotlinx.coroutines.* +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext import kotlin.coroutines.coroutineContext @@ -23,14 +25,18 @@ public class MCScope( /** * Launches a supervised Monte-Carlo scope */ -public suspend inline fun mcScope(generator: RandomGenerator, block: MCScope.() -> T): T = - MCScope(coroutineContext, generator).block() +public suspend inline fun mcScope(generator: RandomGenerator, block: MCScope.() -> T): T { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return MCScope(coroutineContext, generator).block() +} /** * Launch mc scope with a given seed */ -public suspend inline fun mcScope(seed: Long, block: MCScope.() -> T): T = - mcScope(RandomGenerator.default(seed), block) +public suspend inline fun mcScope(seed: Long, block: MCScope.() -> T): T { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + return mcScope(RandomGenerator.default(seed), block) +} /** * Specialized launch for [MCScope]. Behaves the same way as regular [CoroutineScope.launch], but also stores the generator fork. diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Mean.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Mean.kt index 9769146fb..2a9bd3cd4 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Mean.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Mean.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat @@ -27,24 +27,39 @@ public class Mean( override suspend fun evaluate(data: Buffer): T = super.evaluate(data) - public override suspend fun computeIntermediate(data: Buffer): Pair = - evaluateBlocking(data) to data.size + override suspend fun computeIntermediate(data: Buffer): Pair = group { + var res = zero + for (i in data.indices) { + res += data[i] + } + res to data.size + } - public override suspend fun composeIntermediate(first: Pair, second: Pair): Pair = + override suspend fun composeIntermediate(first: Pair, second: Pair): Pair = group { first.first + second.first } to (first.second + second.second) - public override suspend fun toResult(intermediate: Pair): T = group { + override suspend fun toResult(intermediate: Pair): T = group { division(intermediate.first, intermediate.second) } public companion object { - //TODO replace with optimized version which respects overflow + @Deprecated("Use Double.mean instead") public val double: Mean = Mean(DoubleField) { sum, count -> sum / count } + @Deprecated("Use Int.mean instead") public val int: Mean = Mean(IntRing) { sum, count -> sum / count } + @Deprecated("Use Long.mean instead") public val long: Mean = Mean(LongRing) { sum, count -> sum / count } - public fun evaluate(buffer: Buffer): Double = double.evaluateBlocking(buffer) - public fun evaluate(buffer: Buffer): Int = int.evaluateBlocking(buffer) - public fun evaluate(buffer: Buffer): Long = long.evaluateBlocking(buffer) + public fun evaluate(buffer: Buffer): Double = Double.mean.evaluateBlocking(buffer) + public fun evaluate(buffer: Buffer): Int = Int.mean.evaluateBlocking(buffer) + public fun evaluate(buffer: Buffer): Long = Long.mean.evaluateBlocking(buffer) } -} \ No newline at end of file +} + + +//TODO replace with optimized version which respects overflow +public val Double.Companion.mean: Mean get() = Mean(DoubleField) { sum, count -> sum / count } +public val Int.Companion.mean: Mean get() = Mean(IntRing) { sum, count -> sum / count } +public val Long.Companion.mean: Mean get() = Mean(LongRing) { sum, count -> sum / count } + + diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Median.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Median.kt index 70754eab7..664e4e8e7 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Median.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Median.kt @@ -1,18 +1,18 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat +import space.kscience.kmath.operations.asSequence import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.asSequence /** * Non-composable median */ public class Median(private val comparator: Comparator) : BlockingStatistic { - public override fun evaluateBlocking(data: Buffer): T = + override fun evaluateBlocking(data: Buffer): T = data.asSequence().sortedWith(comparator).toList()[data.size / 2] //TODO check if this is correct public companion object { diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomChain.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomChain.kt index 5041e7359..61e472334 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomChain.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomChain.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat @@ -31,7 +31,7 @@ public fun RandomGenerator.chain(generator: suspend RandomGenerator.() -> R) * A type-specific double chunk random chain */ public class UniformDoubleChain(public val generator: RandomGenerator) : BlockingDoubleChain { - public override fun nextBufferBlocking(size: Int): DoubleBuffer = generator.nextDoubleBuffer(size) + override fun nextBufferBlocking(size: Int): DoubleBuffer = generator.nextDoubleBuffer(size) override suspend fun nextBuffer(size: Int): DoubleBuffer = nextBufferBlocking(size) override suspend fun fork(): UniformDoubleChain = UniformDoubleChain(generator.fork()) diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomGenerator.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomGenerator.kt index 3ff12f383..98ee6402a 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomGenerator.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomGenerator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat @@ -23,7 +23,7 @@ public interface RandomGenerator { public fun nextDouble(): Double /** - * A chunk of doubles of given [size] + * A chunk of doubles of given [size]. */ public fun nextDoubleBuffer(size: Int): DoubleBuffer = DoubleBuffer(size) { nextDouble() } @@ -57,7 +57,7 @@ public interface RandomGenerator { public fun nextLong(until: Long): Long /** - * Fills a subrange of the specified byte [array] starting from [fromIndex] inclusive and ending [toIndex] exclusive + * Fills a subrange with the specified byte [array] starting from [fromIndex] inclusive and ending [toIndex] exclusive * with random bytes. * * @return [array] with the subrange filled with random bytes. @@ -70,7 +70,7 @@ public interface RandomGenerator { public fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) } /** - * Create a new generator which is independent from current generator (operations on new generator do not affect this one + * Create a new generator that is independent of current generator (operations on new generator do not affect this one * and vise versa). The statistical properties of new generator should be the same as for this one. * For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run. * @@ -97,17 +97,17 @@ public interface RandomGenerator { * @property random the underlying [Random] object. */ public class DefaultGenerator(public val random: Random = Random) : RandomGenerator { - public override fun nextBoolean(): Boolean = random.nextBoolean() - public override fun nextDouble(): Double = random.nextDouble() - public override fun nextInt(): Int = random.nextInt() - public override fun nextInt(until: Int): Int = random.nextInt(until) - public override fun nextLong(): Long = random.nextLong() - public override fun nextLong(until: Long): Long = random.nextLong(until) + override fun nextBoolean(): Boolean = random.nextBoolean() + override fun nextDouble(): Double = random.nextDouble() + override fun nextInt(): Int = random.nextInt() + override fun nextInt(until: Int): Int = random.nextInt(until) + override fun nextLong(): Long = random.nextLong() + override fun nextLong(until: Long): Long = random.nextLong(until) - public override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { + override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { random.nextBytes(array, fromIndex, toIndex) } - public override fun nextBytes(size: Int): ByteArray = random.nextBytes(size) - public override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong()) + override fun nextBytes(size: Int): ByteArray = random.nextBytes(size) + override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong()) } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt index 51ae78d3d..0b3b52cab 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat @@ -12,7 +12,7 @@ import space.kscience.kmath.structures.* import kotlin.jvm.JvmName /** - * Sampler that generates chains of values of type [T] in a chain of type [C]. + * Sampler that generates chains of values of type [T]. */ public fun interface Sampler { /** diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt index b3d607ab0..e0be72d4b 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/SamplerAlgebra.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat @@ -18,8 +18,8 @@ import space.kscience.kmath.operations.invoke * * @property value the value to sample. */ -public class ConstantSampler(public val value: T) : Sampler { - public override fun sample(generator: RandomGenerator): Chain = ConstantChain(value) +public class ConstantSampler(public val value: T) : Sampler { + override fun sample(generator: RandomGenerator): Chain = ConstantChain(value) } /** @@ -27,29 +27,29 @@ public class ConstantSampler(public val value: T) : Sampler { * * @property chainBuilder the provider of [Chain]. */ -public class BasicSampler(public val chainBuilder: (RandomGenerator) -> Chain) : Sampler { - public override fun sample(generator: RandomGenerator): Chain = chainBuilder(generator) +public class BasicSampler(public val chainBuilder: (RandomGenerator) -> Chain) : Sampler { + override fun sample(generator: RandomGenerator): Chain = chainBuilder(generator) } /** - * A space of samplers. Allows to perform simple operations on distributions. + * A space of samplers. Allows performing simple operations on distributions. * * @property algebra the space to provide addition and scalar multiplication for [T]. */ -public class SamplerSpace(public val algebra: S) : Group>, +public class SamplerSpace(public val algebra: S) : Group>, ScaleOperations> where S : Group, S : ScaleOperations { - public override val zero: Sampler = ConstantSampler(algebra.zero) + override val zero: Sampler = ConstantSampler(algebra.zero) - public override fun add(a: Sampler, b: Sampler): Sampler = BasicSampler { generator -> - a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } } + override fun add(left: Sampler, right: Sampler): Sampler = BasicSampler { generator -> + left.sample(generator).zip(right.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } } } - public override fun scale(a: Sampler, value: Double): Sampler = BasicSampler { generator -> + override fun scale(a: Sampler, value: Double): Sampler = BasicSampler { generator -> a.sample(generator).map { a -> algebra { a * value } } } - public override fun Sampler.unaryMinus(): Sampler = scale(this, -1.0) + override fun Sampler.unaryMinus(): Sampler = scale(this, -1.0) } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Statistic.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Statistic.kt index 1b05aa9cd..107161514 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Statistic.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Statistic.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat @@ -8,7 +8,6 @@ package space.kscience.kmath.stat import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.runningReduce @@ -18,23 +17,31 @@ import space.kscience.kmath.structures.Buffer /** * A function, that transforms a buffer of random quantities to some resulting value */ -public interface Statistic { +public fun interface Statistic { public suspend fun evaluate(data: Buffer): R } -public interface BlockingStatistic: Statistic{ +public suspend operator fun Statistic.invoke(data: Buffer): R = evaluate(data) + +/** + * A statistic that is computed in a synchronous blocking mode + */ +public fun interface BlockingStatistic : Statistic { public fun evaluateBlocking(data: Buffer): R - override suspend fun evaluate(data: Buffer): R = evaluateBlocking(data) + override suspend fun evaluate(data: Buffer): R = evaluateBlocking(data) } +public operator fun BlockingStatistic.invoke(data: Buffer): R = evaluateBlocking(data) + /** * A statistic tha could be computed separately on different blocks of data and then composed - * @param T - source type - * @param I - intermediate block type - * @param R - result type + * + * @param T the source type. + * @param I the intermediate block type. + * @param R the result type. */ -public interface ComposableStatistic : Statistic { +public interface ComposableStatistic : Statistic { //compute statistic on a single block public suspend fun computeIntermediate(data: Buffer): I @@ -44,11 +51,13 @@ public interface ComposableStatistic : Statistic { //Transform block to result public suspend fun toResult(intermediate: I): R - public override suspend fun evaluate(data: Buffer): R = toResult(computeIntermediate(data)) + override suspend fun evaluate(data: Buffer): R = toResult(computeIntermediate(data)) } -@FlowPreview -@ExperimentalCoroutinesApi +/** + * Flow intermediate state of the [ComposableStatistic] + */ +@OptIn(ExperimentalCoroutinesApi::class) private fun ComposableStatistic.flowIntermediate( flow: Flow>, dispatcher: CoroutineDispatcher = Dispatchers.Default, @@ -63,7 +72,7 @@ private fun ComposableStatistic.flowIntermediate( * * The resulting flow contains values that include the whole previous statistics, not only the last chunk. */ -@OptIn(FlowPreview::class, ExperimentalCoroutinesApi::class) +@OptIn(ExperimentalCoroutinesApi::class) public fun ComposableStatistic.flow( flow: Flow>, dispatcher: CoroutineDispatcher = Dispatchers.Default, diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/UniformDistribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/UniformDistribution.kt index 970a3aab5..4c0d08720 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/UniformDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/UniformDistribution.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat diff --git a/kmath-stat/src/jvmMain/kotlin/space/kscience/kmath/stat/RandomSourceGenerator.kt b/kmath-stat/src/jvmMain/kotlin/space/kscience/kmath/stat/RandomSourceGenerator.kt index 1ff6481ac..a8e6a3362 100644 --- a/kmath-stat/src/jvmMain/kotlin/space/kscience/kmath/stat/RandomSourceGenerator.kt +++ b/kmath-stat/src/jvmMain/kotlin/space/kscience/kmath/stat/RandomSourceGenerator.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat @@ -17,19 +17,19 @@ public class RandomSourceGenerator internal constructor(public val source: Rando internal val random: UniformRandomProvider = seed?.let { RandomSource.create(source, seed) } ?: RandomSource.create(source) - public override fun nextBoolean(): Boolean = random.nextBoolean() - public override fun nextDouble(): Double = random.nextDouble() - public override fun nextInt(): Int = random.nextInt() - public override fun nextInt(until: Int): Int = random.nextInt(until) - public override fun nextLong(): Long = random.nextLong() - public override fun nextLong(until: Long): Long = random.nextLong(until) + override fun nextBoolean(): Boolean = random.nextBoolean() + override fun nextDouble(): Double = random.nextDouble() + override fun nextInt(): Int = random.nextInt() + override fun nextInt(until: Int): Int = random.nextInt(until) + override fun nextLong(): Long = random.nextLong() + override fun nextLong(until: Long): Long = random.nextLong(until) - public override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { + override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { require(toIndex > fromIndex) random.nextBytes(array, fromIndex, toIndex - fromIndex) } - public override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong()) + override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong()) } /** @@ -43,23 +43,23 @@ public class RandomGeneratorProvider(public val generator: RandomGenerator) : Un * * @return the next random value. */ - public override fun nextBoolean(): Boolean = generator.nextBoolean() + override fun nextBoolean(): Boolean = generator.nextBoolean() /** * Generates a [Float] value between 0 and 1. * * @return the next random value between 0 and 1. */ - public override fun nextFloat(): Float = generator.nextDouble().toFloat() + override fun nextFloat(): Float = generator.nextDouble().toFloat() /** * Generates [Byte] values and places them into a user-supplied array. * - * The number of random bytes produced is equal to the length of the the byte array. + * The number of random bytes produced is equal to the length of the byte array. * * @param bytes byte array in which to put the random bytes. */ - public override fun nextBytes(bytes: ByteArray): Unit = generator.fillBytes(bytes) + override fun nextBytes(bytes: ByteArray): Unit = generator.fillBytes(bytes) /** * Generates [Byte] values and places them into a user-supplied array. @@ -71,7 +71,7 @@ public class RandomGeneratorProvider(public val generator: RandomGenerator) : Un * @param start the index at which to start inserting the generated bytes. * @param len the number of bytes to insert. */ - public override fun nextBytes(bytes: ByteArray, start: Int, len: Int) { + override fun nextBytes(bytes: ByteArray, start: Int, len: Int) { generator.fillBytes(bytes, start, start + len) } @@ -80,7 +80,7 @@ public class RandomGeneratorProvider(public val generator: RandomGenerator) : Un * * @return the next random value. */ - public override fun nextInt(): Int = generator.nextInt() + override fun nextInt(): Int = generator.nextInt() /** * Generates an [Int] value between 0 (inclusive) and the specified value (exclusive). @@ -88,21 +88,21 @@ public class RandomGeneratorProvider(public val generator: RandomGenerator) : Un * @param n the bound on the random number to be returned. Must be positive. * @return a random integer between 0 (inclusive) and [n] (exclusive). */ - public override fun nextInt(n: Int): Int = generator.nextInt(n) + override fun nextInt(n: Int): Int = generator.nextInt(n) /** * Generates a [Double] value between 0 and 1. * * @return the next random value between 0 and 1. */ - public override fun nextDouble(): Double = generator.nextDouble() + override fun nextDouble(): Double = generator.nextDouble() /** * Generates a [Long] value. * * @return the next random value. */ - public override fun nextLong(): Long = generator.nextLong() + override fun nextLong(): Long = generator.nextLong() /** * Generates a [Long] value between 0 (inclusive) and the specified value (exclusive). @@ -110,7 +110,7 @@ public class RandomGeneratorProvider(public val generator: RandomGenerator) : Un * @param n Bound on the random number to be returned. Must be positive. * @return a random long value between 0 (inclusive) and [n] (exclusive). */ - public override fun nextLong(n: Long): Long = generator.nextLong(n) + override fun nextLong(n: Long): Long = generator.nextLong(n) } /** diff --git a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/CommonsDistributionsTest.kt b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/CommonsDistributionsTest.kt index 19c01e099..2b6b1ca60 100644 --- a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/CommonsDistributionsTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/CommonsDistributionsTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat diff --git a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/MCScopeTest.kt b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/MCScopeTest.kt index 075d7f3e5..0c3d9cb0d 100644 --- a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/MCScopeTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/MCScopeTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat diff --git a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/SamplerTest.kt b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/SamplerTest.kt index 1dbbf591b..4060c0505 100644 --- a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/SamplerTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/SamplerTest.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat diff --git a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt index 1dd5c5161..777b93c29 100644 --- a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt @@ -1,15 +1,17 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.stat -import kotlinx.coroutines.flow.drop import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.last +import kotlinx.coroutines.flow.take import kotlinx.coroutines.runBlocking import space.kscience.kmath.streaming.chunked import kotlin.test.Test +import kotlin.test.assertEquals internal class StatisticTest { //create a random number generator. @@ -22,12 +24,27 @@ internal class StatisticTest { val chunked = data.chunked(1000) @Test - fun testParallelMean() = runBlocking { - val average = Mean.double - .flow(chunked) //create a flow with results - .drop(99) // Skip first 99 values and use one with total data - .first() //get 1e5 data samples average - - println(average) + fun singleBlockingMean() { + val first = runBlocking { chunked.first()} + val res = Double.mean(first) + assertEquals(0.5,res, 1e-1) } + + @Test + fun singleSuspendMean() = runBlocking { + val first = runBlocking { chunked.first()} + val res = Double.mean(first) + assertEquals(0.5,res, 1e-1) + } + + @Test + fun parallelMean() = runBlocking { + val average = Double.mean + .flow(chunked) //create a flow from evaluated results + .take(100) // Take 100 data chunks from the source and accumulate them + .last() //get 1e5 data samples average + + assertEquals(0.5,average, 1e-2) + } + } diff --git a/kmath-symja/build.gradle.kts b/kmath-symja/build.gradle.kts index f305c03b8..65c329d52 100644 --- a/kmath-symja/build.gradle.kts +++ b/kmath-symja/build.gradle.kts @@ -34,7 +34,7 @@ dependencies { api("org.hipparchus:hipparchus-stat:1.8") api(project(":kmath-core")) - testImplementation("org.slf4j:slf4j-simple:1.7.30") + testImplementation("org.slf4j:slf4j-simple:1.7.31") } readme { diff --git a/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/SymjaExpression.kt b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/SymjaExpression.kt index a6773c709..a343256fa 100644 --- a/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/SymjaExpression.kt +++ b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/SymjaExpression.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.symja @@ -14,7 +14,8 @@ import space.kscience.kmath.expressions.interpret import space.kscience.kmath.operations.NumericAlgebra /** - * Represents [MST] based [DifferentiableExpression] relying on [Symja](https://github.com/axkr/symja_android_library). + * Represents [MST] based [space.kscience.kmath.expressions.DifferentiableExpression] relying on + * [Symja](https://github.com/axkr/symja_android_library). * * The principle of this API is converting the [mst] to an [org.matheclipse.core.interfaces.IExpr], differentiating it * with Symja's [F.D], then converting [org.matheclipse.core.interfaces.IExpr] back to [MST]. @@ -29,9 +30,9 @@ public class SymjaExpression>( public val mst: MST, public val evaluator: ExprEvaluator = DEFAULT_EVALUATOR, ) : SpecialDifferentiableExpression> { - public override fun invoke(arguments: Map): T = mst.interpret(algebra, arguments) + override fun invoke(arguments: Map): T = mst.interpret(algebra, arguments) - public override fun derivativeOrNull(symbols: List): SymjaExpression = SymjaExpression( + override fun derivativeOrNull(symbols: List): SymjaExpression = SymjaExpression( algebra, symbols.map(Symbol::toIExpr).fold(mst.toIExpr(), F::D).toMst(evaluator), evaluator, diff --git a/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt index 95dd1ebbf..a7ca298a9 100644 --- a/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt +++ b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.symja @@ -64,8 +64,8 @@ public fun MST.toIExpr(): IExpr = when (this) { } is MST.Unary -> when (operation) { - GroupOperations.PLUS_OPERATION -> value.toIExpr() - GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr()) + GroupOps.PLUS_OPERATION -> value.toIExpr() + GroupOps.MINUS_OPERATION -> F.Negate(value.toIExpr()) TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr()) TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr()) TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr()) @@ -85,10 +85,10 @@ public fun MST.toIExpr(): IExpr = when (this) { } is MST.Binary -> when (operation) { - GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr() - GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr() - RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr() - FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr()) + GroupOps.PLUS_OPERATION -> left.toIExpr() + right.toIExpr() + GroupOps.MINUS_OPERATION -> left.toIExpr() - right.toIExpr() + RingOps.TIMES_OPERATION -> left.toIExpr() * right.toIExpr() + FieldOps.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr()) PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value)) else -> error("Binary operation $operation not defined in $this") } diff --git a/kmath-tensors/README.md b/kmath-tensors/README.md index 6b991d5df..b19a55381 100644 --- a/kmath-tensors/README.md +++ b/kmath-tensors/README.md @@ -3,7 +3,7 @@ Common linear algebra operations on tensors. - [tensor algebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.) - - [tensor algebra with broadcasting](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting. + - [tensor algebra with broadcasting](src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting. - [linear algebra operations](src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc. diff --git a/kmath-tensors/build.gradle.kts b/kmath-tensors/build.gradle.kts index 92f87b927..3ba57ca7b 100644 --- a/kmath-tensors/build.gradle.kts +++ b/kmath-tensors/build.gradle.kts @@ -1,11 +1,14 @@ plugins { - id("ru.mipt.npm.gradle.mpp") + kotlin("multiplatform") + id("ru.mipt.npm.gradle.common") + id("ru.mipt.npm.gradle.native") } kotlin.sourceSets { all { languageSettings.useExperimentalAnnotation("space.kscience.kmath.misc.UnstableKMathAPI") } + commonMain { dependencies { api(project(":kmath-core")) @@ -14,30 +17,22 @@ kotlin.sourceSets { } } -tasks.dokkaHtml { - dependsOn(tasks.build) -} - readme { maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) feature( id = "tensor algebra", - description = "Basic linear algebra operations on tensors (plus, dot, etc.)", ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt" - ) + ) { "Basic linear algebra operations on tensors (plus, dot, etc.)" } feature( id = "tensor algebra with broadcasting", - description = "Basic linear algebra operations implemented with broadcasting.", - ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt" - ) + ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt" + ) { "Basic linear algebra operations implemented with broadcasting." } feature( id = "linear algebra operations", - description = "Advanced linear algebra operations like LU decomposition, SVD, etc.", ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt" - ) - -} \ No newline at end of file + ) { "Advanced linear algebra operations like LU decomposition, SVD, etc." } +} diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt index e58af14db..debfb3ef0 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt @@ -1,22 +1,25 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.api +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.Field + /** * Analytic operations on [Tensor]. * * @param T the type of items closed under analytic functions in the tensors. */ -public interface AnalyticTensorAlgebra : TensorPartialDivisionAlgebra { +public interface AnalyticTensorAlgebra> : TensorPartialDivisionAlgebra { /** * @return the mean of all elements in the input tensor. */ - public fun Tensor.mean(): T + public fun StructureND.mean(): T /** * Returns the mean of each row of the input tensor in the given dimension [dim]. @@ -29,12 +32,12 @@ public interface AnalyticTensorAlgebra : TensorPartialDivisionAlgebra { * @param keepDim whether the output tensor has [dim] retained or not. * @return the mean of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.mean(dim: Int, keepDim: Boolean): Tensor /** * @return the standard deviation of all elements in the input tensor. */ - public fun Tensor.std(): T + public fun StructureND.std(): T /** * Returns the standard deviation of each row of the input tensor in the given dimension [dim]. @@ -47,12 +50,12 @@ public interface AnalyticTensorAlgebra : TensorPartialDivisionAlgebra { * @param keepDim whether the output tensor has [dim] retained or not. * @return the standard deviation of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.std(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.std(dim: Int, keepDim: Boolean): Tensor /** * @return the variance of all elements in the input tensor. */ - public fun Tensor.variance(): T + public fun StructureND.variance(): T /** * Returns the variance of each row of the input tensor in the given dimension [dim]. @@ -65,57 +68,57 @@ public interface AnalyticTensorAlgebra : TensorPartialDivisionAlgebra { * @param keepDim whether the output tensor has [dim] retained or not. * @return the variance of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.variance(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.variance(dim: Int, keepDim: Boolean): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.exp.html - public fun Tensor.exp(): Tensor + public fun StructureND.exp(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.log.html - public fun Tensor.ln(): Tensor + public fun StructureND.ln(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html - public fun Tensor.sqrt(): Tensor + public fun StructureND.sqrt(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos - public fun Tensor.cos(): Tensor + public fun StructureND.cos(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos - public fun Tensor.acos(): Tensor + public fun StructureND.acos(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh - public fun Tensor.cosh(): Tensor + public fun StructureND.cosh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh - public fun Tensor.acosh(): Tensor + public fun StructureND.acosh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin - public fun Tensor.sin(): Tensor + public fun StructureND.sin(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin - public fun Tensor.asin(): Tensor + public fun StructureND.asin(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh - public fun Tensor.sinh(): Tensor + public fun StructureND.sinh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh - public fun Tensor.asinh(): Tensor + public fun StructureND.asinh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan - public fun Tensor.tan(): Tensor + public fun StructureND.tan(): Tensor //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan - public fun Tensor.atan(): Tensor + public fun StructureND.atan(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh - public fun Tensor.tanh(): Tensor + public fun StructureND.tanh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh - public fun Tensor.atanh(): Tensor + public fun StructureND.atanh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil - public fun Tensor.ceil(): Tensor + public fun StructureND.ceil(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor - public fun Tensor.floor(): Tensor + public fun StructureND.floor(): Tensor } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt index ad6cc9e78..ad1e3640d 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt @@ -1,16 +1,19 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.api +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.Field + /** * Common linear algebra operations. Operates on [Tensor]. * * @param T the type of items closed under division in the tensors. */ -public interface LinearOpsTensorAlgebra : TensorPartialDivisionAlgebra { +public interface LinearOpsTensorAlgebra> : TensorPartialDivisionAlgebra { /** * Computes the determinant of a square matrix input, or of each square matrix in a batched input. @@ -18,17 +21,17 @@ public interface LinearOpsTensorAlgebra : TensorPartialDivisionAlgebra { * * @return the determinant. */ - public fun Tensor.det(): Tensor + public fun StructureND.det(): Tensor /** * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input. * Given a square matrix `A`, return the matrix `AInv` satisfying - * `A dot AInv = AInv dot A = eye(a.shape[0])`. + * `A dot AInv == AInv dot A == eye(a.shape[0])`. * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv * * @return the multiplicative inverse of a matrix. */ - public fun Tensor.inv(): Tensor + public fun StructureND.inv(): Tensor /** * Cholesky decomposition. @@ -36,27 +39,29 @@ public interface LinearOpsTensorAlgebra : TensorPartialDivisionAlgebra { * Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices) * positive-definite matrix or the Cholesky decompositions for a batch of such matrices. * Each decomposition has the form: - * Given a tensor `input`, return the tensor `L` satisfying `input = L dot L.H`, - * where L is a lower-triangular matrix and L.H is the conjugate transpose of L, + * Given a tensor `input`, return the tensor `L` satisfying `input = L dot LH`, + * where `L` is a lower-triangular matrix and `LH` is the conjugate transpose of `L`, * which is just a transpose for the case of real-valued input matrices. * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky * - * @return the batch of L matrices. + * @receiver the `input`. + * @return the batch of `L` matrices. */ - public fun Tensor.cholesky(): Tensor + public fun StructureND.cholesky(): Tensor /** * QR decomposition. * - * Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors. - * Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q dot R``, + * Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `Q to R` of tensors. + * Given a tensor `input`, return tensors `Q to R` satisfying `input == Q dot R`, * with `Q` being an orthogonal matrix or batch of orthogonal matrices * and `R` being an upper triangular matrix or batch of upper triangular matrices. * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr * - * @return pair of Q and R tensors. + * @receiver the `input`. + * @return pair of `Q` and `R` tensors. */ - public fun Tensor.qr(): Pair, Tensor> + public fun StructureND.qr(): Pair, Tensor> /** * LUP decomposition @@ -69,31 +74,35 @@ public interface LinearOpsTensorAlgebra : TensorPartialDivisionAlgebra { * `L` being a lower triangular matrix or batch of matrices, * `U` being an upper triangular matrix or batch of matrices. * - * * @return triple of P, L and U tensors + * @receiver the `input`. + * @return triple of P, L and U tensors */ - public fun Tensor.lu(): Triple, Tensor, Tensor> + public fun StructureND.lu(): Triple, Tensor, Tensor> /** * Singular Value Decomposition. * * Computes the singular value decomposition of either a matrix or batch of matrices `input`. - * The singular value decomposition is represented as a triple `(U, S, V)`, - * such that `input = U dot diagonalEmbedding(S) dot V.H`, - * where V.H is the conjugate transpose of V. - * If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input. + * The singular value decomposition is represented as a triple `Triple(U, S, V)`, + * such that `input = U dot diagonalEmbedding(S) dot VH`, + * where `VH` is the conjugate transpose of V. + * If `input` is a batch of tensors, then `U`, `S`, and `VH` are also batched with the same batch dimensions as + * `input`. * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd * - * @return triple `(U, S, V)`. + * @receiver the `input`. + * @return triple `Triple(U, S, V)`. */ - public fun Tensor.svd(): Triple, Tensor, Tensor> + public fun StructureND.svd(): Triple, Tensor, Tensor> /** - * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, - * represented by a pair (eigenvalues, eigenvectors). + * Returns eigenvalues and eigenvectors of a real symmetric matrix `input` or a batch of real symmetric matrices, + * represented by a pair `eigenvalues to eigenvectors`. * For more information: https://pytorch.org/docs/stable/generated/torch.symeig.html * - * @return a pair (eigenvalues, eigenvectors) + * @receiver the `input`. + * @return a pair `eigenvalues to eigenvectors` */ - public fun Tensor.symEig(): Pair, Tensor> + public fun StructureND.symEig(): Pair, Tensor> } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/Tensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/Tensor.kt index 179787684..e0f296057 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/Tensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/Tensor.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + package space.kscience.kmath.tensors.api import space.kscience.kmath.nd.MutableStructureND diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index de3b12fd6..60fc470fb 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -1,11 +1,13 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.api -import space.kscience.kmath.operations.Algebra +import space.kscience.kmath.nd.RingOpsND +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.Ring /** * Algebra over a ring on [Tensor]. @@ -13,49 +15,47 @@ import space.kscience.kmath.operations.Algebra * * @param T the type of items in the tensors. */ -public interface TensorAlgebra : Algebra> { - +public interface TensorAlgebra> : RingOpsND { /** - * For a tensor containing a single element, i.e. a scalar tensor, the value of that element is returned. - * You need to check if the implementation puts any restrictions on the shape of a scalar tensor. + * Returns a single tensor value of unit dimension if tensor shape equals to [1]. * * @return a nullable value of a potentially scalar tensor. */ - public fun Tensor.valueOrNull(): T? + public fun StructureND.valueOrNull(): T? /** - * Unsafe version of [valueOrNull] + * Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. * * @return the value of a scalar tensor. */ - public fun Tensor.value(): T = valueOrNull() - ?: throw IllegalArgumentException("Illegal value call for non scalar tensor") + public fun StructureND.value(): T = + valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape") /** - * Each element of the tensor [other] is added to this value. + * Each element of the tensor [arg] is added to this value. * The resulting tensor is returned. * - * @param other tensor to be added. - * @return the sum of this value and tensor [other]. + * @param arg tensor to be added. + * @return the sum of this value and tensor [arg]. */ - public operator fun T.plus(other: Tensor): Tensor + override operator fun T.plus(arg: StructureND): Tensor /** - * Adds the scalar [value] to each element of this tensor and returns a new resulting tensor. + * Adds the scalar [arg] to each element of this tensor and returns a new resulting tensor. * - * @param value the number to be added to each element of this tensor. - * @return the sum of this tensor and [value]. + * @param arg the number to be added to each element of this tensor. + * @return the sum of this tensor and [arg]. */ - public operator fun Tensor.plus(value: T): Tensor + override operator fun StructureND.plus(arg: T): Tensor /** - * Each element of the tensor [other] is added to each element of this tensor. + * Each element of the tensor [arg] is added to each element of this tensor. * The resulting tensor is returned. * - * @param other tensor to be added. - * @return the sum of this tensor and [other]. + * @param arg tensor to be added. + * @return the sum of this tensor and [arg]. */ - public operator fun Tensor.plus(other: Tensor): Tensor + override operator fun StructureND.plus(arg: StructureND): Tensor /** * Adds the scalar [value] to each element of this tensor. @@ -65,37 +65,37 @@ public interface TensorAlgebra : Algebra> { public operator fun Tensor.plusAssign(value: T) /** - * Each element of the tensor [other] is added to each element of this tensor. + * Each element of the tensor [arg] is added to each element of this tensor. * - * @param other tensor to be added. + * @param arg tensor to be added. */ - public operator fun Tensor.plusAssign(other: Tensor) + public operator fun Tensor.plusAssign(arg: StructureND) /** - * Each element of the tensor [other] is subtracted from this value. + * Each element of the tensor [arg] is subtracted from this value. * The resulting tensor is returned. * - * @param other tensor to be subtracted. - * @return the difference between this value and tensor [other]. + * @param arg tensor to be subtracted. + * @return the difference between this value and tensor [arg]. */ - public operator fun T.minus(other: Tensor): Tensor + override operator fun T.minus(arg: StructureND): Tensor /** - * Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor. + * Subtracts the scalar [arg] from each element of this tensor and returns a new resulting tensor. * - * @param value the number to be subtracted from each element of this tensor. - * @return the difference between this tensor and [value]. + * @param arg the number to be subtracted from each element of this tensor. + * @return the difference between this tensor and [arg]. */ - public operator fun Tensor.minus(value: T): Tensor + override operator fun StructureND.minus(arg: T): Tensor /** - * Each element of the tensor [other] is subtracted from each element of this tensor. + * Each element of the tensor [arg] is subtracted from each element of this tensor. * The resulting tensor is returned. * - * @param other tensor to be subtracted. - * @return the difference between this tensor and [other]. + * @param arg tensor to be subtracted. + * @return the difference between this tensor and [arg]. */ - public operator fun Tensor.minus(other: Tensor): Tensor + override operator fun StructureND.minus(arg: StructureND): Tensor /** * Subtracts the scalar [value] from each element of this tensor. @@ -105,38 +105,38 @@ public interface TensorAlgebra : Algebra> { public operator fun Tensor.minusAssign(value: T) /** - * Each element of the tensor [other] is subtracted from each element of this tensor. + * Each element of the tensor [arg] is subtracted from each element of this tensor. * - * @param other tensor to be subtracted. + * @param arg tensor to be subtracted. */ - public operator fun Tensor.minusAssign(other: Tensor) + public operator fun Tensor.minusAssign(arg: StructureND) /** - * Each element of the tensor [other] is multiplied by this value. + * Each element of the tensor [arg] is multiplied by this value. * The resulting tensor is returned. * - * @param other tensor to be multiplied. - * @return the product of this value and tensor [other]. + * @param arg tensor to be multiplied. + * @return the product of this value and tensor [arg]. */ - public operator fun T.times(other: Tensor): Tensor + override operator fun T.times(arg: StructureND): Tensor /** - * Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor. + * Multiplies the scalar [arg] by each element of this tensor and returns a new resulting tensor. * - * @param value the number to be multiplied by each element of this tensor. - * @return the product of this tensor and [value]. + * @param arg the number to be multiplied by each element of this tensor. + * @return the product of this tensor and [arg]. */ - public operator fun Tensor.times(value: T): Tensor + override operator fun StructureND.times(arg: T): Tensor /** - * Each element of the tensor [other] is multiplied by each element of this tensor. + * Each element of the tensor [arg] is multiplied by each element of this tensor. * The resulting tensor is returned. * - * @param other tensor to be multiplied. - * @return the product of this tensor and [other]. + * @param arg tensor to be multiplied. + * @return the product of this tensor and [arg]. */ - public operator fun Tensor.times(other: Tensor): Tensor + override operator fun StructureND.times(arg: StructureND): Tensor /** * Multiplies the scalar [value] by each element of this tensor. @@ -146,18 +146,18 @@ public interface TensorAlgebra : Algebra> { public operator fun Tensor.timesAssign(value: T) /** - * Each element of the tensor [other] is multiplied by each element of this tensor. + * Each element of the tensor [arg] is multiplied by each element of this tensor. * - * @param other tensor to be multiplied. + * @param arg tensor to be multiplied. */ - public operator fun Tensor.timesAssign(other: Tensor) + public operator fun Tensor.timesAssign(arg: StructureND) /** * Numerical negative, element-wise. * * @return tensor negation of the original tensor. */ - public operator fun Tensor.unaryMinus(): Tensor + override operator fun StructureND.unaryMinus(): Tensor /** * Returns the tensor at index i @@ -190,13 +190,13 @@ public interface TensorAlgebra : Algebra> { /** * View this tensor as the same size as [other]. - * ``this.viewAs(other) is equivalent to this.view(other.shape)``. + * `this.viewAs(other)` is equivalent to `this.view(other.shape)`. * For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html * * @param other the result tensor has the same size as other. * @return the result tensor with the same size as other. */ - public fun Tensor.viewAs(other: Tensor): Tensor + public fun Tensor.viewAs(other: StructureND): Tensor /** * Matrix product of two tensors. @@ -218,16 +218,16 @@ public interface TensorAlgebra : Algebra> { * a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. * If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix * multiple and removed after. - * The non-matrix (i.e. batch) dimensions are broadcast (and thus must be broadcastable). + * The non-matrix (i.e., batch) dimensions are broadcast (and thus must be broadcastable). * For example, if `input` is a (j × 1 × n × n) tensor and `other` is a * (k × n × n) tensor, out will be a (j × k × n × n) tensor. * * For more information: https://pytorch.org/docs/stable/generated/torch.matmul.html * - * @param other tensor to be multiplied - * @return mathematical product of two tensors + * @param other tensor to be multiplied. + * @return a mathematical product of two tensors. */ - public infix fun Tensor.dot(other: Tensor): Tensor + public infix fun StructureND.dot(other: StructureND): Tensor /** * Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2]) @@ -235,7 +235,7 @@ public interface TensorAlgebra : Algebra> { * To facilitate creating batched diagonal matrices, * the 2D planes formed by the last two dimensions of the returned tensor are chosen by default. * - * The argument [offset] controls which diagonal to consider: + * The argument [offset] controls, which diagonal to consider: * 1. If [offset] = 0, it is the main diagonal. * 1. If [offset] > 0, it is above the main diagonal. * 1. If [offset] < 0, it is below the main diagonal. @@ -262,7 +262,7 @@ public interface TensorAlgebra : Algebra> { /** * @return the sum of all elements in the input tensor. */ - public fun Tensor.sum(): T + public fun StructureND.sum(): T /** * Returns the sum of each row of the input tensor in the given dimension [dim]. @@ -275,12 +275,12 @@ public interface TensorAlgebra : Algebra> { * @param keepDim whether the output tensor has [dim] retained or not. * @return the sum of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.sum(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.sum(dim: Int, keepDim: Boolean): Tensor /** - * @return the minimum value of all elements in the input tensor. + * @return the minimum value of all elements in the input tensor or null if there are no values */ - public fun Tensor.min(): T + public fun StructureND.min(): T? /** * Returns the minimum value of each row of the input tensor in the given dimension [dim]. @@ -293,12 +293,12 @@ public interface TensorAlgebra : Algebra> { * @param keepDim whether the output tensor has [dim] retained or not. * @return the minimum value of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.min(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.min(dim: Int, keepDim: Boolean): Tensor /** - * Returns the maximum value of all elements in the input tensor. + * Returns the maximum value of all elements in the input tensor or null if there are no values */ - public fun Tensor.max(): T + public fun StructureND.max(): T? /** * Returns the maximum value of each row of the input tensor in the given dimension [dim]. @@ -311,7 +311,7 @@ public interface TensorAlgebra : Algebra> { * @param keepDim whether the output tensor has [dim] retained or not. * @return the maximum value of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.max(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.max(dim: Int, keepDim: Boolean): Tensor /** * Returns the index of maximum value of each row of the input tensor in the given dimension [dim]. @@ -322,8 +322,11 @@ public interface TensorAlgebra : Algebra> { * * @param dim the dimension to reduce. * @param keepDim whether the output tensor has [dim] retained or not. - * @return the the index of maximum value of each row of the input tensor in the given dimension [dim]. + * @return the index of maximum value of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor + override fun add(left: StructureND, right: StructureND): Tensor = left + right + + override fun multiply(left: StructureND, right: StructureND): Tensor = left * right } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt index 02bf5415d..0a1e09081 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt @@ -1,43 +1,49 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.api +import space.kscience.kmath.nd.FieldOpsND +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.Field + /** * Algebra over a field with partial division on [Tensor]. * For more information: https://proofwiki.org/wiki/Definition:Division_Algebra * * @param T the type of items closed under division in the tensors. */ -public interface TensorPartialDivisionAlgebra : TensorAlgebra { +public interface TensorPartialDivisionAlgebra> : TensorAlgebra, FieldOpsND { /** - * Each element of the tensor [other] is divided by this value. + * Each element of the tensor [arg] is divided by this value. * The resulting tensor is returned. * - * @param other tensor to divide by. - * @return the division of this value by the tensor [other]. + * @param arg tensor to divide by. + * @return the division of this value by the tensor [arg]. */ - public operator fun T.div(other: Tensor): Tensor + override operator fun T.div(arg: StructureND): Tensor /** - * Divide by the scalar [value] each element of this tensor returns a new resulting tensor. + * Divide by the scalar [arg] each element of this tensor returns a new resulting tensor. * - * @param value the number to divide by each element of this tensor. - * @return the division of this tensor by the [value]. + * @param arg the number to divide by each element of this tensor. + * @return the division of this tensor by the [arg]. */ - public operator fun Tensor.div(value: T): Tensor + override operator fun StructureND.div(arg: T): Tensor /** - * Each element of the tensor [other] is divided by each element of this tensor. + * Each element of the tensor [arg] is divided by each element of this tensor. * The resulting tensor is returned. * - * @param other tensor to be divided by. - * @return the division of this tensor by [other]. + * @param arg tensor to be divided by. + * @return the division of this tensor by [arg]. */ - public operator fun Tensor.div(other: Tensor): Tensor + override operator fun StructureND.div(arg: StructureND): Tensor + + override fun divide(left: StructureND, right: StructureND): StructureND = left.div(right) /** * Divides by the scalar [value] each element of this tensor. @@ -47,9 +53,9 @@ public interface TensorPartialDivisionAlgebra : TensorAlgebra { public operator fun Tensor.divAssign(value: T) /** - * Each element of this tensor is divided by each element of the [other] tensor. + * Each element of this tensor is divided by each element of the [arg] tensor. * - * @param other tensor to be divide by. + * @param arg tensor to be divided by. */ - public operator fun Tensor.divAssign(other: Tensor) + public operator fun Tensor.divAssign(arg: StructureND) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt index b8530f637..7353ecab1 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt @@ -1,10 +1,12 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.core +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.internal.array import space.kscience.kmath.tensors.core.internal.broadcastTensors @@ -17,77 +19,85 @@ import space.kscience.kmath.tensors.core.internal.tensor */ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { - override fun Tensor.plus(other: Tensor): DoubleTensor { - val broadcast = broadcastTensors(tensor, other.tensor) + override fun StructureND.plus(arg: StructureND): DoubleTensor { + val broadcast = broadcastTensors(tensor, arg.tensor) val newThis = broadcast[0] val newOther = broadcast[1] - val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> + val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i] } return DoubleTensor(newThis.shape, resBuffer) } - override fun Tensor.plusAssign(other: Tensor) { - val newOther = broadcastTo(other.tensor, tensor.shape) - for (i in 0 until tensor.linearStructure.linearSize) { + override fun Tensor.plusAssign(arg: StructureND) { + val newOther = broadcastTo(arg.tensor, tensor.shape) + for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] += newOther.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Tensor.minus(other: Tensor): DoubleTensor { - val broadcast = broadcastTensors(tensor, other.tensor) + override fun StructureND.minus(arg: StructureND): DoubleTensor { + val broadcast = broadcastTensors(tensor, arg.tensor) val newThis = broadcast[0] val newOther = broadcast[1] - val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> + val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i] } return DoubleTensor(newThis.shape, resBuffer) } - override fun Tensor.minusAssign(other: Tensor) { - val newOther = broadcastTo(other.tensor, tensor.shape) - for (i in 0 until tensor.linearStructure.linearSize) { + override fun Tensor.minusAssign(arg: StructureND) { + val newOther = broadcastTo(arg.tensor, tensor.shape) + for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] -= newOther.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Tensor.times(other: Tensor): DoubleTensor { - val broadcast = broadcastTensors(tensor, other.tensor) + override fun StructureND.times(arg: StructureND): DoubleTensor { + val broadcast = broadcastTensors(tensor, arg.tensor) val newThis = broadcast[0] val newOther = broadcast[1] - val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> + val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> newThis.mutableBuffer.array()[newThis.bufferStart + i] * newOther.mutableBuffer.array()[newOther.bufferStart + i] } return DoubleTensor(newThis.shape, resBuffer) } - override fun Tensor.timesAssign(other: Tensor) { - val newOther = broadcastTo(other.tensor, tensor.shape) - for (i in 0 until tensor.linearStructure.linearSize) { + override fun Tensor.timesAssign(arg: StructureND) { + val newOther = broadcastTo(arg.tensor, tensor.shape) + for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] *= newOther.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Tensor.div(other: Tensor): DoubleTensor { - val broadcast = broadcastTensors(tensor, other.tensor) + override fun StructureND.div(arg: StructureND): DoubleTensor { + val broadcast = broadcastTensors(tensor, arg.tensor) val newThis = broadcast[0] val newOther = broadcast[1] - val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> + val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> newThis.mutableBuffer.array()[newOther.bufferStart + i] / newOther.mutableBuffer.array()[newOther.bufferStart + i] } return DoubleTensor(newThis.shape, resBuffer) } - override fun Tensor.divAssign(other: Tensor) { - val newOther = broadcastTo(other.tensor, tensor.shape) - for (i in 0 until tensor.linearStructure.linearSize) { + override fun Tensor.divAssign(arg: StructureND) { + val newOther = broadcastTo(arg.tensor, tensor.shape) + for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /= newOther.mutableBuffer.array()[tensor.bufferStart + i] } } -} \ No newline at end of file +} + + +/** + * Compute a value using broadcast double tensor algebra + */ +@UnstableKMathAPI +public fun DoubleTensorAlgebra.withBroadcast(block: BroadcastDoubleTensorAlgebra.() -> R): R = + BroadcastDoubleTensorAlgebra.block() \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt index b78df13d3..54d8f54dc 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + package space.kscience.kmath.tensors.core import space.kscience.kmath.misc.PerformancePitfall @@ -10,30 +15,29 @@ import space.kscience.kmath.tensors.api.Tensor */ public open class BufferedTensor internal constructor( override val shape: IntArray, - internal val mutableBuffer: MutableBuffer, - internal val bufferStart: Int + @PublishedApi internal val mutableBuffer: MutableBuffer, + @PublishedApi internal val bufferStart: Int, ) : Tensor { /** * Buffer strides based on [TensorLinearStructure] implementation */ - public val linearStructure: Strides - get() = TensorLinearStructure(shape) + override val indices: Strides get() = TensorLinearStructure(shape) /** * Number of elements in tensor */ public val numElements: Int - get() = linearStructure.linearSize + get() = indices.linearSize - override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)] + override fun get(index: IntArray): T = mutableBuffer[bufferStart + indices.offset(index)] override fun set(index: IntArray, value: T) { - mutableBuffer[bufferStart + linearStructure.offset(index)] = value + mutableBuffer[bufferStart + indices.offset(index)] = value } @PerformancePitfall - override fun elements(): Sequence> = linearStructure.indices().map { + override fun elements(): Sequence> = indices.asSequence().map { it to get(it) } } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensor.kt index 41df50cba..ad7831fb9 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensor.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.core @@ -11,7 +11,7 @@ import space.kscience.kmath.tensors.core.internal.toPrettyString /** * Default [BufferedTensor] implementation for [Double] values */ -public class DoubleTensor internal constructor( +public class DoubleTensor @PublishedApi internal constructor( shape: IntArray, buffer: DoubleArray, offset: Int = 0 diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index e49119a6f..5e7ae262f 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -1,13 +1,17 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.core -import space.kscience.kmath.structures.MutableBuffer +import space.kscience.kmath.nd.MutableStructure2D +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.structures.MutableBuffer +import space.kscience.kmath.structures.indices import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.Tensor @@ -19,27 +23,71 @@ import kotlin.math.* * Implementation of basic operations over double tensors and basic algebra operations on them. */ public open class DoubleTensorAlgebra : - TensorPartialDivisionAlgebra, - AnalyticTensorAlgebra, - LinearOpsTensorAlgebra { + TensorPartialDivisionAlgebra, + AnalyticTensorAlgebra, + LinearOpsTensorAlgebra { public companion object : DoubleTensorAlgebra() + override val elementAlgebra: DoubleField + get() = DoubleField + /** - * Returns a single tensor value of unit dimension if tensor shape equals to [1]. + * Applies the [transform] function to each element of the tensor and returns the resulting modified tensor. * - * @return a nullable value of a potentially scalar tensor. + * @param transform the function to be applied to each element of the tensor. + * @return the resulting tensor after applying the function. */ - override fun Tensor.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) + @Suppress("OVERRIDE_BY_INLINE") + final override inline fun StructureND.map(transform: DoubleField.(Double) -> Double): DoubleTensor { + val tensor = this.tensor + //TODO remove additional copy + val sourceArray = tensor.copyArray() + val array = DoubleArray(tensor.numElements) { DoubleField.transform(sourceArray[it]) } + return DoubleTensor( + tensor.shape, + array, + tensor.bufferStart + ) + } + + @Suppress("OVERRIDE_BY_INLINE") + final override inline fun StructureND.mapIndexed(transform: DoubleField.(index: IntArray, Double) -> Double): DoubleTensor { + val tensor = this.tensor + //TODO remove additional copy + val sourceArray = tensor.copyArray() + val array = DoubleArray(tensor.numElements) { DoubleField.transform(tensor.indices.index(it), sourceArray[it]) } + return DoubleTensor( + tensor.shape, + array, + tensor.bufferStart + ) + } + + override fun zip( + left: StructureND, + right: StructureND, + transform: DoubleField.(Double, Double) -> Double + ): DoubleTensor { + require(left.shape.contentEquals(right.shape)){ + "The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}" + } + val leftTensor = left.tensor + val leftArray = leftTensor.copyArray() + val rightTensor = right.tensor + val rightArray = rightTensor.copyArray() + val array = DoubleArray(leftTensor.numElements) { DoubleField.transform(leftArray[it], rightArray[it]) } + return DoubleTensor( + leftTensor.shape, + array + ) + } + + override fun StructureND.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) tensor.mutableBuffer.array()[tensor.bufferStart] else null - /** - * Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. - * - * @return the value of a scalar tensor. - */ - override fun Tensor.value(): Double = valueOrNull() + override fun StructureND.value(): Double = valueOrNull() ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") /** @@ -63,11 +111,10 @@ public open class DoubleTensorAlgebra : * @param initializer mapping tensor indices to values. * @return tensor with the [shape] shape and data generated by the [initializer]. */ - public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor = - fromArray( - shape, - TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray() - ) + override fun structureND(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): DoubleTensor = fromArray( + shape, + TensorLinearStructure(shape).asSequence().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray() + ) override operator fun Tensor.get(i: Int): DoubleTensor { val lastShape = tensor.shape.drop(1).toIntArray() @@ -102,37 +149,37 @@ public open class DoubleTensorAlgebra : } /** - * Returns a tensor filled with the scalar value 0.0, with the shape defined by the variable argument [shape]. + * Returns a tensor filled with the scalar value `0.0`, with the shape defined by the variable argument [shape]. * * @param shape array of integers defining the shape of the output tensor. - * @return tensor filled with the scalar value 0.0, with the [shape] shape. + * @return tensor filled with the scalar value `0.0`, with the [shape] shape. */ public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape) /** - * Returns a tensor filled with the scalar value 0.0, with the same shape as a given array. + * Returns a tensor filled with the scalar value `0.0`, with the same shape as a given array. * - * @return tensor filled with the scalar value 0.0, with the same shape as `input` tensor. + * @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor. */ - public fun Tensor.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) + public fun StructureND.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) /** - * Returns a tensor filled with the scalar value 1.0, with the shape defined by the variable argument [shape]. + * Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape]. * * @param shape array of integers defining the shape of the output tensor. - * @return tensor filled with the scalar value 1.0, with the [shape] shape. + * @return tensor filled with the scalar value `1.0`, with the [shape] shape. */ public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape) /** - * Returns a tensor filled with the scalar value 1.0, with the same shape as a given array. + * Returns a tensor filled with the scalar value `1.0`, with the same shape as a given array. * - * @return tensor filled with the scalar value 1.0, with the same shape as `input` tensor. + * @return tensor filled with the scalar value `1.0`, with the same shape as `input` tensor. */ public fun Tensor.onesLike(): DoubleTensor = tensor.fullLike(1.0) /** - * Returns a 2-D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere. + * Returns a 2D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere. * * @param n the number of rows and columns * @return a 2-D tensor with ones on the diagonal and zeros elsewhere. @@ -152,23 +199,22 @@ public open class DoubleTensorAlgebra : * * @return a copy of the `input` tensor with a copied buffer. */ - public fun Tensor.copy(): DoubleTensor { - return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) - } + public fun StructureND.copy(): DoubleTensor = + DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) - override fun Double.plus(other: Tensor): DoubleTensor { - val resBuffer = DoubleArray(other.tensor.numElements) { i -> - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this + override fun Double.plus(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + this } - return DoubleTensor(other.shape, resBuffer) + return DoubleTensor(arg.shape, resBuffer) } - override fun Tensor.plus(value: Double): DoubleTensor = value + tensor + override fun StructureND.plus(arg: Double): DoubleTensor = arg + tensor - override fun Tensor.plus(other: Tensor): DoubleTensor { - checkShapesCompatible(tensor, other.tensor) + override fun StructureND.plus(arg: StructureND): DoubleTensor { + checkShapesCompatible(tensor, arg.tensor) val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i] + tensor.mutableBuffer.array()[i] + arg.tensor.mutableBuffer.array()[i] } return DoubleTensor(tensor.shape, resBuffer) } @@ -179,32 +225,32 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.plusAssign(other: Tensor) { - checkShapesCompatible(tensor, other.tensor) + override fun Tensor.plusAssign(arg: StructureND) { + checkShapesCompatible(tensor, arg.tensor) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] += - other.tensor.mutableBuffer.array()[tensor.bufferStart + i] + arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Double.minus(other: Tensor): DoubleTensor { - val resBuffer = DoubleArray(other.tensor.numElements) { i -> - this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + override fun Double.minus(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + this - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] } - return DoubleTensor(other.shape, resBuffer) + return DoubleTensor(arg.shape, resBuffer) } - override fun Tensor.minus(value: Double): DoubleTensor { + override fun StructureND.minus(arg: Double): DoubleTensor { val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] - value + tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg } return DoubleTensor(tensor.shape, resBuffer) } - override fun Tensor.minus(other: Tensor): DoubleTensor { - checkShapesCompatible(tensor, other) + override fun StructureND.minus(arg: StructureND): DoubleTensor { + checkShapesCompatible(tensor, arg) val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i] + tensor.mutableBuffer.array()[i] - arg.tensor.mutableBuffer.array()[i] } return DoubleTensor(tensor.shape, resBuffer) } @@ -215,28 +261,28 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.minusAssign(other: Tensor) { - checkShapesCompatible(tensor, other) + override fun Tensor.minusAssign(arg: StructureND) { + checkShapesCompatible(tensor, arg) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] -= - other.tensor.mutableBuffer.array()[tensor.bufferStart + i] + arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Double.times(other: Tensor): DoubleTensor { - val resBuffer = DoubleArray(other.tensor.numElements) { i -> - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this + override fun Double.times(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this } - return DoubleTensor(other.shape, resBuffer) + return DoubleTensor(arg.shape, resBuffer) } - override fun Tensor.times(value: Double): DoubleTensor = value * tensor + override fun StructureND.times(arg: Double): DoubleTensor = arg * tensor - override fun Tensor.times(other: Tensor): DoubleTensor { - checkShapesCompatible(tensor, other) + override fun StructureND.times(arg: StructureND): DoubleTensor { + checkShapesCompatible(tensor, arg) val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[tensor.bufferStart + i] * - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] } return DoubleTensor(tensor.shape, resBuffer) } @@ -247,33 +293,33 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.timesAssign(other: Tensor) { - checkShapesCompatible(tensor, other) + override fun Tensor.timesAssign(arg: StructureND) { + checkShapesCompatible(tensor, arg) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] *= - other.tensor.mutableBuffer.array()[tensor.bufferStart + i] + arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Double.div(other: Tensor): DoubleTensor { - val resBuffer = DoubleArray(other.tensor.numElements) { i -> - this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + override fun Double.div(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + this / arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] } - return DoubleTensor(other.shape, resBuffer) + return DoubleTensor(arg.shape, resBuffer) } - override fun Tensor.div(value: Double): DoubleTensor { + override fun StructureND.div(arg: Double): DoubleTensor { val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] / value + tensor.mutableBuffer.array()[tensor.bufferStart + i] / arg } return DoubleTensor(shape, resBuffer) } - override fun Tensor.div(other: Tensor): DoubleTensor { - checkShapesCompatible(tensor, other) + override fun StructureND.div(arg: StructureND): DoubleTensor { + checkShapesCompatible(tensor, arg) val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[other.tensor.bufferStart + i] / - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] / + arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] } return DoubleTensor(tensor.shape, resBuffer) } @@ -284,15 +330,15 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.divAssign(other: Tensor) { - checkShapesCompatible(tensor, other) + override fun Tensor.divAssign(arg: StructureND) { + checkShapesCompatible(tensor, arg) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /= - other.tensor.mutableBuffer.array()[tensor.bufferStart + i] + arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Tensor.unaryMinus(): DoubleTensor { + override fun StructureND.unaryMinus(): DoubleTensor { val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() } @@ -312,27 +358,26 @@ public open class DoubleTensorAlgebra : val resTensor = DoubleTensor(resShape, resBuffer) for (offset in 0 until n) { - val oldMultiIndex = tensor.linearStructure.index(offset) + val oldMultiIndex = tensor.indices.index(offset) val newMultiIndex = oldMultiIndex.copyOf() newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } - val linearIndex = resTensor.linearStructure.offset(newMultiIndex) + val linearIndex = resTensor.indices.offset(newMultiIndex) resTensor.mutableBuffer.array()[linearIndex] = tensor.mutableBuffer.array()[tensor.bufferStart + offset] } return resTensor } - override fun Tensor.view(shape: IntArray): DoubleTensor { checkView(tensor, shape) return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) } - override fun Tensor.viewAs(other: Tensor): DoubleTensor = + override fun Tensor.viewAs(other: StructureND): DoubleTensor = tensor.view(other.shape) - override infix fun Tensor.dot(other: Tensor): DoubleTensor { + override infix fun StructureND.dot(other: StructureND): DoubleTensor { if (tensor.shape.size == 1 && other.shape.size == 1) { return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) } @@ -384,8 +429,12 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun diagonalEmbedding(diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int): - DoubleTensor { + override fun diagonalEmbedding( + diagonalEntries: Tensor, + offset: Int, + dim1: Int, + dim2: Int + ): DoubleTensor { val n = diagonalEntries.shape.size val d1 = minusIndexFrom(n + 1, dim1) val d2 = minusIndexFrom(n + 1, dim2) @@ -413,7 +462,7 @@ public open class DoubleTensorAlgebra : val resTensor = zeros(resShape) for (i in 0 until diagonalEntries.tensor.numElements) { - val multiIndex = diagonalEntries.tensor.linearStructure.index(i) + val multiIndex = diagonalEntries.tensor.indices.index(i) var offset1 = 0 var offset2 = abs(realOffset) @@ -432,20 +481,6 @@ public open class DoubleTensorAlgebra : return resTensor.tensor } - /** - * Applies the [transform] function to each element of the tensor and returns the resulting modified tensor. - * - * @param transform the function to be applied to each element of the tensor. - * @return the resulting tensor after applying the function. - */ - public fun Tensor.map(transform: (Double) -> Double): DoubleTensor { - return DoubleTensor( - tensor.shape, - tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(), - tensor.bufferStart - ) - } - /** * Compares element-wise two tensors with a specified precision. * @@ -458,7 +493,7 @@ public open class DoubleTensorAlgebra : /** * Compares element-wise two tensors. - * Comparison of two Double values occurs with 1e-5 precision. + * Comparison of two Double values occurs with `1e-5` precision. * * @param other the tensor to compare with `input` tensor. * @return true if two tensors have the same shape and elements, false otherwise. @@ -487,23 +522,24 @@ public open class DoubleTensorAlgebra : } /** - * Returns a tensor of random numbers drawn from normal distributions with 0.0 mean and 1.0 standard deviation. + * Returns a tensor of random numbers drawn from normal distributions with `0.0` mean and `1.0` standard deviation. * * @param shape the desired shape for the output tensor. * @param seed the random seed of the pseudo-random number generator. * @return tensor of a given shape filled with numbers from the normal distribution - * with 0.0 mean and 1.0 standard deviation. + * with `0.0` mean and `1.0` standard deviation. */ public fun randomNormal(shape: IntArray, seed: Long = 0): DoubleTensor = DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed)) /** * Returns a tensor with the same shape as `input` of random numbers drawn from normal distributions - * with 0.0 mean and 1.0 standard deviation. + * with `0.0` mean and `1.0` standard deviation. * + * @receiver the `input`. * @param seed the random seed of the pseudo-random number generator. - * @return tensor with the same shape as `input` filled with numbers from the normal distribution - * with 0.0 mean and 1.0 standard deviation. + * @return a tensor with the same shape as `input` filled with numbers from the normal distribution + * with `0.0` mean and `1.0` standard deviation. */ public fun Tensor.randomNormalLike(seed: Long = 0): DoubleTensor = DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) @@ -526,19 +562,17 @@ public open class DoubleTensorAlgebra : } /** - * Builds tensor from rows of input tensor + * Builds tensor from rows of the input tensor. * * @param indices the [IntArray] of 1-dimensional indices - * @return tensor with rows corresponding to rows by [indices] + * @return tensor with rows corresponding to row by [indices] */ - public fun Tensor.rowsByIndices(indices: IntArray): DoubleTensor { - return stack(indices.map { this[it] }) - } + public fun Tensor.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) - internal fun Tensor.fold(foldFunction: (DoubleArray) -> Double): Double = - foldFunction(tensor.toDoubleArray()) + internal inline fun StructureND.fold(foldFunction: (DoubleArray) -> Double): Double = + foldFunction(tensor.copyArray()) - internal inline fun Tensor.foldDim( + internal inline fun StructureND.foldDim( foldFunction: (DoubleArray) -> R, dim: Int, keepDim: Boolean, @@ -552,10 +586,8 @@ public open class DoubleTensorAlgebra : val resNumElements = resShape.reduce(Int::times) val init = foldFunction(DoubleArray(1){0.0}) val resTensor = BufferedTensor(resShape, - MutableBuffer.auto(resNumElements) { init }, - //MutableList(resNumElements) { init }.asMutableBuffer(), - 0) - for (index in resTensor.linearStructure.indices()) { + MutableBuffer.auto(resNumElements) { init }, 0) + for (index in resTensor.indices) { val prefix = index.take(dim).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray() resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i -> @@ -565,31 +597,31 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun Tensor.sum(): Double = tensor.fold { it.sum() } + override fun StructureND.sum(): Double = tensor.fold { it.sum() } - override fun Tensor.sum(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.sum(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.sum() }, dim, keepDim).toDoubleTensor() - override fun Tensor.min(): Double = this.fold { it.minOrNull()!! } + override fun StructureND.min(): Double = this.fold { it.minOrNull()!! } - override fun Tensor.min(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.min(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.minOrNull()!! }, dim, keepDim).toDoubleTensor() - override fun Tensor.max(): Double = this.fold { it.maxOrNull()!! } + override fun StructureND.max(): Double = this.fold { it.maxOrNull()!! } - override fun Tensor.max(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.max(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.maxOrNull()!! }, dim, keepDim).toDoubleTensor() - override fun Tensor.argMax(dim: Int, keepDim: Boolean): IntTensor = + override fun StructureND.argMax(dim: Int, keepDim: Boolean): IntTensor = foldDim({ x -> x.withIndex().maxByOrNull { it.value }?.index!! }, dim, keepDim).toIntTensor() - override fun Tensor.mean(): Double = this.fold { it.sum() / tensor.numElements } + override fun StructureND.mean(): Double = this.fold { it.sum() / tensor.numElements } - override fun Tensor.mean(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.mean(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } @@ -599,12 +631,12 @@ public open class DoubleTensorAlgebra : keepDim ).toDoubleTensor() - override fun Tensor.std(): Double = this.fold { arr -> + override fun StructureND.std(): Double = this.fold { arr -> val mean = arr.sum() / tensor.numElements sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)) } - override fun Tensor.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( + override fun StructureND.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } val mean = arr.sum() / shape[dim] @@ -614,12 +646,12 @@ public open class DoubleTensorAlgebra : keepDim ).toDoubleTensor() - override fun Tensor.variance(): Double = this.fold { arr -> + override fun StructureND.variance(): Double = this.fold { arr -> val mean = arr.sum() / tensor.numElements arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1) } - override fun Tensor.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( + override fun StructureND.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } val mean = arr.sum() / shape[dim] @@ -635,14 +667,14 @@ public open class DoubleTensorAlgebra : } /** - * Returns the covariance matrix M of given vectors. + * Returns the covariance matrix `M` of given vectors. * - * M[i, j] contains covariance of i-th and j-th given vectors + * `M[i, j]` contains covariance of `i`-th and `j`-th given vectors * * @param tensors the [List] of 1-dimensional tensors with same shape - * @return the covariance matrix + * @return `M`. */ - public fun cov(tensors: List>): DoubleTensor { + public fun cov(tensors: List>): DoubleTensor { check(tensors.isNotEmpty()) { "List must have at least 1 element" } val n = tensors.size val m = tensors[0].shape[0] @@ -659,43 +691,43 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun Tensor.exp(): DoubleTensor = tensor.map(::exp) + override fun StructureND.exp(): DoubleTensor = tensor.map { exp(it) } - override fun Tensor.ln(): DoubleTensor = tensor.map(::ln) + override fun StructureND.ln(): DoubleTensor = tensor.map { ln(it) } - override fun Tensor.sqrt(): DoubleTensor = tensor.map(::sqrt) + override fun StructureND.sqrt(): DoubleTensor = tensor.map { sqrt(it) } - override fun Tensor.cos(): DoubleTensor = tensor.map(::cos) + override fun StructureND.cos(): DoubleTensor = tensor.map { cos(it) } - override fun Tensor.acos(): DoubleTensor = tensor.map(::acos) + override fun StructureND.acos(): DoubleTensor = tensor.map { acos(it) } - override fun Tensor.cosh(): DoubleTensor = tensor.map(::cosh) + override fun StructureND.cosh(): DoubleTensor = tensor.map { cosh(it) } - override fun Tensor.acosh(): DoubleTensor = tensor.map(::acosh) + override fun StructureND.acosh(): DoubleTensor = tensor.map { acosh(it) } - override fun Tensor.sin(): DoubleTensor = tensor.map(::sin) + override fun StructureND.sin(): DoubleTensor = tensor.map { sin(it) } - override fun Tensor.asin(): DoubleTensor = tensor.map(::asin) + override fun StructureND.asin(): DoubleTensor = tensor.map { asin(it) } - override fun Tensor.sinh(): DoubleTensor = tensor.map(::sinh) + override fun StructureND.sinh(): DoubleTensor = tensor.map { sinh(it) } - override fun Tensor.asinh(): DoubleTensor = tensor.map(::asinh) + override fun StructureND.asinh(): DoubleTensor = tensor.map { asinh(it) } - override fun Tensor.tan(): DoubleTensor = tensor.map(::tan) + override fun StructureND.tan(): DoubleTensor = tensor.map { tan(it) } - override fun Tensor.atan(): DoubleTensor = tensor.map(::atan) + override fun StructureND.atan(): DoubleTensor = tensor.map { atan(it) } - override fun Tensor.tanh(): DoubleTensor = tensor.map(::tanh) + override fun StructureND.tanh(): DoubleTensor = tensor.map { tanh(it) } - override fun Tensor.atanh(): DoubleTensor = tensor.map(::atanh) + override fun StructureND.atanh(): DoubleTensor = tensor.map { atanh(it) } - override fun Tensor.ceil(): DoubleTensor = tensor.map(::ceil) + override fun StructureND.ceil(): DoubleTensor = tensor.map { ceil(it) } - override fun Tensor.floor(): DoubleTensor = tensor.map(::floor) + override fun StructureND.floor(): DoubleTensor = tensor.map { floor(it) } - override fun Tensor.inv(): DoubleTensor = invLU(1e-9) + override fun StructureND.inv(): DoubleTensor = invLU(1e-9) - override fun Tensor.det(): DoubleTensor = detLU(1e-9) + override fun StructureND.det(): DoubleTensor = detLU(1e-9) /** * Computes the LU factorization of a matrix or batches of matrices `input`. @@ -706,7 +738,7 @@ public open class DoubleTensorAlgebra : * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. */ - public fun Tensor.luFactor(epsilon: Double): Pair = + public fun StructureND.luFactor(epsilon: Double): Pair = computeLU(tensor, epsilon) ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") @@ -719,21 +751,21 @@ public open class DoubleTensorAlgebra : * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. */ - public fun Tensor.luFactor(): Pair = luFactor(1e-9) + public fun StructureND.luFactor(): Pair = luFactor(1e-9) /** * Unpacks the data and pivots from a LU factorization of a tensor. - * Given a tensor [luTensor], return tensors (P, L, U) satisfying ``P * luTensor = L * U``, + * Given a tensor [luTensor], return tensors `Triple(P, L, U)` satisfying `P dot luTensor = L dot U`, * with `P` being a permutation matrix or batch of matrices, * `L` being a lower triangular matrix or batch of matrices, * `U` being an upper triangular matrix or batch of matrices. * * @param luTensor the packed LU factorization data * @param pivotsTensor the packed LU factorization pivots - * @return triple of P, L and U tensors + * @return triple of `P`, `L` and `U` tensors */ public fun luPivot( - luTensor: Tensor, + luTensor: StructureND, pivotsTensor: Tensor, ): Triple { checkSquareMatrix(luTensor.shape) @@ -766,16 +798,17 @@ public open class DoubleTensorAlgebra : /** * QR decomposition. * - * Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors. - * Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``, + * Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `Q to R` of tensors. + * Given a tensor `input`, return tensors `Q to R` satisfying `input == Q dot R`, * with `Q` being an orthogonal matrix or batch of orthogonal matrices * and `R` being an upper triangular matrix or batch of upper triangular matrices. * - * @param epsilon permissible error when comparing tensors for equality. + * @receiver the `input`. + * @param epsilon the permissible error when comparing tensors for equality. * Used when checking the positive definiteness of the input matrix or matrices. - * @return pair of Q and R tensors. + * @return a pair of `Q` and `R` tensors. */ - public fun Tensor.cholesky(epsilon: Double): DoubleTensor { + public fun StructureND.cholesky(epsilon: Double): DoubleTensor { checkSquareMatrix(shape) checkPositiveDefinite(tensor, epsilon) @@ -788,9 +821,9 @@ public open class DoubleTensorAlgebra : return lTensor } - override fun Tensor.cholesky(): DoubleTensor = cholesky(1e-6) + override fun StructureND.cholesky(): DoubleTensor = cholesky(1e-6) - override fun Tensor.qr(): Pair { + override fun StructureND.qr(): Pair { checkSquareMatrix(shape) val qTensor = zeroesLike() val rTensor = zeroesLike() @@ -806,22 +839,23 @@ public open class DoubleTensorAlgebra : return qTensor to rTensor } - override fun Tensor.svd(): Triple = + override fun StructureND.svd(): Triple = svd(epsilon = 1e-10) /** * Singular Value Decomposition. * * Computes the singular value decomposition of either a matrix or batch of matrices `input`. - * The singular value decomposition is represented as a triple `(U, S, V)`, - * such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``. - * If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input. + * The singular value decomposition is represented as a triple `Triple(U, S, V)`, + * such that `input == U dot diagonalEmbedding(S) dot V.transpose()`. + * If `input` is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as `input. * - * @param epsilon permissible error when calculating the dot product of vectors, - * i.e. the precision with which the cosine approaches 1 in an iterative algorithm. - * @return triple `(U, S, V)`. + * @receiver the `input`. + * @param epsilon permissible error when calculating the dot product of vectors + * i.e., the precision with which the cosine approaches 1 in an iterative algorithm. + * @return a triple `Triple(U, S, V)`. */ - public fun Tensor.svd(epsilon: Double): Triple { + public fun StructureND.svd(epsilon: Double): Triple { val size = tensor.dimension val commonShape = tensor.shape.sliceArray(0 until size - 2) val (n, m) = tensor.shape.sliceArray(size - 2 until size) @@ -829,45 +863,63 @@ public open class DoubleTensorAlgebra : val sTensor = zeros(commonShape + intArrayOf(min(n, m))) val vTensor = zeros(commonShape + intArrayOf(min(n, m), m)) - tensor.matrixSequence() - .zip( - uTensor.matrixSequence() - .zip( - sTensor.vectorSequence() - .zip(vTensor.matrixSequence()) - ) - ).forEach { (matrix, USV) -> - val matrixSize = matrix.shape.reduce { acc, i -> acc * i } - val curMatrix = DoubleTensor( - matrix.shape, - matrix.mutableBuffer.array().slice(matrix.bufferStart until matrix.bufferStart + matrixSize) - .toDoubleArray() - ) - svdHelper(curMatrix, USV, m, n, epsilon) - } + val matrices = tensor.matrices + val uTensors = uTensor.matrices + val sTensorVectors = sTensor.vectors + val vTensors = vTensor.matrices + + for (index in matrices.indices) { + val matrix = matrices[index] + val usv = Triple( + uTensors[index], + sTensorVectors[index], + vTensors[index] + ) + val matrixSize = matrix.shape.reduce { acc, i -> acc * i } + val curMatrix = DoubleTensor( + matrix.shape, + matrix.mutableBuffer.array() + .slice(matrix.bufferStart until matrix.bufferStart + matrixSize) + .toDoubleArray() + ) + svdHelper(curMatrix, usv, m, n, epsilon) + } return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) } - override fun Tensor.symEig(): Pair = - symEig(epsilon = 1e-15) + override fun StructureND.symEig(): Pair = symEig(epsilon = 1e-15) /** * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, - * represented by a pair (eigenvalues, eigenvectors). + * represented by a pair `eigenvalues to eigenvectors`. * - * @param epsilon permissible error when comparing tensors for equality + * @param epsilon the permissible error when comparing tensors for equality * and when the cosine approaches 1 in the SVD algorithm. - * @return a pair (eigenvalues, eigenvectors) + * @return a pair `eigenvalues to eigenvectors`. */ - public fun Tensor.symEig(epsilon: Double): Pair { + public fun StructureND.symEig(epsilon: Double): Pair { checkSymmetric(tensor, epsilon) + + fun MutableStructure2D.cleanSym(n: Int) { + for (i in 0 until n) { + for (j in 0 until n) { + if (i == j) { + this[i, j] = sign(this[i, j]) + } else { + this[i, j] = 0.0 + } + } + } + } + val (u, s, v) = tensor.svd(epsilon) val shp = s.shape + intArrayOf(1) val utv = u.transpose() dot v val n = s.shape.last() - for (matrix in utv.matrixSequence()) - cleanSymHelper(matrix.as2D(), n) + for (matrix in utv.matrixSequence()) { + matrix.as2D().cleanSym(n) + } val eig = (utv dot s.view(shp)).view(s.shape) return eig to v @@ -877,11 +929,11 @@ public open class DoubleTensorAlgebra : * Computes the determinant of a square matrix input, or of each square matrix in a batched input * using LU factorization algorithm. * - * @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero + * @param epsilon the error in the LU algorithm—permissible error when comparing the determinant of a matrix + * with zero. * @return the determinant. */ - public fun Tensor.detLU(epsilon: Double = 1e-9): DoubleTensor { - + public fun StructureND.detLU(epsilon: Double = 1e-9): DoubleTensor { checkSquareMatrix(tensor.shape) val luTensor = tensor.copy() val pivotsTensor = tensor.setUpPivots() @@ -909,12 +961,12 @@ public open class DoubleTensorAlgebra : * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input * using LU factorization algorithm. * Given a square matrix `a`, return the matrix `aInv` satisfying - * ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``. + * `a dot aInv == aInv dot a == eye(a.shape[0])`. * - * @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero + * @param epsilon error in the LU algorithm—permissible error when comparing the determinant of a matrix with zero * @return the multiplicative inverse of a matrix. */ - public fun Tensor.invLU(epsilon: Double = 1e-9): DoubleTensor { + public fun StructureND.invLU(epsilon: Double = 1e-9): DoubleTensor { val (luTensor, pivotsTensor) = luFactor(epsilon) val invTensor = luTensor.zeroesLike() @@ -928,23 +980,25 @@ public open class DoubleTensorAlgebra : } /** - * LUP decomposition + * LUP decomposition. * * Computes the LUP decomposition of a matrix or a batch of matrices. - * Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``, + * Given a tensor `input`, return tensors `Triple(P, L, U)` satisfying `P dot input == L dot U`, * with `P` being a permutation matrix or batch of matrices, * `L` being a lower triangular matrix or batch of matrices, * `U` being an upper triangular matrix or batch of matrices. * - * @param epsilon permissible error when comparing the determinant of a matrix with zero - * @return triple of P, L and U tensors + * @param epsilon permissible error when comparing the determinant of a matrix with zero. + * @return triple of `P`, `L` and `U` tensors. */ - public fun Tensor.lu(epsilon: Double = 1e-9): Triple { + public fun StructureND.lu(epsilon: Double = 1e-9): Triple { val (lu, pivots) = tensor.luFactor(epsilon) return luPivot(lu, pivots) } - override fun Tensor.lu(): Triple = lu(1e-9) + override fun StructureND.lu(): Triple = lu(1e-9) } +public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra + diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt index e3d7c3d35..715d9035f 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.core diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/TensorLinearStructure.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/TensorLinearStructure.kt index 7bfa9d3f8..57cdfee2f 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/TensorLinearStructure.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/TensorLinearStructure.kt @@ -14,7 +14,7 @@ import kotlin.math.max * * @param shape the shape of the tensor. */ -public class TensorLinearStructure(override val shape: IntArray) : Strides { +public class TensorLinearStructure(override val shape: IntArray) : Strides() { override val strides: IntArray get() = stridesFromShape(shape) @@ -24,6 +24,10 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides { override val linearSize: Int get() = shape.reduce(Int::times) + override fun equals(other: Any?): Boolean = false + + override fun hashCode(): Int = 0 + public companion object { public fun stridesFromShape(shape: IntArray): IntArray { diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/TensorLinearStructure.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/TensorLinearStructure.kt new file mode 100644 index 000000000..57668722a --- /dev/null +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/TensorLinearStructure.kt @@ -0,0 +1,71 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.tensors.core.internal + +import space.kscience.kmath.nd.Strides +import kotlin.math.max + + +internal fun stridesFromShape(shape: IntArray): IntArray { + val nDim = shape.size + val res = IntArray(nDim) + if (nDim == 0) + return res + + var current = nDim - 1 + res[current] = 1 + + while (current > 0) { + res[current - 1] = max(1, shape[current]) * res[current] + current-- + } + return res +} + +internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray { + val res = IntArray(nDim) + var current = offset + var strideIndex = 0 + + while (strideIndex < nDim) { + res[strideIndex] = (current / strides[strideIndex]) + current %= strides[strideIndex] + strideIndex++ + } + return res +} + +/** + * This [Strides] implementation follows the last dimension first convention + * For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html + * + * @param shape the shape of the tensor. + */ +internal class TensorLinearStructure(override val shape: IntArray) : Strides() { + override val strides: IntArray + get() = stridesFromShape(shape) + + override fun index(offset: Int): IntArray = + indexFromOffset(offset, strides, shape.size) + + override val linearSize: Int + get() = shape.reduce(Int::times) + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other == null || this::class != other::class) return false + + other as TensorLinearStructure + + if (!shape.contentEquals(other.shape)) return false + + return true + } + + override fun hashCode(): Int { + return shape.contentHashCode() + } +} diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/broadcastUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/broadcastUtils.kt index 6324dc242..3787c0972 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/broadcastUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/broadcastUtils.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.core.internal @@ -10,7 +10,7 @@ import kotlin.math.max internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) { for (linearIndex in 0 until linearSize) { - val totalMultiIndex = resTensor.linearStructure.index(linearIndex) + val totalMultiIndex = resTensor.indices.index(linearIndex) val curMultiIndex = tensor.shape.copyOf() val offset = totalMultiIndex.size - curMultiIndex.size @@ -23,7 +23,7 @@ internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTenso } } - val curLinearIndex = tensor.linearStructure.offset(curMultiIndex) + val curLinearIndex = tensor.indices.offset(curMultiIndex) resTensor.mutableBuffer.array()[linearIndex] = tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex] } @@ -112,7 +112,7 @@ internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List checkShapesCompatible(a: Tensor, b: Tensor) = +internal fun checkShapesCompatible(a: StructureND, b: StructureND) = check(a.shape contentEquals b.shape) { "Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} " } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt index 7d3617547..d31e02677 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.core.internal @@ -9,42 +9,55 @@ import space.kscience.kmath.nd.MutableStructure1D import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D +import space.kscience.kmath.operations.asSequence import space.kscience.kmath.operations.invoke -import space.kscience.kmath.tensors.core.* +import space.kscience.kmath.structures.VirtualBuffer +import space.kscience.kmath.tensors.core.BufferedTensor +import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensorAlgebra -import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.valueOrNull +import space.kscience.kmath.tensors.core.IntTensor import kotlin.math.abs import kotlin.math.min -import kotlin.math.sign import kotlin.math.sqrt +internal val BufferedTensor.vectors: VirtualBuffer> + get() { + val n = shape.size + val vectorOffset = shape[n - 1] + val vectorShape = intArrayOf(shape.last()) -internal fun BufferedTensor.vectorSequence(): Sequence> = sequence { - val n = shape.size - val vectorOffset = shape[n - 1] - val vectorShape = intArrayOf(shape.last()) - for (offset in 0 until numElements step vectorOffset) { - val vector = BufferedTensor(vectorShape, mutableBuffer, bufferStart + offset) - yield(vector) + return VirtualBuffer(numElements / vectorOffset) { index -> + val offset = index * vectorOffset + BufferedTensor(vectorShape, mutableBuffer, bufferStart + offset) + } } -} -internal fun BufferedTensor.matrixSequence(): Sequence> = sequence { - val n = shape.size - check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n" } - val matrixOffset = shape[n - 1] * shape[n - 2] - val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) - for (offset in 0 until numElements step matrixOffset) { - val matrix = BufferedTensor(matrixShape, mutableBuffer, bufferStart + offset) - yield(matrix) + +internal fun BufferedTensor.vectorSequence(): Sequence> = vectors.asSequence() + +/** + * A random access alternative to [matrixSequence] + */ +internal val BufferedTensor.matrices: VirtualBuffer> + get() { + val n = shape.size + check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n" } + val matrixOffset = shape[n - 1] * shape[n - 2] + val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) + + return VirtualBuffer(numElements / matrixOffset) { index -> + val offset = index * matrixOffset + BufferedTensor(matrixShape, mutableBuffer, bufferStart + offset) + } } -} + +internal fun BufferedTensor.matrixSequence(): Sequence> = matrices.asSequence() internal fun dotHelper( a: MutableStructure2D, b: MutableStructure2D, res: MutableStructure2D, - l: Int, m: Int, n: Int + l: Int, m: Int, n: Int, ) { for (i in 0 until l) { for (j in 0 until n) { @@ -60,7 +73,7 @@ internal fun dotHelper( internal fun luHelper( lu: MutableStructure2D, pivots: MutableStructure1D, - epsilon: Double + epsilon: Double, ): Boolean { val m = lu.rowNum @@ -122,7 +135,7 @@ internal fun BufferedTensor.setUpPivots(): IntTensor { internal fun DoubleTensorAlgebra.computeLU( tensor: DoubleTensor, - epsilon: Double + epsilon: Double, ): Pair? { checkSquareMatrix(tensor.shape) @@ -139,7 +152,7 @@ internal fun DoubleTensorAlgebra.computeLU( internal fun pivInit( p: MutableStructure2D, pivot: MutableStructure1D, - n: Int + n: Int, ) { for (i in 0 until n) { p[i, pivot[i]] = 1.0 @@ -150,7 +163,7 @@ internal fun luPivotHelper( l: MutableStructure2D, u: MutableStructure2D, lu: MutableStructure2D, - n: Int + n: Int, ) { for (i in 0 until n) { for (j in 0 until n) { @@ -170,7 +183,7 @@ internal fun luPivotHelper( internal fun choleskyHelper( a: MutableStructure2D, l: MutableStructure2D, - n: Int + n: Int, ) { for (i in 0 until n) { for (j in 0 until i) { @@ -200,7 +213,7 @@ internal fun luMatrixDet(lu: MutableStructure2D, pivots: MutableStructur internal fun luMatrixInv( lu: MutableStructure2D, pivots: MutableStructure1D, - invMatrix: MutableStructure2D + invMatrix: MutableStructure2D, ) { val m = lu.shape[0] @@ -227,7 +240,7 @@ internal fun luMatrixInv( internal fun DoubleTensorAlgebra.qrHelper( matrix: DoubleTensor, q: DoubleTensor, - r: MutableStructure2D + r: MutableStructure2D, ) { checkSquareMatrix(matrix.shape) val n = matrix.shape[0] @@ -280,12 +293,11 @@ internal fun DoubleTensorAlgebra.svd1d(a: DoubleTensor, epsilon: Double = 1e-10) internal fun DoubleTensorAlgebra.svdHelper( matrix: DoubleTensor, - USV: Pair, Pair, BufferedTensor>>, - m: Int, n: Int, epsilon: Double + USV: Triple, BufferedTensor, BufferedTensor>, + m: Int, n: Int, epsilon: Double, ) { val res = ArrayList>(0) - val (matrixU, SV) = USV - val (matrixS, matrixV) = SV + val (matrixU, matrixS, matrixV) = USV for (k in 0 until min(n, m)) { var a = matrix.copy() @@ -329,14 +341,3 @@ internal fun DoubleTensorAlgebra.svdHelper( matrixV.mutableBuffer.array()[matrixV.bufferStart + i] = vBuffer[i] } } - -internal fun cleanSymHelper(matrix: MutableStructure2D, n: Int) { - for (i in 0 until n) - for (j in 0 until n) { - if (i == j) { - matrix[i, j] = sign(matrix[i, j]) - } else { - matrix[i, j] = 0.0 - } - } -} diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt index 0da036735..602430b03 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt @@ -1,11 +1,12 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.core.internal import space.kscience.kmath.nd.MutableBufferND +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.structures.asMutableBuffer import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.BufferedTensor @@ -19,20 +20,24 @@ internal fun BufferedTensor.asTensor(): IntTensor = internal fun BufferedTensor.asTensor(): DoubleTensor = DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) -internal fun Tensor.copyToBufferedTensor(): BufferedTensor = +internal fun StructureND.copyToBufferedTensor(): BufferedTensor = BufferedTensor( this.shape, - TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0 + TensorLinearStructure(this.shape).asSequence().map(this::get).toMutableList().asMutableBuffer(), 0 ) -internal fun Tensor.toBufferedTensor(): BufferedTensor = when (this) { +internal fun StructureND.toBufferedTensor(): BufferedTensor = when (this) { is BufferedTensor -> this - is MutableBufferND -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides) - BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor() + is MutableBufferND -> if (this.indices == TensorLinearStructure(this.shape)) { + BufferedTensor(this.shape, this.buffer, 0) + } else { + this.copyToBufferedTensor() + } else -> this.copyToBufferedTensor() } -internal val Tensor.tensor: DoubleTensor +@PublishedApi +internal val StructureND.tensor: DoubleTensor get() = when (this) { is DoubleTensor -> this else -> this.toBufferedTensor().asTensor() diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/utils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/utils.kt index 0ffaf39e7..553ed6add 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/utils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/utils.kt @@ -1,11 +1,12 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.core.internal import space.kscience.kmath.nd.as1D +import space.kscience.kmath.operations.toMutableList import space.kscience.kmath.samplers.GaussianSampler import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.structures.* @@ -14,7 +15,7 @@ import space.kscience.kmath.tensors.core.DoubleTensor import kotlin.math.* /** - * Returns a reference to [IntArray] containing all of the elements of this [Buffer] or copy the data. + * Returns a reference to [IntArray] containing all the elements of this [Buffer] or copy the data. */ internal fun Buffer.array(): IntArray = when (this) { is IntBuffer -> array @@ -22,8 +23,9 @@ internal fun Buffer.array(): IntArray = when (this) { } /** - * Returns a reference to [DoubleArray] containing all of the elements of this [Buffer] or copy the data. + * Returns a reference to [DoubleArray] containing all the elements of this [Buffer] or copy the data. */ +@PublishedApi internal fun Buffer.array(): DoubleArray = when (this) { is DoubleBuffer -> array else -> this.toDoubleArray() @@ -83,7 +85,7 @@ internal fun format(value: Double, digits: Int = 4): String = buildString { internal fun DoubleTensor.toPrettyString(): String = buildString { var offset = 0 val shape = this@toPrettyString.shape - val linearStructure = this@toPrettyString.linearStructure + val linearStructure = this@toPrettyString.indices val vectorSize = shape.last() append("DoubleTensor(\n") var charOffset = 3 diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorAlgebraExtensions.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorAlgebraExtensions.kt new file mode 100644 index 000000000..916388ba9 --- /dev/null +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorAlgebraExtensions.kt @@ -0,0 +1,16 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.tensors.core + +import space.kscience.kmath.nd.Shape +import kotlin.jvm.JvmName + +@JvmName("varArgOne") +public fun DoubleTensorAlgebra.one(vararg shape: Int): DoubleTensor = ones(intArrayOf(*shape)) +public fun DoubleTensorAlgebra.one(shape: Shape): DoubleTensor = ones(shape) +@JvmName("varArgZero") +public fun DoubleTensorAlgebra.zero(vararg shape: Int): DoubleTensor = zeros(intArrayOf(*shape)) +public fun DoubleTensorAlgebra.zero(shape: Shape): DoubleTensor = zeros(shape) \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt index 142cb2156..feade56de 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.tensors.core @@ -19,18 +19,19 @@ public fun Tensor.toDoubleTensor(): DoubleTensor = this.tensor public fun Tensor.toIntTensor(): IntTensor = this.tensor /** - * Returns [DoubleArray] of tensor elements + * Returns a copy-protected [DoubleArray] of tensor elements */ -public fun DoubleTensor.toDoubleArray(): DoubleArray { +public fun DoubleTensor.copyArray(): DoubleArray { + //TODO use ArrayCopy return DoubleArray(numElements) { i -> mutableBuffer[bufferStart + i] } } /** - * Returns [IntArray] of tensor elements + * Returns a copy-protected [IntArray] of tensor elements */ -public fun IntTensor.toIntArray(): IntArray { +public fun IntTensor.copyArray(): IntArray { return IntArray(numElements) { i -> mutableBuffer[bufferStart + i] } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt index 35e605fd9..1171b5217 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + package space.kscience.kmath.tensors.core import space.kscience.kmath.operations.invoke diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleAnalyticTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleAnalyticTensorAlgebra.kt index 3b4c615b4..ba8182da2 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleAnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleAnalyticTensorAlgebra.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + package space.kscience.kmath.tensors.core import space.kscience.kmath.operations.invoke diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index 347bb683f..c50c99b54 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + package space.kscience.kmath.tensors.core import space.kscience.kmath.operations.invoke @@ -183,7 +188,7 @@ internal class TestDoubleLinearOpsTensorAlgebra { } -private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit { +private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10) { val svd = tensor.svd() val tensorSVD = svd.first diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensor.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensor.kt index a176abdd4..2686df19e 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensor.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensor.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + package space.kscience.kmath.tensors.core import space.kscience.kmath.misc.PerformancePitfall diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index e7e898008..2aee03b82 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -1,3 +1,8 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + package space.kscience.kmath.tensors.core @@ -17,7 +22,7 @@ internal class TestDoubleTensorAlgebra { } @Test - fun TestDoubleDiv() = DoubleTensorAlgebra { + fun testDoubleDiv() = DoubleTensorAlgebra { val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0)) val res = 2.0/tensor assertTrue(res.mutableBuffer.array() contentEquals doubleArrayOf(1.0, 0.5)) diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorBuffer.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorBuffer.kt index bbf502faf..32fb65b8a 100644 --- a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorBuffer.kt +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorBuffer.kt @@ -1,24 +1,28 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.viktor import org.jetbrains.bio.viktor.F64FlatArray +import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.MutableBuffer @Suppress("NOTHING_TO_INLINE", "OVERRIDE_BY_INLINE") -public class ViktorBuffer(public val flatArray: F64FlatArray) : MutableBuffer { - public override val size: Int +@JvmInline +public value class ViktorBuffer(public val flatArray: F64FlatArray) : MutableBuffer { + override val size: Int get() = flatArray.size - public override inline fun get(index: Int): Double = flatArray[index] + override inline fun get(index: Int): Double = flatArray[index] - public override inline fun set(index: Int, value: Double) { + override inline fun set(index: Int, value: Double) { flatArray[index] = value } - public override fun copy(): MutableBuffer = ViktorBuffer(flatArray.copy().flatten()) - public override operator fun iterator(): Iterator = flatArray.data.iterator() + override fun copy(): MutableBuffer = ViktorBuffer(flatArray.copy().flatten()) + override operator fun iterator(): Iterator = flatArray.data.iterator() + + override fun toString(): String = Buffer.toString(this) } diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt new file mode 100644 index 000000000..c50404c9c --- /dev/null +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorFieldOpsND.kt @@ -0,0 +1,124 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.viktor + +import org.jetbrains.bio.viktor.F64Array +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.nd.* +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.ExtendedFieldOps +import space.kscience.kmath.operations.NumbersAddOps + +@OptIn(UnstableKMathAPI::class) +@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +public open class ViktorFieldOpsND : + FieldOpsND, + ExtendedFieldOps> { + + public val StructureND.f64Buffer: F64Array + get() = when (this) { + is ViktorStructureND -> this.f64Buffer + else -> structureND(shape) { this@f64Buffer[it] }.f64Buffer + } + + override val elementAlgebra: DoubleField get() = DoubleField + + override fun structureND(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND = + F64Array(*shape).apply { + DefaultStrides(shape).asSequence().forEach { index -> + set(value = DoubleField.initializer(index), indices = index) + } + }.asStructure() + + override fun StructureND.unaryMinus(): StructureND = -1 * this + + override fun StructureND.map(transform: DoubleField.(Double) -> Double): ViktorStructureND = + F64Array(*shape).apply { + DefaultStrides(shape).asSequence().forEach { index -> + set(value = DoubleField.transform(this@map[index]), indices = index) + } + }.asStructure() + + override fun StructureND.mapIndexed( + transform: DoubleField.(index: IntArray, Double) -> Double, + ): ViktorStructureND = F64Array(*shape).apply { + DefaultStrides(shape).asSequence().forEach { index -> + set(value = DoubleField.transform(index, this@mapIndexed[index]), indices = index) + } + }.asStructure() + + override fun zip( + left: StructureND, + right: StructureND, + transform: DoubleField.(Double, Double) -> Double, + ): ViktorStructureND { + require(left.shape.contentEquals(right.shape)) + return F64Array(*left.shape).apply { + DefaultStrides(left.shape).asSequence().forEach { index -> + set(value = DoubleField.transform(left[index], right[index]), indices = index) + } + }.asStructure() + } + + override fun add(left: StructureND, right: StructureND): ViktorStructureND = + (left.f64Buffer + right.f64Buffer).asStructure() + + override fun scale(a: StructureND, value: Double): ViktorStructureND = + (a.f64Buffer * value).asStructure() + + override fun StructureND.plus(arg: StructureND): ViktorStructureND = + (f64Buffer + arg.f64Buffer).asStructure() + + override fun StructureND.minus(arg: StructureND): ViktorStructureND = + (f64Buffer - arg.f64Buffer).asStructure() + + override fun StructureND.times(k: Number): ViktorStructureND = + (f64Buffer * k.toDouble()).asStructure() + + override fun StructureND.plus(arg: Double): ViktorStructureND = + (f64Buffer.plus(arg)).asStructure() + + override fun sin(arg: StructureND): ViktorStructureND = arg.map { sin(it) } + override fun cos(arg: StructureND): ViktorStructureND = arg.map { cos(it) } + override fun tan(arg: StructureND): ViktorStructureND = arg.map { tan(it) } + override fun asin(arg: StructureND): ViktorStructureND = arg.map { asin(it) } + override fun acos(arg: StructureND): ViktorStructureND = arg.map { acos(it) } + override fun atan(arg: StructureND): ViktorStructureND = arg.map { atan(it) } + + override fun power(arg: StructureND, pow: Number): ViktorStructureND = arg.map { it.pow(pow) } + + override fun exp(arg: StructureND): ViktorStructureND = arg.f64Buffer.exp().asStructure() + + override fun ln(arg: StructureND): ViktorStructureND = arg.f64Buffer.log().asStructure() + + override fun sinh(arg: StructureND): ViktorStructureND = arg.map { sinh(it) } + + override fun cosh(arg: StructureND): ViktorStructureND = arg.map { cosh(it) } + + override fun asinh(arg: StructureND): ViktorStructureND = arg.map { asinh(it) } + + override fun acosh(arg: StructureND): ViktorStructureND = arg.map { acosh(it) } + + override fun atanh(arg: StructureND): ViktorStructureND = arg.map { atanh(it) } + + public companion object : ViktorFieldOpsND() +} + +public val DoubleField.viktorAlgebra: ViktorFieldOpsND get() = ViktorFieldOpsND + +public open class ViktorFieldND( + override val shape: Shape +) : ViktorFieldOpsND(), FieldND, NumbersAddOps> { + override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() } + override val one: ViktorStructureND by lazy { F64Array.full(init = 1.0, shape = shape).asStructure() } + + override fun number(value: Number): ViktorStructureND = + F64Array.full(init = value.toDouble(), shape = shape).asStructure() +} + +public fun DoubleField.viktorAlgebra(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape) + +public fun ViktorFieldND(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape) \ No newline at end of file diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt index b7abf4304..4926652ed 100644 --- a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt @@ -1,126 +1,30 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. */ package space.kscience.kmath.viktor import org.jetbrains.bio.viktor.F64Array import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.nd.* -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.NumbersAddOperations -import space.kscience.kmath.operations.ScaleOperations +import space.kscience.kmath.nd.DefaultStrides +import space.kscience.kmath.nd.MutableStructureND @Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructureND { - public override val shape: IntArray get() = f64Buffer.shape + override val shape: IntArray get() = f64Buffer.shape - public override inline fun get(index: IntArray): Double = f64Buffer.get(*index) + override inline fun get(index: IntArray): Double = f64Buffer.get(*index) - public override inline fun set(index: IntArray, value: Double) { + override inline fun set(index: IntArray, value: Double) { f64Buffer.set(*index, value = value) } @PerformancePitfall - public override fun elements(): Sequence> = - DefaultStrides(shape).indices().map { it to get(it) } + override fun elements(): Sequence> = + DefaultStrides(shape).asSequence().map { it to get(it) } } public fun F64Array.asStructure(): ViktorStructureND = ViktorStructureND(this) -@OptIn(UnstableKMathAPI::class) -@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public class ViktorFieldND(public override val shape: IntArray) : FieldND, - NumbersAddOperations>, ExtendedField>, - ScaleOperations> { - public val StructureND.f64Buffer: F64Array - get() = when { - !shape.contentEquals(this@ViktorFieldND.shape) -> throw ShapeMismatchException( - this@ViktorFieldND.shape, - shape - ) - this is ViktorStructureND && this.f64Buffer.shape.contentEquals(this@ViktorFieldND.shape) -> this.f64Buffer - else -> produce { this@f64Buffer[it] }.f64Buffer - } - - public override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() } - public override val one: ViktorStructureND by lazy { F64Array.full(init = 1.0, shape = shape).asStructure() } - - private val strides: Strides = DefaultStrides(shape) - - public override val elementContext: DoubleField get() = DoubleField - - public override fun produce(initializer: DoubleField.(IntArray) -> Double): ViktorStructureND = - F64Array(*shape).apply { - this@ViktorFieldND.strides.indices().forEach { index -> - set(value = DoubleField.initializer(index), indices = index) - } - }.asStructure() - - public override fun StructureND.unaryMinus(): StructureND = -1 * this - - public override fun StructureND.map(transform: DoubleField.(Double) -> Double): ViktorStructureND = - F64Array(*this@ViktorFieldND.shape).apply { - this@ViktorFieldND.strides.indices().forEach { index -> - set(value = DoubleField.transform(this@map[index]), indices = index) - } - }.asStructure() - - public override fun StructureND.mapIndexed( - transform: DoubleField.(index: IntArray, Double) -> Double, - ): ViktorStructureND = F64Array(*this@ViktorFieldND.shape).apply { - this@ViktorFieldND.strides.indices().forEach { index -> - set(value = DoubleField.transform(index, this@mapIndexed[index]), indices = index) - } - }.asStructure() - - public override fun combine( - a: StructureND, - b: StructureND, - transform: DoubleField.(Double, Double) -> Double, - ): ViktorStructureND = F64Array(*shape).apply { - this@ViktorFieldND.strides.indices().forEach { index -> - set(value = DoubleField.transform(a[index], b[index]), indices = index) - } - }.asStructure() - - public override fun add(a: StructureND, b: StructureND): ViktorStructureND = - (a.f64Buffer + b.f64Buffer).asStructure() - - public override fun scale(a: StructureND, value: Double): ViktorStructureND = - (a.f64Buffer * value.toDouble()).asStructure() - - public override inline fun StructureND.plus(b: StructureND): ViktorStructureND = - (f64Buffer + b.f64Buffer).asStructure() - - public override inline fun StructureND.minus(b: StructureND): ViktorStructureND = - (f64Buffer - b.f64Buffer).asStructure() - - public override inline fun StructureND.times(k: Number): ViktorStructureND = - (f64Buffer * k.toDouble()).asStructure() - - public override inline fun StructureND.plus(arg: Double): ViktorStructureND = - (f64Buffer.plus(arg)).asStructure() - - public override fun number(value: Number): ViktorStructureND = - F64Array.full(init = value.toDouble(), shape = shape).asStructure() - - public override fun sin(arg: StructureND): ViktorStructureND = arg.map { sin(it) } - public override fun cos(arg: StructureND): ViktorStructureND = arg.map { cos(it) } - public override fun tan(arg: StructureND): ViktorStructureND = arg.map { tan(it) } - public override fun asin(arg: StructureND): ViktorStructureND = arg.map { asin(it) } - public override fun acos(arg: StructureND): ViktorStructureND = arg.map { acos(it) } - public override fun atan(arg: StructureND): ViktorStructureND = arg.map { atan(it) } - - public override fun power(arg: StructureND, pow: Number): ViktorStructureND = arg.map { it.pow(pow) } - - public override fun exp(arg: StructureND): ViktorStructureND = arg.f64Buffer.exp().asStructure() - - public override fun ln(arg: StructureND): ViktorStructureND = arg.f64Buffer.log().asStructure() -} - -public fun ViktorNDField(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape) diff --git a/settings.gradle.kts b/settings.gradle.kts index 241e4f38b..4bfd71032 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -5,24 +5,24 @@ pluginManagement { gradlePluginPortal() } - val toolsVersion = "0.10.0" - val kotlinVersion = "1.5.0" + val kotlinVersion = "1.6.0-RC" + val toolsVersion = "0.10.5" plugins { - id("ru.mipt.npm.gradle.project") version toolsVersion - id("ru.mipt.npm.gradle.mpp") version toolsVersion - id("ru.mipt.npm.gradle.jvm") version toolsVersion - kotlin("multiplatform") version kotlinVersion - kotlin("jvm") version kotlinVersion - kotlin("plugin.allopen") version kotlinVersion id("org.jetbrains.kotlinx.benchmark") version "0.3.1" - kotlin("jupyter.api") version "0.10.0-25" - + id("ru.mipt.npm.gradle.project") version toolsVersion + id("ru.mipt.npm.gradle.jvm") version toolsVersion + id("ru.mipt.npm.gradle.mpp") version toolsVersion + kotlin("multiplatform") version kotlinVersion + kotlin("plugin.allopen") version kotlinVersion } } rootProject.name = "kmath" +enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS") +enableFeaturePreview("VERSION_CATALOGS") + include( ":kmath-memory", ":kmath-complex", @@ -32,6 +32,8 @@ include( ":kmath-histograms", ":kmath-commons", ":kmath-viktor", + ":kmath-multik", + ":kmath-optimization", ":kmath-stat", ":kmath-nd4j", ":kmath-dimensions", @@ -45,7 +47,7 @@ include( ":kmath-symja", ":kmath-jafama", ":examples", - ":benchmarks" + ":benchmarks", )