forked from kscience/kmath
commit
618dd07bcb
@ -1,6 +1,7 @@
|
|||||||
Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/scientifik/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion)
|
Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/scientifik/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion)
|
||||||
|
|
||||||
# KMath
|
# KMath
|
||||||
|
Could be pronounced as `key-math`.
|
||||||
The Kotlin MATHematics library is intended as a Kotlin-based analog to Python's `numpy` library. In contrast to `numpy` and `scipy` it is modular and has a lightweight core.
|
The Kotlin MATHematics library is intended as a Kotlin-based analog to Python's `numpy` library. In contrast to `numpy` and `scipy` it is modular and has a lightweight core.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
@ -13,7 +14,7 @@ Actual feature list is [here](doc/features.md)
|
|||||||
* Complex numbers backed by the `Field` API (meaning that they will be usable in any structure like vectors and N-dimensional arrays).
|
* Complex numbers backed by the `Field` API (meaning that they will be usable in any structure like vectors and N-dimensional arrays).
|
||||||
* Advanced linear algebra operations like matrix inversion and LU decomposition.
|
* Advanced linear algebra operations like matrix inversion and LU decomposition.
|
||||||
|
|
||||||
* **Array-like structures** Full support of many-dimenstional array-like structures
|
* **Array-like structures** Full support of many-dimensional array-like structures
|
||||||
including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).
|
including mixed arithmetic operations and function operations over arrays and numbers (with the added benefit of static type checking).
|
||||||
|
|
||||||
* **Expressions** By writing a single mathematical expression
|
* **Expressions** By writing a single mathematical expression
|
||||||
@ -22,13 +23,13 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
|
|
||||||
* **Histograms** Fast multi-dimensional histograms.
|
* **Histograms** Fast multi-dimensional histograms.
|
||||||
|
|
||||||
* **Streaming** Streaming operations on mathematica objects and objects buffers.
|
* **Streaming** Streaming operations on mathematical objects and objects buffers.
|
||||||
|
|
||||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
||||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||||
to submit a feature request if you want something to be done first.
|
to submit a feature request if you want something to be done first.
|
||||||
|
|
||||||
* **Koma wrapper** [Koma](https://github.com/kyonifer/koma) is a well established numerics library in kotlin, specifically linear algebra.
|
* **Koma wrapper** [Koma](https://github.com/kyonifer/koma) is a well established numerics library in Kotlin, specifically linear algebra.
|
||||||
The plan is to have wrappers for koma implementations for compatibility with kmath API.
|
The plan is to have wrappers for koma implementations for compatibility with kmath API.
|
||||||
|
|
||||||
## Planned features
|
## Planned features
|
||||||
@ -110,4 +111,4 @@ dependencies{
|
|||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
The project requires a lot of additional work. Please fill free to contribute in any way and propose new features.
|
The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
|
||||||
|
@ -1,39 +0,0 @@
|
|||||||
plugins {
|
|
||||||
id "java"
|
|
||||||
id "me.champeau.gradle.jmh" version "0.4.8"
|
|
||||||
id 'org.jetbrains.kotlin.jvm'
|
|
||||||
}
|
|
||||||
|
|
||||||
repositories {
|
|
||||||
maven { url 'https://dl.bintray.com/kotlin/kotlin-eap' }
|
|
||||||
maven{ url "http://dl.bintray.com/kyonifer/maven"}
|
|
||||||
mavenCentral()
|
|
||||||
}
|
|
||||||
|
|
||||||
dependencies {
|
|
||||||
implementation project(":kmath-core")
|
|
||||||
implementation project(":kmath-coroutines")
|
|
||||||
implementation project(":kmath-commons")
|
|
||||||
implementation project(":kmath-koma")
|
|
||||||
implementation group: "com.kyonifer", name:"koma-core-ejml", version: "0.12"
|
|
||||||
implementation "org.jetbrains.kotlinx:kotlinx-io-jvm:0.1.5"
|
|
||||||
//compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
|
|
||||||
//jmh project(':kmath-core')
|
|
||||||
}
|
|
||||||
|
|
||||||
jmh {
|
|
||||||
warmupIterations = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
jmhClasses.dependsOn(compileKotlin)
|
|
||||||
|
|
||||||
compileKotlin {
|
|
||||||
kotlinOptions {
|
|
||||||
jvmTarget = "1.8"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
compileTestKotlin {
|
|
||||||
kotlinOptions {
|
|
||||||
jvmTarget = "1.8"
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,4 +1,4 @@
|
|||||||
val kmathVersion by extra("0.1.2")
|
val kmathVersion by extra("0.1.3")
|
||||||
|
|
||||||
allprojects {
|
allprojects {
|
||||||
repositories {
|
repositories {
|
||||||
|
@ -32,12 +32,12 @@ Typical case of `Field` is the `RealField` which works on doubles. And typical c
|
|||||||
|
|
||||||
In some cases algebra context could hold additional operation like `exp` or `sin`, in this case it inherits appropriate
|
In some cases algebra context could hold additional operation like `exp` or `sin`, in this case it inherits appropriate
|
||||||
interface. Also a context could have an operation which produces an element outside of its context. For example
|
interface. Also a context could have an operation which produces an element outside of its context. For example
|
||||||
`Matrix` `dot` operation produces a matrix with new dimensions which could not be compatible with initial matrix in
|
`Matrix` `dot` operation produces a matrix with new dimensions which can be incompatible with initial matrix in
|
||||||
terms of linear operations.
|
terms of linear operations.
|
||||||
|
|
||||||
## Algebra element
|
## Algebra element
|
||||||
|
|
||||||
In order to achieve more familiar behavior (where you apply operations directly to mathematica objects), without involving contexts
|
In order to achieve more familiar behavior (where you apply operations directly to mathematical objects), without involving contexts
|
||||||
`kmath` introduces special type objects called `MathElement`. A `MathElement` is basically some object coupled to
|
`kmath` introduces special type objects called `MathElement`. A `MathElement` is basically some object coupled to
|
||||||
a mathematical context. For example `Complex` is the pair of real numbers representing real and imaginary parts,
|
a mathematical context. For example `Complex` is the pair of real numbers representing real and imaginary parts,
|
||||||
but it also holds reference to the `ComplexField` singleton which allows to perform direct operations on `Complex`
|
but it also holds reference to the `ComplexField` singleton which allows to perform direct operations on `Complex`
|
||||||
|
@ -9,17 +9,12 @@ structures. In `kmath` performance depends on which particular context was used
|
|||||||
|
|
||||||
Let us consider following contexts:
|
Let us consider following contexts:
|
||||||
```kotlin
|
```kotlin
|
||||||
// specialized nd-field for Double. It works as generic Double field as well
|
|
||||||
val specializedField = NDField.real(intArrayOf(dim, dim))
|
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build context most suited for given type.
|
||||||
val autoField = NDField.auto(intArrayOf(dim, dim), RealField)
|
val autoField = NDField.auto(RealField, dim, dim)
|
||||||
|
// specialized nd-field for Double. It works as generic Double field as well
|
||||||
//A field implementing lazy computations. All elements are computed on-demand
|
val specializedField = NDField.real(dim, dim)
|
||||||
val lazyField = NDField.lazy(intArrayOf(dim, dim), RealField)
|
|
||||||
|
|
||||||
//A generic boxing field. It should be used for objects, not primitives.
|
//A generic boxing field. It should be used for objects, not primitives.
|
||||||
val genericField = NDField.buffered(intArrayOf(dim, dim), RealField)
|
val genericField = NDField.buffered(RealField, dim, dim)
|
||||||
```
|
```
|
||||||
Now let us perform several tests and see which implementation is best suited for each case:
|
Now let us perform several tests and see which implementation is best suited for each case:
|
||||||
|
|
||||||
@ -32,7 +27,7 @@ to it `n = 1000` times.
|
|||||||
The code to run this looks like:
|
The code to run this looks like:
|
||||||
```kotlin
|
```kotlin
|
||||||
specializedField.run {
|
specializedField.run {
|
||||||
var res = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
res += 1.0
|
res += 1.0
|
||||||
}
|
}
|
||||||
@ -93,7 +88,7 @@ In this case it completes in about `4x-5x` time due to boxing.
|
|||||||
The boxing field produced by
|
The boxing field produced by
|
||||||
```kotlin
|
```kotlin
|
||||||
genericField.run {
|
genericField.run {
|
||||||
var res = one
|
var res: NDBuffer<Double> = one
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
res += 1.0
|
res += 1.0
|
||||||
}
|
}
|
||||||
|
67
examples/build.gradle.kts
Normal file
67
examples/build.gradle.kts
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import org.jetbrains.gradle.benchmarks.JvmBenchmarkTarget
|
||||||
|
import org.jetbrains.kotlin.allopen.gradle.AllOpenExtension
|
||||||
|
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
|
||||||
|
|
||||||
|
plugins {
|
||||||
|
java
|
||||||
|
kotlin("jvm")
|
||||||
|
kotlin("plugin.allopen") version "1.3.31"
|
||||||
|
id("org.jetbrains.gradle.benchmarks.plugin") version "0.1.7-dev-24"
|
||||||
|
}
|
||||||
|
|
||||||
|
configure<AllOpenExtension> {
|
||||||
|
annotation("org.openjdk.jmh.annotations.State")
|
||||||
|
}
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
maven("http://dl.bintray.com/kyonifer/maven")
|
||||||
|
maven("https://dl.bintray.com/orangy/maven")
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceSets {
|
||||||
|
register("benchmarks")
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation(project(":kmath-core"))
|
||||||
|
implementation(project(":kmath-coroutines"))
|
||||||
|
implementation(project(":kmath-commons"))
|
||||||
|
implementation(project(":kmath-koma"))
|
||||||
|
implementation("com.kyonifer:koma-core-ejml:0.12")
|
||||||
|
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.1.5")
|
||||||
|
|
||||||
|
implementation("org.jetbrains.gradle.benchmarks:runtime:0.1.7-dev-24")
|
||||||
|
|
||||||
|
|
||||||
|
"benchmarksCompile"(sourceSets.main.get().compileClasspath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure benchmark
|
||||||
|
benchmark {
|
||||||
|
// Setup configurations
|
||||||
|
targets {
|
||||||
|
// This one matches sourceSet name above
|
||||||
|
register("benchmarks") {
|
||||||
|
this as JvmBenchmarkTarget
|
||||||
|
jmhVersion = "1.21"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
configurations {
|
||||||
|
register("fast") {
|
||||||
|
warmups = 5 // number of warmup iterations
|
||||||
|
iterations = 3 // number of iterations
|
||||||
|
iterationTime = 500 // time in seconds per iteration
|
||||||
|
iterationTimeUnit = "ms" // time unity for iterationTime, default is seconds
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
tasks.withType<KotlinCompile> {
|
||||||
|
kotlinOptions {
|
||||||
|
jvmTarget = "1.8"
|
||||||
|
}
|
||||||
|
}
|
@ -7,7 +7,7 @@ import java.nio.IntBuffer
|
|||||||
|
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
open class ArrayBenchmark {
|
class ArrayBenchmark {
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun benchmarkArrayRead() {
|
fun benchmarkArrayRead() {
|
@ -7,7 +7,7 @@ import scientifik.kmath.operations.Complex
|
|||||||
import scientifik.kmath.operations.complex
|
import scientifik.kmath.operations.complex
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
open class BufferBenchmark {
|
class BufferBenchmark {
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun genericDoubleBufferReadWrite() {
|
fun genericDoubleBufferReadWrite() {
|
@ -1,9 +1,12 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import org.openjdk.jmh.annotations.Benchmark
|
import org.openjdk.jmh.annotations.Benchmark
|
||||||
|
import org.openjdk.jmh.annotations.Scope
|
||||||
|
import org.openjdk.jmh.annotations.State
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
|
|
||||||
open class NDFieldBenchmark {
|
@State(Scope.Benchmark)
|
||||||
|
class NDFieldBenchmark {
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun autoFieldAdd() {
|
fun autoFieldAdd() {
|
||||||
@ -50,6 +53,6 @@ open class NDFieldBenchmark {
|
|||||||
|
|
||||||
val bufferedField = NDField.auto(RealField, dim, dim)
|
val bufferedField = NDField.auto(RealField, dim, dim)
|
||||||
val specializedField = NDField.real(dim, dim)
|
val specializedField = NDField.real(dim, dim)
|
||||||
val genericField = NDField.buffered(intArrayOf(dim, dim), RealField)
|
val genericField = NDField.boxing(RealField, dim, dim)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
package scientifik.kmath.commons.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.runBlocking
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.mapWithState
|
||||||
|
import scientifik.kmath.prob.Distribution
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
|
||||||
|
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
||||||
|
|
||||||
|
fun Chain<Double>.mean(): Chain<Double> = mapWithState(AveragingChainState(),{it.copy()}){chain->
|
||||||
|
val next = chain.next()
|
||||||
|
num++
|
||||||
|
value += next
|
||||||
|
return@mapWithState value / num
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val normal = Distribution.normal()
|
||||||
|
val chain = normal.sample(RandomGenerator.default).mean()
|
||||||
|
|
||||||
|
runBlocking {
|
||||||
|
repeat(10001) { counter ->
|
||||||
|
val mean = chain.next()
|
||||||
|
if (counter % 1000 == 0) {
|
||||||
|
println("[$counter] Average value is $mean")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,9 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import koma.matrix.ejml.EJMLMatrixFactory
|
import koma.matrix.ejml.EJMLMatrixFactory
|
||||||
|
import scientifik.kmath.commons.linear.CMMatrixContext
|
||||||
|
import scientifik.kmath.commons.linear.inverse
|
||||||
|
import scientifik.kmath.commons.linear.toCM
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import kotlin.contracts.ExperimentalContracts
|
import kotlin.contracts.ExperimentalContracts
|
@ -1,6 +1,8 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import koma.matrix.ejml.EJMLMatrixFactory
|
import koma.matrix.ejml.EJMLMatrixFactory
|
||||||
|
import scientifik.kmath.commons.linear.CMMatrixContext
|
||||||
|
import scientifik.kmath.commons.linear.toCM
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
@ -0,0 +1,10 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.NDElement
|
||||||
|
import scientifik.kmath.structures.complex
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val element = NDElement.complex(2, 2) { index: IntArray ->
|
||||||
|
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,8 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.transpose
|
||||||
|
import scientifik.kmath.operations.Complex
|
||||||
|
import scientifik.kmath.operations.toComplex
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
@ -32,3 +35,23 @@ fun main() {
|
|||||||
|
|
||||||
println("Complex addition completed in $complexTime millis")
|
println("Complex addition completed in $complexTime millis")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun complexExample() {
|
||||||
|
//Create a context for 2-d structure with complex values
|
||||||
|
NDField.complex(4, 8).run {
|
||||||
|
//a constant real-valued structure
|
||||||
|
val x = one * 2.5
|
||||||
|
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
|
||||||
|
//a structure generator specific to this context
|
||||||
|
val matrix = produce { (k, l) ->
|
||||||
|
k + l*i
|
||||||
|
}
|
||||||
|
|
||||||
|
//Perform sum
|
||||||
|
val sum = matrix + x + 1.0
|
||||||
|
|
||||||
|
//Represent the sum as 2d-structure and transpose
|
||||||
|
sum.as2D().transpose()
|
||||||
|
}
|
||||||
|
}
|
@ -13,7 +13,7 @@ fun main(args: Array<String>) {
|
|||||||
// specialized nd-field for Double. It works as generic Double field as well
|
// specialized nd-field for Double. It works as generic Double field as well
|
||||||
val specializedField = NDField.real(dim, dim)
|
val specializedField = NDField.real(dim, dim)
|
||||||
//A generic boxing field. It should be used for objects, not primitives.
|
//A generic boxing field. It should be used for objects, not primitives.
|
||||||
val genericField = NDField.buffered(intArrayOf(dim, dim), RealField)
|
val genericField = NDField.boxing(RealField, dim, dim)
|
||||||
|
|
||||||
|
|
||||||
val autoTime = measureTimeMillis {
|
val autoTime = measureTimeMillis {
|
@ -8,6 +8,7 @@ description = "Commons math binding for kmath"
|
|||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
|
api(project(":kmath-prob"))
|
||||||
api("org.apache.commons:commons-math3:3.6.1")
|
api("org.apache.commons:commons-math3:3.6.1")
|
||||||
testImplementation("org.jetbrains.kotlin:kotlin-test")
|
testImplementation("org.jetbrains.kotlin:kotlin-test")
|
||||||
testImplementation("org.jetbrains.kotlin:kotlin-test-junit")
|
testImplementation("org.jetbrains.kotlin:kotlin-test-junit")
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.commons.expressions
|
||||||
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.expressions.ExpressionContext
|
||||||
import scientifik.kmath.operations.ExtendedField
|
import scientifik.kmath.operations.ExtendedField
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
@ -82,8 +84,11 @@ class DerivativeStructureField(
|
|||||||
* A constructs that creates a derivative structure with required order on-demand
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
*/
|
*/
|
||||||
class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> {
|
class DiffExpression(val function: DerivativeStructureField.() -> DerivativeStructure) : Expression<Double> {
|
||||||
override fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(0, arguments)
|
|
||||||
.run(function).value
|
override fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
|
||||||
|
0,
|
||||||
|
arguments
|
||||||
|
).run(function).value
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the derivative expression with given orders
|
* Get the derivative expression with given orders
|
||||||
@ -109,21 +114,27 @@ fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
|||||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
||||||
*/
|
*/
|
||||||
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
|
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
|
||||||
override fun variable(name: String, default: Double?) = DiffExpression { variable(name, default?.const()) }
|
override fun variable(name: String, default: Double?) =
|
||||||
|
DiffExpression { variable(name, default?.const()) }
|
||||||
|
|
||||||
override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
|
override fun const(value: Double): DiffExpression =
|
||||||
|
DiffExpression { value.const() }
|
||||||
|
|
||||||
override fun add(a: DiffExpression, b: DiffExpression) = DiffExpression { a.function(this) + b.function(this) }
|
override fun add(a: DiffExpression, b: DiffExpression) =
|
||||||
|
DiffExpression { a.function(this) + b.function(this) }
|
||||||
|
|
||||||
override val zero = DiffExpression { 0.0.const() }
|
override val zero = DiffExpression { 0.0.const() }
|
||||||
|
|
||||||
override fun multiply(a: DiffExpression, k: Number) = DiffExpression { a.function(this) * k }
|
override fun multiply(a: DiffExpression, k: Number) =
|
||||||
|
DiffExpression { a.function(this) * k }
|
||||||
|
|
||||||
override val one = DiffExpression { 1.0.const() }
|
override val one = DiffExpression { 1.0.const() }
|
||||||
|
|
||||||
override fun multiply(a: DiffExpression, b: DiffExpression) = DiffExpression { a.function(this) * b.function(this) }
|
override fun multiply(a: DiffExpression, b: DiffExpression) =
|
||||||
|
DiffExpression { a.function(this) * b.function(this) }
|
||||||
|
|
||||||
override fun divide(a: DiffExpression, b: DiffExpression) = DiffExpression { a.function(this) / b.function(this) }
|
override fun divide(a: DiffExpression, b: DiffExpression) =
|
||||||
|
DiffExpression { a.function(this) / b.function(this) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.commons.linear
|
||||||
|
|
||||||
import org.apache.commons.math3.linear.*
|
import org.apache.commons.math3.linear.*
|
||||||
import org.apache.commons.math3.linear.RealMatrix
|
import org.apache.commons.math3.linear.RealMatrix
|
||||||
import org.apache.commons.math3.linear.RealVector
|
import org.apache.commons.math3.linear.RealVector
|
||||||
|
import scientifik.kmath.linear.*
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
|
|
||||||
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) :
|
||||||
|
FeaturedMatrix<Double> {
|
||||||
override val rowNum: Int get() = origin.rowDimension
|
override val rowNum: Int get() = origin.rowDimension
|
||||||
override val colNum: Int get() = origin.columnDimension
|
override val colNum: Int get() = origin.columnDimension
|
||||||
|
|
||||||
@ -70,10 +72,14 @@ object CMMatrixContext : MatrixContext<Double> {
|
|||||||
override fun multiply(a: Matrix<Double>, k: Number) =
|
override fun multiply(a: Matrix<Double>, k: Number) =
|
||||||
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
|
CMMatrix(a.toCM().origin.scalarMultiply(k.toDouble()))
|
||||||
|
|
||||||
override fun Matrix<Double>.times(value: Double): Matrix<Double> = produce(rowNum,colNum){i,j-> get(i,j)*value}
|
override fun Matrix<Double>.times(value: Double): Matrix<Double> =
|
||||||
|
produce(rowNum, colNum) { i, j -> get(i, j) * value }
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
|
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix =
|
||||||
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin))
|
CMMatrix(this.origin.add(other.origin))
|
||||||
|
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix =
|
||||||
|
CMMatrix(this.origin.subtract(other.origin))
|
||||||
|
|
||||||
infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = CMMatrix(this.origin.multiply(other.origin))
|
infix fun CMMatrix.dot(other: CMMatrix): CMMatrix =
|
||||||
|
CMMatrix(this.origin.multiply(other.origin))
|
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.commons.linear
|
||||||
|
|
||||||
import org.apache.commons.math3.linear.*
|
import org.apache.commons.math3.linear.*
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
|
|
||||||
enum class CMDecomposition {
|
enum class CMDecomposition {
|
@ -0,0 +1,32 @@
|
|||||||
|
package scientifik.kmath.commons.prob
|
||||||
|
|
||||||
|
import org.apache.commons.math3.random.JDKRandomGenerator
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator as CMRandom
|
||||||
|
|
||||||
|
inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator {
|
||||||
|
override fun nextDouble(): Double = generator.nextDouble()
|
||||||
|
|
||||||
|
override fun nextInt(): Int = generator.nextInt()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = generator.nextLong()
|
||||||
|
|
||||||
|
override fun nextBlock(size: Int): ByteArray = ByteArray(size).apply { generator.nextBytes(this) }
|
||||||
|
|
||||||
|
override fun fork(): RandomGenerator {
|
||||||
|
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun CMRandom.asKmathGenerator(): RandomGenerator = CMRandomGeneratorWrapper(this)
|
||||||
|
|
||||||
|
fun RandomGenerator.asCMGenerator(): CMRandom =
|
||||||
|
(this as? CMRandomGeneratorWrapper)?.generator ?: TODO("Implement reverse CM wrapper")
|
||||||
|
|
||||||
|
val RandomGenerator.Companion.default: RandomGenerator by lazy { JDKRandomGenerator().asKmathGenerator() }
|
||||||
|
|
||||||
|
fun RandomGenerator.Companion.jdk(seed: Int? = null): RandomGenerator = if (seed == null) {
|
||||||
|
JDKRandomGenerator()
|
||||||
|
} else {
|
||||||
|
JDKRandomGenerator(seed)
|
||||||
|
}.asKmathGenerator()
|
@ -0,0 +1,82 @@
|
|||||||
|
package scientifik.kmath.commons.prob
|
||||||
|
|
||||||
|
import org.apache.commons.math3.distribution.*
|
||||||
|
import scientifik.kmath.prob.Distribution
|
||||||
|
import scientifik.kmath.prob.RandomChain
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.UnivariateDistribution
|
||||||
|
import org.apache.commons.math3.random.RandomGenerator as CMRandom
|
||||||
|
|
||||||
|
class CMRealDistributionWrapper(val builder: (CMRandom?) -> RealDistribution) : UnivariateDistribution<Double> {
|
||||||
|
|
||||||
|
private val defaultDistribution by lazy { builder(null) }
|
||||||
|
|
||||||
|
override fun probability(arg: Double): Double = defaultDistribution.probability(arg)
|
||||||
|
|
||||||
|
override fun cumulative(arg: Double): Double = defaultDistribution.cumulativeProbability(arg)
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): RandomChain<Double> {
|
||||||
|
val distribution = builder(generator.asCMGenerator())
|
||||||
|
return RandomChain(generator) { distribution.sample() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class CMIntDistributionWrapper(val builder: (CMRandom?) -> IntegerDistribution) : UnivariateDistribution<Int> {
|
||||||
|
|
||||||
|
private val defaultDistribution by lazy { builder(null) }
|
||||||
|
|
||||||
|
override fun probability(arg: Int): Double = defaultDistribution.probability(arg)
|
||||||
|
|
||||||
|
override fun cumulative(arg: Int): Double = defaultDistribution.cumulativeProbability(arg)
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): RandomChain<Int> {
|
||||||
|
val distribution = builder(generator.asCMGenerator())
|
||||||
|
return RandomChain(generator) { distribution.sample() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun Distribution.Companion.normal(mean: Double = 0.0, sigma: Double = 1.0): UnivariateDistribution<Double> =
|
||||||
|
CMRealDistributionWrapper { generator -> NormalDistribution(generator, mean, sigma) }
|
||||||
|
|
||||||
|
fun Distribution.Companion.poisson(mean: Double): UnivariateDistribution<Int> = CMIntDistributionWrapper { generator ->
|
||||||
|
PoissonDistribution(
|
||||||
|
generator,
|
||||||
|
mean,
|
||||||
|
PoissonDistribution.DEFAULT_EPSILON,
|
||||||
|
PoissonDistribution.DEFAULT_MAX_ITERATIONS
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.binomial(trials: Int, p: Double): UnivariateDistribution<Int> =
|
||||||
|
CMIntDistributionWrapper { generator ->
|
||||||
|
BinomialDistribution(generator, trials, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.student(degreesOfFreedom: Double): UnivariateDistribution<Double> =
|
||||||
|
CMRealDistributionWrapper { generator ->
|
||||||
|
TDistribution(generator, degreesOfFreedom, TDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.chi2(degreesOfFreedom: Double): UnivariateDistribution<Double> =
|
||||||
|
CMRealDistributionWrapper { generator ->
|
||||||
|
ChiSquaredDistribution(generator, degreesOfFreedom)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.fisher(
|
||||||
|
numeratorDegreesOfFreedom: Double,
|
||||||
|
denominatorDegreesOfFreedom: Double
|
||||||
|
): UnivariateDistribution<Double> =
|
||||||
|
CMRealDistributionWrapper { generator ->
|
||||||
|
FDistribution(generator, numeratorDegreesOfFreedom, denominatorDegreesOfFreedom)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.exponential(mean: Double): UnivariateDistribution<Double> =
|
||||||
|
CMRealDistributionWrapper { generator ->
|
||||||
|
ExponentialDistribution(generator, mean)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.uniform(a: Double, b: Double): UnivariateDistribution<Double> =
|
||||||
|
CMRealDistributionWrapper { generator ->
|
||||||
|
UniformRealDistribution(generator, a, b)
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.transform
|
package scientifik.kmath.commons.transform
|
||||||
|
|
||||||
import kotlinx.coroutines.FlowPreview
|
import kotlinx.coroutines.FlowPreview
|
||||||
import kotlinx.coroutines.flow.Flow
|
import kotlinx.coroutines.flow.Flow
|
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.commons.expressions
|
||||||
|
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R) =
|
inline fun <R> diff(order: Int, vararg parameters: Pair<String, Double>, block: DerivativeStructureField.() -> R) =
|
@ -58,32 +58,35 @@ internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, v
|
|||||||
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, ExpressionContext<T> {
|
open class ExpressionSpace<T>(val space: Space<T>) : Space<Expression<T>>, ExpressionContext<T> {
|
||||||
|
override val zero: Expression<T> = ConstantExpression(space.zero)
|
||||||
override val zero: Expression<T> = ConstantExpression(field.zero)
|
|
||||||
|
|
||||||
override val one: Expression<T> = ConstantExpression(field.one)
|
|
||||||
|
|
||||||
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
||||||
|
|
||||||
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
||||||
|
|
||||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(field, a, b)
|
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(field, a, k)
|
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
|
||||||
|
|
||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
|
||||||
|
|
||||||
|
|
||||||
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
||||||
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
||||||
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
|
||||||
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
|
||||||
|
|
||||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
operator fun T.plus(arg: Expression<T>) = arg + this
|
||||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
operator fun T.minus(arg: Expression<T>) = arg - this
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ExpressionField<T>(val field: Field<T>) : Field<Expression<T>>, ExpressionSpace<T>(field) {
|
||||||
|
override val one: Expression<T> = ConstantExpression(field.one)
|
||||||
|
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
||||||
|
|
||||||
|
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
||||||
|
|
||||||
|
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
||||||
|
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
||||||
|
|
||||||
operator fun T.times(arg: Expression<T>) = arg * this
|
operator fun T.times(arg: Expression<T>) = arg * this
|
||||||
operator fun T.div(arg: Expression<T>) = arg / this
|
operator fun T.div(arg: Expression<T>) = arg / this
|
||||||
}
|
}
|
@ -24,6 +24,7 @@ class BufferMatrixContext<T : Any, R : Ring<T>>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
object RealMatrixContext : GenericMatrixContext<Double, RealField> {
|
||||||
|
|
||||||
override val elementContext = RealField
|
override val elementContext = RealField
|
||||||
|
@ -0,0 +1,72 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Ring
|
||||||
|
import scientifik.kmath.operations.RingElement
|
||||||
|
|
||||||
|
class BoxingNDRing<T, R : Ring<T>>(
|
||||||
|
override val shape: IntArray,
|
||||||
|
override val elementContext: R,
|
||||||
|
val bufferFactory: BufferFactory<T>
|
||||||
|
) : BufferedNDRing<T, R> {
|
||||||
|
|
||||||
|
override val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
|
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
|
bufferFactory(size, initializer)
|
||||||
|
|
||||||
|
override fun check(vararg elements: NDBuffer<T>) {
|
||||||
|
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
|
||||||
|
}
|
||||||
|
|
||||||
|
override val zero by lazy { produce { zero } }
|
||||||
|
override val one by lazy { produce { one } }
|
||||||
|
|
||||||
|
override fun produce(initializer: R.(IntArray) -> T) =
|
||||||
|
BufferedNDRingElement(
|
||||||
|
this,
|
||||||
|
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
|
||||||
|
|
||||||
|
override fun map(arg: NDBuffer<T>, transform: R.(T) -> T): BufferedNDRingElement<T, R> {
|
||||||
|
check(arg)
|
||||||
|
return BufferedNDRingElement(
|
||||||
|
this,
|
||||||
|
buildBuffer(arg.strides.linearSize) { offset -> elementContext.transform(arg.buffer[offset]) })
|
||||||
|
|
||||||
|
// val buffer = arg.buffer.transform { _, value -> elementContext.transform(value) }
|
||||||
|
// return BufferedNDFieldElement(this, buffer)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun mapIndexed(
|
||||||
|
arg: NDBuffer<T>,
|
||||||
|
transform: R.(index: IntArray, T) -> T
|
||||||
|
): BufferedNDRingElement<T, R> {
|
||||||
|
check(arg)
|
||||||
|
return BufferedNDRingElement(
|
||||||
|
this,
|
||||||
|
buildBuffer(arg.strides.linearSize) { offset ->
|
||||||
|
elementContext.transform(
|
||||||
|
arg.strides.index(offset),
|
||||||
|
arg.buffer[offset]
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// val buffer =
|
||||||
|
// arg.buffer.transform { offset, value -> elementContext.transform(arg.strides.index(offset), value) }
|
||||||
|
// return BufferedNDFieldElement(this, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun combine(
|
||||||
|
a: NDBuffer<T>,
|
||||||
|
b: NDBuffer<T>,
|
||||||
|
transform: R.(T, T) -> T
|
||||||
|
): BufferedNDRingElement<T, R> {
|
||||||
|
check(a, b)
|
||||||
|
return BufferedNDRingElement(
|
||||||
|
this,
|
||||||
|
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun NDBuffer<T>.toElement(): RingElement<NDBuffer<T>, *, out BufferedNDRing<T, R>> =
|
||||||
|
BufferedNDRingElement(this@BoxingNDRing, buffer)
|
||||||
|
}
|
@ -131,4 +131,7 @@ operator fun ComplexNDElement.plus(arg: Double) =
|
|||||||
operator fun ComplexNDElement.minus(arg: Double) =
|
operator fun ComplexNDElement.minus(arg: Double) =
|
||||||
map { it - arg }
|
map { it - arg }
|
||||||
|
|
||||||
fun NDField.Companion.complex(vararg shape: Int) = ComplexNDField(shape)
|
fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = ComplexNDField(shape)
|
||||||
|
|
||||||
|
fun NDElement.Companion.complex(vararg shape: Int, initializer: ComplexField.(IntArray) -> Complex): ComplexNDElement =
|
||||||
|
NDField.complex(*shape).produce(initializer)
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
import kotlin.jvm.JvmName
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -57,6 +57,8 @@ interface NDAlgebra<T, C, N : NDStructure<T>> {
|
|||||||
* element-by-element invoke a function working on [T] on a [NDStructure]
|
* element-by-element invoke a function working on [T] on a [NDStructure]
|
||||||
*/
|
*/
|
||||||
operator fun Function1<T, T>.invoke(structure: N) = map(structure) { value -> this@invoke(value) }
|
operator fun Function1<T, T>.invoke(structure: N) = map(structure) { value -> this@invoke(value) }
|
||||||
|
|
||||||
|
companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -75,10 +77,13 @@ interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T,
|
|||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) }
|
operator fun N.plus(arg: T) = map(this) { value -> add(arg, value) }
|
||||||
|
|
||||||
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) }
|
operator fun N.minus(arg: T) = map(this) { value -> add(arg, -value) }
|
||||||
|
|
||||||
operator fun T.plus(arg: N) = map(arg) { value -> add(this@plus, value) }
|
operator fun T.plus(arg: N) = map(arg) { value -> add(this@plus, value) }
|
||||||
operator fun T.minus(arg: N) = map(arg) { value -> add(-this@minus, value) }
|
operator fun T.minus(arg: N) = map(arg) { value -> add(-this@minus, value) }
|
||||||
|
|
||||||
|
companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -93,7 +98,10 @@ interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N>
|
|||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) }
|
operator fun N.times(arg: T) = map(this) { value -> multiply(arg, value) }
|
||||||
|
|
||||||
operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) }
|
operator fun T.times(arg: N) = map(arg) { value -> multiply(this@times, value) }
|
||||||
|
|
||||||
|
companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -113,6 +121,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
|||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) }
|
operator fun N.div(arg: T) = map(this) { value -> divide(arg, value) }
|
||||||
|
|
||||||
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) }
|
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) }
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
@ -127,12 +136,11 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
|||||||
/**
|
/**
|
||||||
* Create a nd-field with boxing generic buffer
|
* Create a nd-field with boxing generic buffer
|
||||||
*/
|
*/
|
||||||
fun <T : Any, F : Field<T>> buffered(
|
fun <T : Any, F : Field<T>> boxing(
|
||||||
shape: IntArray,
|
|
||||||
field: F,
|
field: F,
|
||||||
|
vararg shape: Int,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||||
) =
|
) = BoxingNDField(shape, field, bufferFactory)
|
||||||
BoxingNDField(shape, field, bufferFactory)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a most suitable implementation for nd-field using reified class.
|
* Create a most suitable implementation for nd-field using reified class.
|
||||||
@ -141,6 +149,7 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
|||||||
inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): BufferedNDField<T, F> =
|
inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): BufferedNDField<T, F> =
|
||||||
when {
|
when {
|
||||||
T::class == Double::class -> real(*shape) as BufferedNDField<T, F>
|
T::class == Double::class -> real(*shape) as BufferedNDField<T, F>
|
||||||
|
T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F>
|
||||||
else -> BoxingNDField(shape, field, Buffer.Companion::auto)
|
else -> BoxingNDField(shape, field, Buffer.Companion::auto)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,7 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
|||||||
/**
|
/**
|
||||||
* Simple boxing NDArray
|
* Simple boxing NDArray
|
||||||
*/
|
*/
|
||||||
fun <T : Any, F : Field<T>> buffered(
|
fun <T : Any, F : Field<T>> boxing(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
field: F,
|
field: F,
|
||||||
initializer: F.(IntArray) -> T
|
initializer: F.(IntArray) -> T
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.*
|
||||||
|
import java.math.BigDecimal
|
||||||
|
import java.math.BigInteger
|
||||||
|
import java.math.MathContext
|
||||||
|
|
||||||
|
object BigIntegerRing : Ring<BigInteger> {
|
||||||
|
override val zero: BigInteger = BigInteger.ZERO
|
||||||
|
override val one: BigInteger = BigInteger.ONE
|
||||||
|
|
||||||
|
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
|
||||||
|
|
||||||
|
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
|
||||||
|
|
||||||
|
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
class BigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> {
|
||||||
|
override val zero: BigDecimal = BigDecimal.ZERO
|
||||||
|
override val one: BigDecimal = BigDecimal.ONE
|
||||||
|
|
||||||
|
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
|
||||||
|
|
||||||
|
override fun multiply(a: BigDecimal, k: Number): BigDecimal =
|
||||||
|
a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext)
|
||||||
|
|
||||||
|
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
|
||||||
|
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInteger): Buffer<BigInteger> =
|
||||||
|
boxing(size, initializer)
|
||||||
|
|
||||||
|
inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInteger): MutableBuffer<BigInteger> =
|
||||||
|
boxing(size, initializer)
|
||||||
|
|
||||||
|
fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInteger, BigIntegerRing> =
|
||||||
|
BoxingNDRing(shape, BigIntegerRing, Buffer.Companion::bigInt)
|
||||||
|
|
||||||
|
fun NDElement.Companion.bigInt(
|
||||||
|
vararg shape: Int,
|
||||||
|
initializer: BigIntegerRing.(IntArray) -> BigInteger
|
||||||
|
): BufferedNDRingElement<BigInteger, BigIntegerRing> =
|
||||||
|
NDAlgebra.bigInt(*shape).produce(initializer)
|
@ -17,7 +17,9 @@
|
|||||||
package scientifik.kmath.chains
|
package scientifik.kmath.chains
|
||||||
|
|
||||||
import kotlinx.atomicfu.atomic
|
import kotlinx.atomicfu.atomic
|
||||||
|
import kotlinx.atomicfu.updateAndGet
|
||||||
import kotlinx.coroutines.FlowPreview
|
import kotlinx.coroutines.FlowPreview
|
||||||
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -25,11 +27,6 @@ import kotlinx.coroutines.FlowPreview
|
|||||||
* @param R - the chain element type
|
* @param R - the chain element type
|
||||||
*/
|
*/
|
||||||
interface Chain<out R> {
|
interface Chain<out R> {
|
||||||
/**
|
|
||||||
* Last cached value of the chain. Returns null if [next] was not called
|
|
||||||
*/
|
|
||||||
val value: R?
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate next value, changing state if needed
|
* Generate next value, changing state if needed
|
||||||
*/
|
*/
|
||||||
@ -40,109 +37,115 @@ interface Chain<out R> {
|
|||||||
*/
|
*/
|
||||||
fun fork(): Chain<R>
|
fun fork(): Chain<R>
|
||||||
|
|
||||||
|
companion object
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Chain as a coroutine flow. The flow emit affects chain state and vice versa
|
* Chain as a coroutine flow. The flow emit affects chain state and vice versa
|
||||||
*/
|
*/
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
val <R> Chain<R>.flow
|
val <R> Chain<R>.flow: Flow<R>
|
||||||
get() = kotlinx.coroutines.flow.flow { while (true) emit(next()) }
|
get() = kotlinx.coroutines.flow.flow { while (true) emit(next()) }
|
||||||
|
|
||||||
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
||||||
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
|
||||||
* since mapped chain consumes tokens. Accepts regular transformation function
|
|
||||||
*/
|
|
||||||
fun <T, R> Chain<T>.map(func: (T) -> R): Chain<R> {
|
|
||||||
val parent = this;
|
|
||||||
return object : Chain<R> {
|
|
||||||
override val value: R? get() = parent.value?.let(func)
|
|
||||||
|
|
||||||
override suspend fun next(): R {
|
|
||||||
return func(parent.next())
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
|
||||||
return parent.fork().map(func)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple chain of independent tokens
|
* A simple chain of independent tokens
|
||||||
*/
|
*/
|
||||||
class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
||||||
private val atomicValue = atomic<R?>(null)
|
override suspend fun next(): R = gen()
|
||||||
override val value: R? get() = atomicValue.value
|
|
||||||
|
|
||||||
override suspend fun next(): R = gen().also { atomicValue.lazySet(it) }
|
|
||||||
|
|
||||||
override fun fork(): Chain<R> = this
|
override fun fork(): Chain<R> = this
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO force forks on mapping operations?
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A stateless Markov chain
|
* A stateless Markov chain
|
||||||
*/
|
*/
|
||||||
class MarkovChain<out R : Any>(private val seed: () -> R, private val gen: suspend (R) -> R) :
|
class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
||||||
Chain<R> {
|
|
||||||
|
|
||||||
constructor(seed: R, gen: suspend (R) -> R) : this({ seed }, gen)
|
constructor(seedValue: R, gen: suspend (R) -> R) : this({ seedValue }, gen)
|
||||||
|
|
||||||
private val atomicValue = atomic<R?>(null)
|
private val value = atomic<R?>(null)
|
||||||
override val value: R get() = atomicValue.value ?: seed()
|
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
val newValue = gen(value)
|
return value.updateAndGet { prev -> gen(prev ?: seed()) }!!
|
||||||
atomicValue.lazySet(newValue)
|
|
||||||
return value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
override fun fork(): Chain<R> {
|
||||||
return MarkovChain(value, gen)
|
return MarkovChain(seed = { value.value ?: seed() }, gen = gen)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share the state
|
* A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share the state
|
||||||
* @param S - the state of the chain
|
* @param S - the state of the chain
|
||||||
|
* @param forkState - the function to copy current state without modifying it
|
||||||
*/
|
*/
|
||||||
class StatefulChain<S, out R>(
|
class StatefulChain<S, out R>(
|
||||||
private val state: S,
|
private val state: S,
|
||||||
private val seed: S.() -> R,
|
private val seed: S.() -> R,
|
||||||
|
private val forkState: ((S) -> S),
|
||||||
private val gen: suspend S.(R) -> R
|
private val gen: suspend S.(R) -> R
|
||||||
) : Chain<R> {
|
) : Chain<R> {
|
||||||
|
|
||||||
constructor(state: S, seed: R, gen: suspend S.(R) -> R) : this(state, { seed }, gen)
|
constructor(state: S, seedValue: R, forkState: ((S) -> S), gen: suspend S.(R) -> R) : this(
|
||||||
|
state,
|
||||||
|
{ seedValue },
|
||||||
|
forkState,
|
||||||
|
gen
|
||||||
|
)
|
||||||
|
|
||||||
private val atomicValue = atomic<R?>(null)
|
private val atomicValue = atomic<R?>(null)
|
||||||
override val value: R get() = atomicValue.value ?: seed(state)
|
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
val newValue = gen(state, value)
|
return atomicValue.updateAndGet { prev -> state.gen(prev ?: state.seed()) }!!
|
||||||
atomicValue.lazySet(newValue)
|
|
||||||
return value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
override fun fork(): Chain<R> {
|
||||||
throw RuntimeException("Fork not supported for stateful chain")
|
return StatefulChain(forkState(state), seed, forkState, gen)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A chain that repeats the same value
|
* A chain that repeats the same value
|
||||||
*/
|
*/
|
||||||
class ConstantChain<out T>(override val value: T) : Chain<T> {
|
class ConstantChain<out T>(val value: T) : Chain<T> {
|
||||||
override suspend fun next(): T {
|
override suspend fun next(): T = value
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun fork(): Chain<T> {
|
override fun fork(): Chain<T> {
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||||
|
* since mapped chain consumes tokens. Accepts regular transformation function
|
||||||
|
*/
|
||||||
|
fun <T, R> Chain<T>.pipe(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
|
||||||
|
override suspend fun next(): R = func(this@pipe.next())
|
||||||
|
override fun fork(): Chain<R> = this@pipe.fork().pipe(func)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map the whole chain
|
||||||
|
*/
|
||||||
|
fun <T, R> Chain<T>.map(mapper: suspend (Chain<T>) -> R): Chain<R> = object : Chain<R> {
|
||||||
|
override suspend fun next(): R = mapper(this@map)
|
||||||
|
override fun fork(): Chain<R> = this@map.fork().map(mapper)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T, S, R> Chain<T>.mapWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
||||||
|
object : Chain<R> {
|
||||||
|
override suspend fun next(): R = state.mapper(this@mapWithState)
|
||||||
|
override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Zip two chains together using given transformation
|
||||||
|
*/
|
||||||
|
fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
|
||||||
|
override suspend fun next(): R = block(this@zip.next(), other.next())
|
||||||
|
|
||||||
|
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath
|
package scientifik.kmath.coroutines
|
||||||
|
|
||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
import kotlinx.coroutines.channels.produce
|
import kotlinx.coroutines.channels.produce
|
||||||
@ -42,13 +42,14 @@ fun <T, R> Flow<T>.async(
|
|||||||
}
|
}
|
||||||
|
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) = AsyncFlow(deferredFlow.map { input ->
|
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) =
|
||||||
//TODO add function composition
|
AsyncFlow(deferredFlow.map { input ->
|
||||||
LazyDeferred(input.dispatcher) {
|
//TODO add function composition
|
||||||
input.start(this)
|
LazyDeferred(input.dispatcher) {
|
||||||
action(input.await())
|
input.start(this)
|
||||||
}
|
action(input.await())
|
||||||
})
|
}
|
||||||
|
})
|
||||||
|
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
@FlowPreview
|
@FlowPreview
|
@ -22,7 +22,7 @@ fun <T> Flow<Buffer<out T>>.spread(): Flow<T> = flatMapConcat { it.asFlow() }
|
|||||||
* Collect incoming flow into fixed size chunks
|
* Collect incoming flow into fixed size chunks
|
||||||
*/
|
*/
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>) = flow {
|
fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<Buffer<T>> = flow {
|
||||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||||
val list = ArrayList<T>(bufferSize)
|
val list = ArrayList<T>(bufferSize)
|
||||||
var counter = 0
|
var counter = 0
|
||||||
@ -46,7 +46,7 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>) = flow
|
|||||||
* Specialized flow chunker for real buffer
|
* Specialized flow chunker for real buffer
|
||||||
*/
|
*/
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
fun Flow<Double>.chunked(bufferSize: Int) = flow {
|
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
|
||||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||||
val array = DoubleArray(bufferSize)
|
val array = DoubleArray(bufferSize)
|
||||||
var counter = 0
|
var counter = 0
|
||||||
|
@ -18,23 +18,3 @@ operator fun <R> Chain<R>.iterator() = object : Iterator<R> {
|
|||||||
fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> {
|
fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> {
|
||||||
override fun iterator(): Iterator<R> = this@asSequence.iterator()
|
override fun iterator(): Iterator<R> = this@asSequence.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
|
||||||
* since mapped chain consumes tokens. Accepts suspending transformation function.
|
|
||||||
*/
|
|
||||||
fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> {
|
|
||||||
val parent = this;
|
|
||||||
return object : Chain<R> {
|
|
||||||
override val value: R? get() = runBlocking { parent.value?.let { func(it) } }
|
|
||||||
|
|
||||||
override suspend fun next(): R {
|
|
||||||
return func(parent.next())
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
|
||||||
return parent.fork().map(func)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,7 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
import scientifik.kmath.Math
|
import scientifik.kmath.coroutines.Math
|
||||||
|
|
||||||
class LazyNDStructure<T>(
|
class LazyNDStructure<T>(
|
||||||
val scope: CoroutineScope,
|
val scope: CoroutineScope,
|
||||||
|
@ -4,9 +4,9 @@ import kotlinx.coroutines.*
|
|||||||
import kotlinx.coroutines.flow.asFlow
|
import kotlinx.coroutines.flow.asFlow
|
||||||
import kotlinx.coroutines.flow.collect
|
import kotlinx.coroutines.flow.collect
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import scientifik.kmath.async
|
import scientifik.kmath.coroutines.async
|
||||||
import scientifik.kmath.collect
|
import scientifik.kmath.coroutines.collect
|
||||||
import scientifik.kmath.map
|
import scientifik.kmath.coroutines.map
|
||||||
import java.util.concurrent.Executors
|
import java.util.concurrent.Executors
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ plugins {
|
|||||||
`npm-multiplatform`
|
`npm-multiplatform`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Just an example how we can collapse nested DSL for simple declarations
|
|
||||||
kotlin.sourceSets.commonMain {
|
kotlin.sourceSets.commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
|
@ -6,13 +6,14 @@ plugins {
|
|||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jvmMain {
|
jvmMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
|
// https://mvnrepository.com/artifact/org.apache.commons/commons-rng-simple
|
||||||
|
//api("org.apache.commons:commons-rng-sampling:1.2")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,68 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.map
|
||||||
|
import kotlin.jvm.JvmName
|
||||||
|
|
||||||
|
interface Sampler<T : Any> {
|
||||||
|
fun sample(generator: RandomGenerator): Chain<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A distribution of typed objects
|
||||||
|
*/
|
||||||
|
interface Distribution<T : Any> : Sampler<T> {
|
||||||
|
/**
|
||||||
|
* A probability value for given argument [arg].
|
||||||
|
* For continuous distributions returns PDF
|
||||||
|
*/
|
||||||
|
fun probability(arg: T): Double
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a chain of samples from this distribution.
|
||||||
|
* The chain is not guaranteed to be stateless.
|
||||||
|
*/
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An empty companion. Distribution factories should be written as its extensions
|
||||||
|
*/
|
||||||
|
companion object
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
|
||||||
|
/**
|
||||||
|
* Cumulative distribution for ordered parameter
|
||||||
|
*/
|
||||||
|
fun cumulative(arg: T): Double
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute probability integral in an interval
|
||||||
|
*/
|
||||||
|
fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Double {
|
||||||
|
require(to > from)
|
||||||
|
return cumulative(to) - cumulative(from)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sample a bunch of values
|
||||||
|
*/
|
||||||
|
fun <T : Any> Sampler<T>.sampleBunch(generator: RandomGenerator, size: Int): Chain<List<T>> {
|
||||||
|
require(size > 1)
|
||||||
|
return sample(generator).map{chain ->
|
||||||
|
List(size){chain.next()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a bunch of samples from real distributions
|
||||||
|
*/
|
||||||
|
@JvmName("realSampleBunch")
|
||||||
|
fun Sampler<Double>.sampleBunch(generator: RandomGenerator, size: Int): Chain<DoubleArray> {
|
||||||
|
require(size > 1)
|
||||||
|
return sample(generator).map{chain ->
|
||||||
|
DoubleArray(size){chain.next()}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,13 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlinx.atomicfu.atomic
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A possibly stateful chain producing random values.
|
||||||
|
*/
|
||||||
|
class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspend RandomGenerator.() -> R) : Chain<R> {
|
||||||
|
override suspend fun next(): R = generator.gen()
|
||||||
|
|
||||||
|
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
||||||
|
}
|
@ -0,0 +1,41 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlin.random.Random
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A basic generator
|
||||||
|
*/
|
||||||
|
interface RandomGenerator {
|
||||||
|
fun nextDouble(): Double
|
||||||
|
fun nextInt(): Int
|
||||||
|
fun nextLong(): Long
|
||||||
|
fun nextBlock(size: Int): ByteArray
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new generator which is independent from current generator (operations on new generator do not affect this one
|
||||||
|
* and vise versa). The statistical properties of new generator should be the same as for this one.
|
||||||
|
* For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run.
|
||||||
|
*
|
||||||
|
* The thread safety of this operation is not guaranteed since it could affect the state of the generator.
|
||||||
|
*/
|
||||||
|
fun fork(): RandomGenerator
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
val default by lazy { DefaultGenerator(Random.nextLong()) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class DefaultGenerator(seed: Long?) : RandomGenerator {
|
||||||
|
private val random = seed?.let { Random(it) } ?: Random
|
||||||
|
|
||||||
|
override fun nextDouble(): Double = random.nextDouble()
|
||||||
|
|
||||||
|
override fun nextInt(): Int = random.nextInt()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = random.nextLong()
|
||||||
|
|
||||||
|
override fun nextBlock(size: Int): ByteArray = random.nextBytes(size)
|
||||||
|
|
||||||
|
override fun fork(): RandomGenerator = DefaultGenerator(nextLong())
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.ConstantChain
|
||||||
|
import scientifik.kmath.chains.pipe
|
||||||
|
import scientifik.kmath.chains.zip
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
|
class BasicSampler<T : Any>(val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator)
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConstantSampler<T : Any>(val value: T) : Sampler<T> {
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<T> = ConstantChain(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A space for samplers. Allows to perform simple operations on distributions
|
||||||
|
*/
|
||||||
|
class SamplerSpace<T : Any>(val space: Space<T>) : Space<Sampler<T>> {
|
||||||
|
|
||||||
|
override val zero: Sampler<T> = ConstantSampler(space.zero)
|
||||||
|
|
||||||
|
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator ->
|
||||||
|
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space.run { aValue + bValue } }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator ->
|
||||||
|
a.sample(generator).pipe { space.run { it * k.toDouble() } }
|
||||||
|
}
|
||||||
|
}
|
@ -3,6 +3,7 @@ pluginManagement {
|
|||||||
jcenter()
|
jcenter()
|
||||||
gradlePluginPortal()
|
gradlePluginPortal()
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
maven("https://dl.bintray.com/orangy/maven")
|
||||||
}
|
}
|
||||||
resolutionStrategy {
|
resolutionStrategy {
|
||||||
eachPlugin {
|
eachPlugin {
|
||||||
@ -28,5 +29,5 @@ include(
|
|||||||
":kmath-commons",
|
":kmath-commons",
|
||||||
":kmath-koma",
|
":kmath-koma",
|
||||||
":kmath-prob",
|
":kmath-prob",
|
||||||
":benchmarks"
|
":examples"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user