diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt index 58841e7b8..317cdcea0 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt @@ -37,6 +37,8 @@ interface Chain { */ fun fork(): Chain + companion object + } /** @@ -137,4 +139,13 @@ fun Chain.mapWithState(state: S, stateFork: (S) -> S, mapper: suspe object : Chain { override suspend fun next(): R = state.mapper(this@mapWithState) override fun fork(): Chain = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper) - } \ No newline at end of file + } + +/** + * Zip two chains together using given transformation + */ +fun Chain.zip(other: Chain, block: suspend (T, U) -> R): Chain = object : Chain { + override suspend fun next(): R = block(this@zip.next(), other.next()) + + override fun fork(): Chain = this@zip.fork().zip(other.fork(), block) +} \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt index d3cac9163..487e13876 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt @@ -5,7 +5,7 @@ import scientifik.kmath.chains.map import kotlin.jvm.JvmName interface Sampler { - fun sample(generator: RandomGenerator): RandomChain + fun sample(generator: RandomGenerator): Chain } /** @@ -22,7 +22,7 @@ interface Distribution : Sampler { * Create a chain of samples from this distribution. * The chain is not guaranteed to be stateless. */ - override fun sample(generator: RandomGenerator): RandomChain + override fun sample(generator: RandomGenerator): Chain /** * An empty companion. Distribution factories should be written as its extensions diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt new file mode 100644 index 000000000..ca9d74aab --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/SamplerAlgebra.kt @@ -0,0 +1,31 @@ +package scientifik.kmath.prob + +import scientifik.kmath.chains.Chain +import scientifik.kmath.chains.ConstantChain +import scientifik.kmath.chains.pipe +import scientifik.kmath.chains.zip +import scientifik.kmath.operations.Space + +class BasicSampler(val chainBuilder: (RandomGenerator) -> Chain) : Sampler { + override fun sample(generator: RandomGenerator): Chain = chainBuilder(generator) +} + +class ConstantSampler(val value: T) : Sampler { + override fun sample(generator: RandomGenerator): Chain = ConstantChain(value) +} + +/** + * A space for samplers. Allows to perform simple operations on distributions + */ +class SamplerSpace(val space: Space) : Space> { + + override val zero: Sampler = ConstantSampler(space.zero) + + override fun add(a: Sampler, b: Sampler): Sampler = BasicSampler { generator -> + a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space.run { aValue + bValue } } + } + + override fun multiply(a: Sampler, k: Number): Sampler = BasicSampler { generator -> + a.sample(generator).pipe { space.run { it * k.toDouble() } } + } +} \ No newline at end of file