diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt index b4d56cf73..9260b67fe 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt @@ -8,35 +8,51 @@ package space.kscience.kmath.samplers import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Deferred import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.async import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.isActive import kotlinx.coroutines.launch +import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.chains.Chain +import space.kscience.kmath.operations.Group +import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.stat.Sampler +import space.kscience.kmath.structures.Float64 import kotlin.coroutines.coroutineContext +@UnstableKMathAPI +public data class RandomForkingSample( + val value: Deferred, + val generation: Int, + val energy: Float64, + val stepChain: Chain +) /** * A sampler that creates a chain that could be split at each computation */ -public class RandomForkingSampler( +@UnstableKMathAPI +public class RandomForkingSampler( private val scope: CoroutineScope, - private val initialValue: T, + private val initialValue: suspend (RandomGenerator) -> T, private val makeStep: suspend RandomGenerator.(T) -> List ) : Sampler { - override fun sample(generator: RandomGenerator): Chain = buildChain(scope, initialValue) { generator.makeStep(it) } + override fun sample(generator: RandomGenerator): Chain = + buildChain(scope, initial = { initialValue(generator) }) { generator.makeStep(it) } public companion object { private suspend fun Channel.receiveEvents( initial: T, + buffer: Int = 50, makeStep: suspend (T) -> List ) { send(initial) //inner dispatch queue - val innerChannel = Channel(50) + val innerChannel = Channel(buffer) innerChannel.send(initial) while (coroutineContext.isActive && !innerChannel.isEmpty) { val current = innerChannel.receive() @@ -51,21 +67,71 @@ public class RandomForkingSampler( } - public fun buildChain( + internal fun buildChain( scope: CoroutineScope, - initial: T, + initial: suspend () -> T, makeStep: suspend (T) -> List ): Chain { val channel = Channel(Channel.RENDEZVOUS) scope.launch { - channel.receiveEvents(initial, makeStep) + channel.receiveEvents(initial(), makeStep = makeStep) } return object : Chain { override suspend fun next(): T? = channel.receiveCatching().getOrNull() - override suspend fun fork(): Chain = buildChain(scope, channel.receive(), makeStep) + override suspend fun fork(): Chain = buildChain(scope, { channel.receive() }, makeStep) } } + + + public fun metropolisHastings( + scope: CoroutineScope, + algebra: A, + startPoint: suspend (RandomGenerator) -> T, + stepSampler: Sampler, + initialEnergy: Float64, + energySplitRule: suspend RandomForkingSample.() -> List, + stepScaleRule: suspend (Float64) -> Float64 = { 1.0 }, + targetPdf: suspend (T) -> Float64, + ): RandomForkingSampler> where A : Group, A : ScaleOperations = + RandomForkingSampler>( + scope = scope, + initialValue = { generator -> + RandomForkingSample( + value = scope.async { startPoint(generator) }, + generation = 0, + energy = initialEnergy, + stepChain = stepSampler.sample(generator) + ) + } + ) { previousSample: RandomForkingSample -> + val value = previousSample.value.await() + previousSample.energySplitRule().map { energy -> + RandomForkingSample( + value = scope.async { + val proposalPoint = with(algebra) { + value + previousSample.stepChain.next() * stepScaleRule(previousSample.energy) + } + val ratio = targetPdf(proposalPoint) / targetPdf(value) + + if (ratio >= 1.0) { + proposalPoint + } else { + val acceptanceProbability = nextDouble() + if (acceptanceProbability <= ratio) { + proposalPoint + } else { + value + } + } + }, + generation = previousSample.generation + 1, + energy = 0.0, + stepChain = previousSample.stepChain + ) + } + } + } }