Sampler space
This commit is contained in:
parent
88ee3e5ab7
commit
e077da39c8
@ -37,6 +37,8 @@ interface Chain<out R> {
|
|||||||
*/
|
*/
|
||||||
fun fork(): Chain<R>
|
fun fork(): Chain<R>
|
||||||
|
|
||||||
|
companion object
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -137,4 +139,13 @@ fun <T, S, R> Chain<T>.mapWithState(state: S, stateFork: (S) -> S, mapper: suspe
|
|||||||
object : Chain<R> {
|
object : Chain<R> {
|
||||||
override suspend fun next(): R = state.mapper(this@mapWithState)
|
override suspend fun next(): R = state.mapper(this@mapWithState)
|
||||||
override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper)
|
override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Zip two chains together using given transformation
|
||||||
|
*/
|
||||||
|
fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
|
||||||
|
override suspend fun next(): R = block(this@zip.next(), other.next())
|
||||||
|
|
||||||
|
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
|
||||||
|
}
|
@ -5,7 +5,7 @@ import scientifik.kmath.chains.map
|
|||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
|
|
||||||
interface Sampler<T : Any> {
|
interface Sampler<T : Any> {
|
||||||
fun sample(generator: RandomGenerator): RandomChain<T>
|
fun sample(generator: RandomGenerator): Chain<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -22,7 +22,7 @@ interface Distribution<T : Any> : Sampler<T> {
|
|||||||
* Create a chain of samples from this distribution.
|
* Create a chain of samples from this distribution.
|
||||||
* The chain is not guaranteed to be stateless.
|
* The chain is not guaranteed to be stateless.
|
||||||
*/
|
*/
|
||||||
override fun sample(generator: RandomGenerator): RandomChain<T>
|
override fun sample(generator: RandomGenerator): Chain<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An empty companion. Distribution factories should be written as its extensions
|
* An empty companion. Distribution factories should be written as its extensions
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.ConstantChain
|
||||||
|
import scientifik.kmath.chains.pipe
|
||||||
|
import scientifik.kmath.chains.zip
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
|
class BasicSampler<T : Any>(val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator)
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConstantSampler<T : Any>(val value: T) : Sampler<T> {
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<T> = ConstantChain(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A space for samplers. Allows to perform simple operations on distributions
|
||||||
|
*/
|
||||||
|
class SamplerSpace<T : Any>(val space: Space<T>) : Space<Sampler<T>> {
|
||||||
|
|
||||||
|
override val zero: Sampler<T> = ConstantSampler(space.zero)
|
||||||
|
|
||||||
|
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator ->
|
||||||
|
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space.run { aValue + bValue } }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator ->
|
||||||
|
a.sample(generator).pipe { space.run { it * k.toDouble() } }
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user