Reimplement random-forking chain

This commit is contained in:
Alexander Nozik 2024-08-09 22:14:24 +03:00
parent 6619db3f45
commit 2becee7f59

View File

@ -8,35 +8,51 @@
package space.kscience.kmath.samplers package space.kscience.kmath.samplers
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.isActive import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.chains.Chain import space.kscience.kmath.chains.Chain
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.stat.Sampler import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.Float64
import kotlin.coroutines.coroutineContext import kotlin.coroutines.coroutineContext
@UnstableKMathAPI
public data class RandomForkingSample<T>(
val value: Deferred<T>,
val generation: Int,
val energy: Float64,
val stepChain: Chain<T>
)
/** /**
* A sampler that creates a chain that could be split at each computation * A sampler that creates a chain that could be split at each computation
*/ */
@UnstableKMathAPI
public class RandomForkingSampler<T : Any>( public class RandomForkingSampler<T : Any>(
private val scope: CoroutineScope, private val scope: CoroutineScope,
private val initialValue: T, private val initialValue: suspend (RandomGenerator) -> T,
private val makeStep: suspend RandomGenerator.(T) -> List<T> private val makeStep: suspend RandomGenerator.(T) -> List<T>
) : Sampler<T?> { ) : Sampler<T?> {
override fun sample(generator: RandomGenerator): Chain<T?> = buildChain(scope, initialValue) { generator.makeStep(it) } override fun sample(generator: RandomGenerator): Chain<T?> =
buildChain(scope, initial = { initialValue(generator) }) { generator.makeStep(it) }
public companion object { public companion object {
private suspend fun <T> Channel<T>.receiveEvents( private suspend fun <T> Channel<T>.receiveEvents(
initial: T, initial: T,
buffer: Int = 50,
makeStep: suspend (T) -> List<T> makeStep: suspend (T) -> List<T>
) { ) {
send(initial) send(initial)
//inner dispatch queue //inner dispatch queue
val innerChannel = Channel<T>(50) val innerChannel = Channel<T>(buffer)
innerChannel.send(initial) innerChannel.send(initial)
while (coroutineContext.isActive && !innerChannel.isEmpty) { while (coroutineContext.isActive && !innerChannel.isEmpty) {
val current = innerChannel.receive() val current = innerChannel.receive()
@ -51,21 +67,71 @@ public class RandomForkingSampler<T: Any>(
} }
public fun <T: Any> buildChain( internal fun <T : Any> buildChain(
scope: CoroutineScope, scope: CoroutineScope,
initial: T, initial: suspend () -> T,
makeStep: suspend (T) -> List<T> makeStep: suspend (T) -> List<T>
): Chain<T?> { ): Chain<T?> {
val channel = Channel<T>(Channel.RENDEZVOUS) val channel = Channel<T>(Channel.RENDEZVOUS)
scope.launch { scope.launch {
channel.receiveEvents(initial, makeStep) channel.receiveEvents(initial(), makeStep = makeStep)
} }
return object : Chain<T?> { return object : Chain<T?> {
override suspend fun next(): T? = channel.receiveCatching().getOrNull() override suspend fun next(): T? = channel.receiveCatching().getOrNull()
override suspend fun fork(): Chain<T?> = buildChain(scope, channel.receive(), makeStep) override suspend fun fork(): Chain<T?> = buildChain(scope, { channel.receive() }, makeStep)
} }
} }
public fun <T : Any, A> metropolisHastings(
scope: CoroutineScope,
algebra: A,
startPoint: suspend (RandomGenerator) -> T,
stepSampler: Sampler<T>,
initialEnergy: Float64,
energySplitRule: suspend RandomForkingSample<T>.() -> List<Float64>,
stepScaleRule: suspend (Float64) -> Float64 = { 1.0 },
targetPdf: suspend (T) -> Float64,
): RandomForkingSampler<RandomForkingSample<T>> where A : Group<T>, A : ScaleOperations<T> =
RandomForkingSampler<RandomForkingSample<T>>(
scope = scope,
initialValue = { generator ->
RandomForkingSample<T>(
value = scope.async { startPoint(generator) },
generation = 0,
energy = initialEnergy,
stepChain = stepSampler.sample(generator)
)
}
) { previousSample: RandomForkingSample<T> ->
val value = previousSample.value.await()
previousSample.energySplitRule().map { energy ->
RandomForkingSample<T>(
value = scope.async<T> {
val proposalPoint = with(algebra) {
value + previousSample.stepChain.next() * stepScaleRule(previousSample.energy)
}
val ratio = targetPdf(proposalPoint) / targetPdf(value)
if (ratio >= 1.0) {
proposalPoint
} else {
val acceptanceProbability = nextDouble()
if (acceptanceProbability <= ratio) {
proposalPoint
} else {
value
}
}
},
generation = previousSample.generation + 1,
energy = 0.0,
stepChain = previousSample.stepChain
)
}
}
} }
} }