Dev #72

Merged
altavir merged 16 commits from dev into master 2019-06-26 18:54:12 +03:00
3 changed files with 45 additions and 3 deletions
Showing only changes of commit e077da39c8 - Show all commits

View File

@ -37,6 +37,8 @@ interface Chain<out R> {
*/ */
fun fork(): Chain<R> fun fork(): Chain<R>
companion object
} }
/** /**
@ -137,4 +139,13 @@ fun <T, S, R> Chain<T>.mapWithState(state: S, stateFork: (S) -> S, mapper: suspe
object : Chain<R> { object : Chain<R> {
override suspend fun next(): R = state.mapper(this@mapWithState) override suspend fun next(): R = state.mapper(this@mapWithState)
override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper) override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper)
} }
/**
* Zip two chains together using given transformation
*/
fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
override suspend fun next(): R = block(this@zip.next(), other.next())
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
}

View File

@ -5,7 +5,7 @@ import scientifik.kmath.chains.map
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
interface Sampler<T : Any> { interface Sampler<T : Any> {
fun sample(generator: RandomGenerator): RandomChain<T> fun sample(generator: RandomGenerator): Chain<T>
} }
/** /**
@ -22,7 +22,7 @@ interface Distribution<T : Any> : Sampler<T> {
* Create a chain of samples from this distribution. * Create a chain of samples from this distribution.
* The chain is not guaranteed to be stateless. * The chain is not guaranteed to be stateless.
*/ */
override fun sample(generator: RandomGenerator): RandomChain<T> override fun sample(generator: RandomGenerator): Chain<T>
/** /**
* An empty companion. Distribution factories should be written as its extensions * An empty companion. Distribution factories should be written as its extensions

View File

@ -0,0 +1,31 @@
package scientifik.kmath.prob
import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.ConstantChain
import scientifik.kmath.chains.pipe
import scientifik.kmath.chains.zip
import scientifik.kmath.operations.Space
class BasicSampler<T : Any>(val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> {
override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator)
}
class ConstantSampler<T : Any>(val value: T) : Sampler<T> {
override fun sample(generator: RandomGenerator): Chain<T> = ConstantChain(value)
}
/**
* A space for samplers. Allows to perform simple operations on distributions
*/
class SamplerSpace<T : Any>(val space: Space<T>) : Space<Sampler<T>> {
override val zero: Sampler<T> = ConstantSampler(space.zero)
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator ->
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space.run { aValue + bValue } }
}
override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator ->
a.sample(generator).pipe { space.run { it * k.toDouble() } }
}
}