Reimplement random-forking chain
This commit is contained in:
parent
6619db3f45
commit
2becee7f59
@ -8,35 +8,51 @@
|
||||
package space.kscience.kmath.samplers
|
||||
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Deferred
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import space.kscience.kmath.UnstableKMathAPI
|
||||
import space.kscience.kmath.chains.Chain
|
||||
import space.kscience.kmath.operations.Group
|
||||
import space.kscience.kmath.operations.ScaleOperations
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.stat.Sampler
|
||||
import space.kscience.kmath.structures.Float64
|
||||
import kotlin.coroutines.coroutineContext
|
||||
|
||||
@UnstableKMathAPI
|
||||
public data class RandomForkingSample<T>(
|
||||
val value: Deferred<T>,
|
||||
val generation: Int,
|
||||
val energy: Float64,
|
||||
val stepChain: Chain<T>
|
||||
)
|
||||
|
||||
/**
|
||||
* A sampler that creates a chain that could be split at each computation
|
||||
*/
|
||||
public class RandomForkingSampler<T: Any>(
|
||||
@UnstableKMathAPI
|
||||
public class RandomForkingSampler<T : Any>(
|
||||
private val scope: CoroutineScope,
|
||||
private val initialValue: T,
|
||||
private val initialValue: suspend (RandomGenerator) -> T,
|
||||
private val makeStep: suspend RandomGenerator.(T) -> List<T>
|
||||
) : Sampler<T?> {
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<T?> = buildChain(scope, initialValue) { generator.makeStep(it) }
|
||||
override fun sample(generator: RandomGenerator): Chain<T?> =
|
||||
buildChain(scope, initial = { initialValue(generator) }) { generator.makeStep(it) }
|
||||
|
||||
public companion object {
|
||||
private suspend fun <T> Channel<T>.receiveEvents(
|
||||
initial: T,
|
||||
buffer: Int = 50,
|
||||
makeStep: suspend (T) -> List<T>
|
||||
) {
|
||||
send(initial)
|
||||
//inner dispatch queue
|
||||
val innerChannel = Channel<T>(50)
|
||||
val innerChannel = Channel<T>(buffer)
|
||||
innerChannel.send(initial)
|
||||
while (coroutineContext.isActive && !innerChannel.isEmpty) {
|
||||
val current = innerChannel.receive()
|
||||
@ -51,21 +67,71 @@ public class RandomForkingSampler<T: Any>(
|
||||
}
|
||||
|
||||
|
||||
public fun <T: Any> buildChain(
|
||||
internal fun <T : Any> buildChain(
|
||||
scope: CoroutineScope,
|
||||
initial: T,
|
||||
initial: suspend () -> T,
|
||||
makeStep: suspend (T) -> List<T>
|
||||
): Chain<T?> {
|
||||
val channel = Channel<T>(Channel.RENDEZVOUS)
|
||||
scope.launch {
|
||||
channel.receiveEvents(initial, makeStep)
|
||||
channel.receiveEvents(initial(), makeStep = makeStep)
|
||||
}
|
||||
|
||||
return object : Chain<T?> {
|
||||
override suspend fun next(): T? = channel.receiveCatching().getOrNull()
|
||||
|
||||
override suspend fun fork(): Chain<T?> = buildChain(scope, channel.receive(), makeStep)
|
||||
override suspend fun fork(): Chain<T?> = buildChain(scope, { channel.receive() }, makeStep)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public fun <T : Any, A> metropolisHastings(
|
||||
scope: CoroutineScope,
|
||||
algebra: A,
|
||||
startPoint: suspend (RandomGenerator) -> T,
|
||||
stepSampler: Sampler<T>,
|
||||
initialEnergy: Float64,
|
||||
energySplitRule: suspend RandomForkingSample<T>.() -> List<Float64>,
|
||||
stepScaleRule: suspend (Float64) -> Float64 = { 1.0 },
|
||||
targetPdf: suspend (T) -> Float64,
|
||||
): RandomForkingSampler<RandomForkingSample<T>> where A : Group<T>, A : ScaleOperations<T> =
|
||||
RandomForkingSampler<RandomForkingSample<T>>(
|
||||
scope = scope,
|
||||
initialValue = { generator ->
|
||||
RandomForkingSample<T>(
|
||||
value = scope.async { startPoint(generator) },
|
||||
generation = 0,
|
||||
energy = initialEnergy,
|
||||
stepChain = stepSampler.sample(generator)
|
||||
)
|
||||
}
|
||||
) { previousSample: RandomForkingSample<T> ->
|
||||
val value = previousSample.value.await()
|
||||
previousSample.energySplitRule().map { energy ->
|
||||
RandomForkingSample<T>(
|
||||
value = scope.async<T> {
|
||||
val proposalPoint = with(algebra) {
|
||||
value + previousSample.stepChain.next() * stepScaleRule(previousSample.energy)
|
||||
}
|
||||
val ratio = targetPdf(proposalPoint) / targetPdf(value)
|
||||
|
||||
if (ratio >= 1.0) {
|
||||
proposalPoint
|
||||
} else {
|
||||
val acceptanceProbability = nextDouble()
|
||||
if (acceptanceProbability <= ratio) {
|
||||
proposalPoint
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
},
|
||||
generation = previousSample.generation + 1,
|
||||
energy = 0.0,
|
||||
stepChain = previousSample.stepChain
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user