forked from kscience/kmath
Initial work on distributions
This commit is contained in:
parent
f706122266
commit
2e76073712
@ -18,6 +18,7 @@ package scientifik.kmath.chains
|
|||||||
|
|
||||||
import kotlinx.atomicfu.atomic
|
import kotlinx.atomicfu.atomic
|
||||||
import kotlinx.coroutines.FlowPreview
|
import kotlinx.coroutines.FlowPreview
|
||||||
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -28,7 +29,7 @@ interface Chain<out R> {
|
|||||||
/**
|
/**
|
||||||
* Last cached value of the chain. Returns null if [next] was not called
|
* Last cached value of the chain. Returns null if [next] was not called
|
||||||
*/
|
*/
|
||||||
val value: R?
|
fun peek(): R?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate next value, changing state if needed
|
* Generate next value, changing state if needed
|
||||||
@ -46,13 +47,12 @@ interface Chain<out R> {
|
|||||||
* Chain as a coroutine flow. The flow emit affects chain state and vice versa
|
* Chain as a coroutine flow. The flow emit affects chain state and vice versa
|
||||||
*/
|
*/
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
val <R> Chain<R>.flow
|
val <R> Chain<R>.flow: Flow<R>
|
||||||
get() = kotlinx.coroutines.flow.flow { while (true) emit(next()) }
|
get() = kotlinx.coroutines.flow.flow { while (true) emit(next()) }
|
||||||
|
|
||||||
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
||||||
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||||
* since mapped chain consumes tokens. Accepts regular transformation function
|
* since mapped chain consumes tokens. Accepts regular transformation function
|
||||||
@ -60,7 +60,7 @@ fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
|||||||
fun <T, R> Chain<T>.map(func: (T) -> R): Chain<R> {
|
fun <T, R> Chain<T>.map(func: (T) -> R): Chain<R> {
|
||||||
val parent = this;
|
val parent = this;
|
||||||
return object : Chain<R> {
|
return object : Chain<R> {
|
||||||
override val value: R? get() = parent.value?.let(func)
|
override fun peek(): R? = parent.peek()?.let(func)
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
return func(parent.next())
|
return func(parent.next())
|
||||||
@ -77,7 +77,7 @@ fun <T, R> Chain<T>.map(func: (T) -> R): Chain<R> {
|
|||||||
*/
|
*/
|
||||||
class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
||||||
private val atomicValue = atomic<R?>(null)
|
private val atomicValue = atomic<R?>(null)
|
||||||
override val value: R? get() = atomicValue.value
|
override fun peek(): R? = atomicValue.value
|
||||||
|
|
||||||
override suspend fun next(): R = gen().also { atomicValue.lazySet(it) }
|
override suspend fun next(): R = gen().also { atomicValue.lazySet(it) }
|
||||||
|
|
||||||
@ -95,16 +95,16 @@ class MarkovChain<out R : Any>(private val seed: () -> R, private val gen: suspe
|
|||||||
constructor(seed: R, gen: suspend (R) -> R) : this({ seed }, gen)
|
constructor(seed: R, gen: suspend (R) -> R) : this({ seed }, gen)
|
||||||
|
|
||||||
private val atomicValue = atomic<R?>(null)
|
private val atomicValue = atomic<R?>(null)
|
||||||
override val value: R get() = atomicValue.value ?: seed()
|
override fun peek(): R = atomicValue.value ?: seed()
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
val newValue = gen(value)
|
val newValue = gen(peek())
|
||||||
atomicValue.lazySet(newValue)
|
atomicValue.lazySet(newValue)
|
||||||
return value
|
return peek()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
override fun fork(): Chain<R> {
|
||||||
return MarkovChain(value, gen)
|
return MarkovChain(peek(), gen)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,12 +121,12 @@ class StatefulChain<S, out R>(
|
|||||||
constructor(state: S, seed: R, gen: suspend S.(R) -> R) : this(state, { seed }, gen)
|
constructor(state: S, seed: R, gen: suspend S.(R) -> R) : this(state, { seed }, gen)
|
||||||
|
|
||||||
private val atomicValue = atomic<R?>(null)
|
private val atomicValue = atomic<R?>(null)
|
||||||
override val value: R get() = atomicValue.value ?: seed(state)
|
override fun peek(): R = atomicValue.value ?: seed(state)
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
val newValue = gen(state, value)
|
val newValue = gen(state, peek())
|
||||||
atomicValue.lazySet(newValue)
|
atomicValue.lazySet(newValue)
|
||||||
return value
|
return peek()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
override fun fork(): Chain<R> {
|
||||||
@ -137,10 +137,10 @@ class StatefulChain<S, out R>(
|
|||||||
/**
|
/**
|
||||||
* A chain that repeats the same value
|
* A chain that repeats the same value
|
||||||
*/
|
*/
|
||||||
class ConstantChain<out T>(override val value: T) : Chain<T> {
|
class ConstantChain<out T>(val value: T) : Chain<T> {
|
||||||
override suspend fun next(): T {
|
override fun peek(): T? = value
|
||||||
return value
|
|
||||||
}
|
override suspend fun next(): T = value
|
||||||
|
|
||||||
override fun fork(): Chain<T> {
|
override fun fork(): Chain<T> {
|
||||||
return this
|
return this
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath
|
package scientifik.kmath.coroutines
|
||||||
|
|
||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
import kotlinx.coroutines.channels.produce
|
import kotlinx.coroutines.channels.produce
|
||||||
@ -42,7 +42,8 @@ fun <T, R> Flow<T>.async(
|
|||||||
}
|
}
|
||||||
|
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) = AsyncFlow(deferredFlow.map { input ->
|
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) =
|
||||||
|
AsyncFlow(deferredFlow.map { input ->
|
||||||
//TODO add function composition
|
//TODO add function composition
|
||||||
LazyDeferred(input.dispatcher) {
|
LazyDeferred(input.dispatcher) {
|
||||||
input.start(this)
|
input.start(this)
|
@ -22,7 +22,7 @@ fun <T> Flow<Buffer<out T>>.spread(): Flow<T> = flatMapConcat { it.asFlow() }
|
|||||||
* Collect incoming flow into fixed size chunks
|
* Collect incoming flow into fixed size chunks
|
||||||
*/
|
*/
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>) = flow {
|
fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<Buffer<T>> = flow {
|
||||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||||
val list = ArrayList<T>(bufferSize)
|
val list = ArrayList<T>(bufferSize)
|
||||||
var counter = 0
|
var counter = 0
|
||||||
@ -46,7 +46,7 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>) = flow
|
|||||||
* Specialized flow chunker for real buffer
|
* Specialized flow chunker for real buffer
|
||||||
*/
|
*/
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
fun Flow<Double>.chunked(bufferSize: Int) = flow {
|
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
|
||||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||||
val array = DoubleArray(bufferSize)
|
val array = DoubleArray(bufferSize)
|
||||||
var counter = 0
|
var counter = 0
|
||||||
|
@ -27,7 +27,7 @@ fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> {
|
|||||||
fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> {
|
fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> {
|
||||||
val parent = this;
|
val parent = this;
|
||||||
return object : Chain<R> {
|
return object : Chain<R> {
|
||||||
override val value: R? get() = runBlocking { parent.value?.let { func(it) } }
|
override fun peek(): R? = runBlocking { parent.peek()?.let { func(it) } }
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
return func(parent.next())
|
return func(parent.next())
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
import scientifik.kmath.Math
|
import scientifik.kmath.coroutines.Math
|
||||||
|
|
||||||
class LazyNDStructure<T>(
|
class LazyNDStructure<T>(
|
||||||
val scope: CoroutineScope,
|
val scope: CoroutineScope,
|
||||||
|
@ -4,9 +4,9 @@ import kotlinx.coroutines.*
|
|||||||
import kotlinx.coroutines.flow.asFlow
|
import kotlinx.coroutines.flow.asFlow
|
||||||
import kotlinx.coroutines.flow.collect
|
import kotlinx.coroutines.flow.collect
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import scientifik.kmath.async
|
import scientifik.kmath.coroutines.async
|
||||||
import scientifik.kmath.collect
|
import scientifik.kmath.coroutines.collect
|
||||||
import scientifik.kmath.map
|
import scientifik.kmath.coroutines.map
|
||||||
import java.util.concurrent.Executors
|
import java.util.concurrent.Executors
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,13 +6,14 @@ plugins {
|
|||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jvmMain {
|
jvmMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
|
// https://mvnrepository.com/artifact/org.apache.commons/commons-rng-simple
|
||||||
|
api("org.apache.commons:commons-rng-sampling:1.2")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A distribution of typed objects
|
||||||
|
*/
|
||||||
|
interface Distribution<T : Any> {
|
||||||
|
/**
|
||||||
|
* A probability value for given argument [arg].
|
||||||
|
* For continuous distributions returns PDF
|
||||||
|
*/
|
||||||
|
fun probability(arg: T): Double
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a chain of samples from this distribution.
|
||||||
|
* The chain is not guaranteed to be stateless.
|
||||||
|
*/
|
||||||
|
fun sample(generator: RandomGenerator): Chain<T>
|
||||||
|
//TODO add sample bunch generator
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
|
||||||
|
/**
|
||||||
|
* Cumulative distribution for ordered parameter
|
||||||
|
*/
|
||||||
|
fun cumulative(arg: T): Double
|
||||||
|
}
|
@ -0,0 +1,11 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A basic generator
|
||||||
|
*/
|
||||||
|
interface RandomGenerator {
|
||||||
|
fun nextDouble(): Double
|
||||||
|
fun nextInt(): Int
|
||||||
|
fun nextLong(): Long
|
||||||
|
fun nextBlock(size: Int): ByteArray
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import org.apache.commons.rng.sampling.distribution.*
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.SimpleChain
|
||||||
|
import kotlin.math.PI
|
||||||
|
import kotlin.math.exp
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
class NormalDistribution(val mean: Double, val sigma: Double) : UnivariateDistribution<Double> {
|
||||||
|
enum class Sampler {
|
||||||
|
BoxMuller,
|
||||||
|
Marsaglia,
|
||||||
|
Ziggurat
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun probability(arg: Double): Double {
|
||||||
|
val d = (arg - mean) / sigma
|
||||||
|
return 1.0 / sqrt(2.0 * PI * sigma) * exp(-d * d / 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cumulative(arg: Double): Double {
|
||||||
|
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||||
|
}
|
||||||
|
|
||||||
|
fun sample(generator: RandomGenerator, sampler: Sampler): Chain<Double> {
|
||||||
|
val normalized = when (sampler) {
|
||||||
|
Sampler.BoxMuller -> BoxMullerNormalizedGaussianSampler(generator.asProvider())
|
||||||
|
Sampler.Marsaglia -> MarsagliaNormalizedGaussianSampler(generator.asProvider())
|
||||||
|
Sampler.Ziggurat -> ZigguratNormalizedGaussianSampler(generator.asProvider())
|
||||||
|
}
|
||||||
|
val gauss = GaussianSampler(normalized, mean, sigma)
|
||||||
|
//TODO add generator to chain state to allow stateful forks
|
||||||
|
return SimpleChain { gauss.sample() }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Double> = sample(generator, Sampler.BoxMuller)
|
||||||
|
}
|
||||||
|
|
||||||
|
class PoissonDistribution(val mean: Double): UnivariateDistribution<Int>{
|
||||||
|
override fun probability(arg: Int): Double {
|
||||||
|
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cumulative(arg: Int): Double {
|
||||||
|
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Int> {
|
||||||
|
val sampler = PoissonSampler(generator.asProvider(), mean)
|
||||||
|
return SimpleChain{sampler.sample()}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import org.apache.commons.rng.UniformRandomProvider
|
||||||
|
|
||||||
|
inline class CommonsRandomProviderWrapper(val provider: UniformRandomProvider) : RandomGenerator {
|
||||||
|
override fun nextDouble(): Double = provider.nextDouble()
|
||||||
|
|
||||||
|
override fun nextInt(): Int = provider.nextInt()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = provider.nextLong()
|
||||||
|
|
||||||
|
override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { provider.nextBytes(it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fun UniformRandomProvider.asGenerator(): RandomGenerator = CommonsRandomProviderWrapper(this)
|
||||||
|
|
||||||
|
fun RandomGenerator.asProvider(): UniformRandomProvider =
|
||||||
|
(this as? CommonsRandomProviderWrapper)?.provider ?: TODO("implement reverse wrapper")
|
Loading…
Reference in New Issue
Block a user