Sampler space
This commit is contained in:
parent
88ee3e5ab7
commit
e077da39c8
@ -37,6 +37,8 @@ interface Chain<out R> {
|
||||
*/
|
||||
fun fork(): Chain<R>
|
||||
|
||||
companion object
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
@ -138,3 +140,12 @@ fun <T, S, R> Chain<T>.mapWithState(state: S, stateFork: (S) -> S, mapper: suspe
|
||||
override suspend fun next(): R = state.mapper(this@mapWithState)
|
||||
override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper)
|
||||
}
|
||||
|
||||
/**
|
||||
* Zip two chains together using given transformation
|
||||
*/
|
||||
fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
|
||||
override suspend fun next(): R = block(this@zip.next(), other.next())
|
||||
|
||||
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
|
||||
}
|
@ -5,7 +5,7 @@ import scientifik.kmath.chains.map
|
||||
import kotlin.jvm.JvmName
|
||||
|
||||
interface Sampler<T : Any> {
|
||||
fun sample(generator: RandomGenerator): RandomChain<T>
|
||||
fun sample(generator: RandomGenerator): Chain<T>
|
||||
}
|
||||
|
||||
/**
|
||||
@ -22,7 +22,7 @@ interface Distribution<T : Any> : Sampler<T> {
|
||||
* Create a chain of samples from this distribution.
|
||||
* The chain is not guaranteed to be stateless.
|
||||
*/
|
||||
override fun sample(generator: RandomGenerator): RandomChain<T>
|
||||
override fun sample(generator: RandomGenerator): Chain<T>
|
||||
|
||||
/**
|
||||
* An empty companion. Distribution factories should be written as its extensions
|
||||
|
@ -0,0 +1,31 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.chains.ConstantChain
|
||||
import scientifik.kmath.chains.pipe
|
||||
import scientifik.kmath.chains.zip
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
class BasicSampler<T : Any>(val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
|
||||
override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator)
|
||||
}
|
||||
|
||||
class ConstantSampler<T : Any>(val value: T) : Sampler<T> {
|
||||
override fun sample(generator: RandomGenerator): Chain<T> = ConstantChain(value)
|
||||
}
|
||||
|
||||
/**
|
||||
* A space for samplers. Allows to perform simple operations on distributions
|
||||
*/
|
||||
class SamplerSpace<T : Any>(val space: Space<T>) : Space<Sampler<T>> {
|
||||
|
||||
override val zero: Sampler<T> = ConstantSampler(space.zero)
|
||||
|
||||
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator ->
|
||||
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space.run { aValue + bValue } }
|
||||
}
|
||||
|
||||
override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator ->
|
||||
a.sample(generator).pipe { space.run { it * k.toDouble() } }
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user