Reimplement random-forking chain
This commit is contained in:
parent
6619db3f45
commit
2becee7f59
@ -8,35 +8,51 @@
|
|||||||
package space.kscience.kmath.samplers
|
package space.kscience.kmath.samplers
|
||||||
|
|
||||||
import kotlinx.coroutines.CoroutineScope
|
import kotlinx.coroutines.CoroutineScope
|
||||||
|
import kotlinx.coroutines.Deferred
|
||||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||||
|
import kotlinx.coroutines.async
|
||||||
import kotlinx.coroutines.channels.Channel
|
import kotlinx.coroutines.channels.Channel
|
||||||
import kotlinx.coroutines.isActive
|
import kotlinx.coroutines.isActive
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.chains.Chain
|
import space.kscience.kmath.chains.Chain
|
||||||
|
import space.kscience.kmath.operations.Group
|
||||||
|
import space.kscience.kmath.operations.ScaleOperations
|
||||||
import space.kscience.kmath.random.RandomGenerator
|
import space.kscience.kmath.random.RandomGenerator
|
||||||
import space.kscience.kmath.stat.Sampler
|
import space.kscience.kmath.stat.Sampler
|
||||||
|
import space.kscience.kmath.structures.Float64
|
||||||
import kotlin.coroutines.coroutineContext
|
import kotlin.coroutines.coroutineContext
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public data class RandomForkingSample<T>(
|
||||||
|
val value: Deferred<T>,
|
||||||
|
val generation: Int,
|
||||||
|
val energy: Float64,
|
||||||
|
val stepChain: Chain<T>
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A sampler that creates a chain that could be split at each computation
|
* A sampler that creates a chain that could be split at each computation
|
||||||
*/
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
public class RandomForkingSampler<T : Any>(
|
public class RandomForkingSampler<T : Any>(
|
||||||
private val scope: CoroutineScope,
|
private val scope: CoroutineScope,
|
||||||
private val initialValue: T,
|
private val initialValue: suspend (RandomGenerator) -> T,
|
||||||
private val makeStep: suspend RandomGenerator.(T) -> List<T>
|
private val makeStep: suspend RandomGenerator.(T) -> List<T>
|
||||||
) : Sampler<T?> {
|
) : Sampler<T?> {
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<T?> = buildChain(scope, initialValue) { generator.makeStep(it) }
|
override fun sample(generator: RandomGenerator): Chain<T?> =
|
||||||
|
buildChain(scope, initial = { initialValue(generator) }) { generator.makeStep(it) }
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
private suspend fun <T> Channel<T>.receiveEvents(
|
private suspend fun <T> Channel<T>.receiveEvents(
|
||||||
initial: T,
|
initial: T,
|
||||||
|
buffer: Int = 50,
|
||||||
makeStep: suspend (T) -> List<T>
|
makeStep: suspend (T) -> List<T>
|
||||||
) {
|
) {
|
||||||
send(initial)
|
send(initial)
|
||||||
//inner dispatch queue
|
//inner dispatch queue
|
||||||
val innerChannel = Channel<T>(50)
|
val innerChannel = Channel<T>(buffer)
|
||||||
innerChannel.send(initial)
|
innerChannel.send(initial)
|
||||||
while (coroutineContext.isActive && !innerChannel.isEmpty) {
|
while (coroutineContext.isActive && !innerChannel.isEmpty) {
|
||||||
val current = innerChannel.receive()
|
val current = innerChannel.receive()
|
||||||
@ -51,21 +67,71 @@ public class RandomForkingSampler<T: Any>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public fun <T: Any> buildChain(
|
internal fun <T : Any> buildChain(
|
||||||
scope: CoroutineScope,
|
scope: CoroutineScope,
|
||||||
initial: T,
|
initial: suspend () -> T,
|
||||||
makeStep: suspend (T) -> List<T>
|
makeStep: suspend (T) -> List<T>
|
||||||
): Chain<T?> {
|
): Chain<T?> {
|
||||||
val channel = Channel<T>(Channel.RENDEZVOUS)
|
val channel = Channel<T>(Channel.RENDEZVOUS)
|
||||||
scope.launch {
|
scope.launch {
|
||||||
channel.receiveEvents(initial, makeStep)
|
channel.receiveEvents(initial(), makeStep = makeStep)
|
||||||
}
|
}
|
||||||
|
|
||||||
return object : Chain<T?> {
|
return object : Chain<T?> {
|
||||||
override suspend fun next(): T? = channel.receiveCatching().getOrNull()
|
override suspend fun next(): T? = channel.receiveCatching().getOrNull()
|
||||||
|
|
||||||
override suspend fun fork(): Chain<T?> = buildChain(scope, channel.receive(), makeStep)
|
override suspend fun fork(): Chain<T?> = buildChain(scope, { channel.receive() }, makeStep)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public fun <T : Any, A> metropolisHastings(
|
||||||
|
scope: CoroutineScope,
|
||||||
|
algebra: A,
|
||||||
|
startPoint: suspend (RandomGenerator) -> T,
|
||||||
|
stepSampler: Sampler<T>,
|
||||||
|
initialEnergy: Float64,
|
||||||
|
energySplitRule: suspend RandomForkingSample<T>.() -> List<Float64>,
|
||||||
|
stepScaleRule: suspend (Float64) -> Float64 = { 1.0 },
|
||||||
|
targetPdf: suspend (T) -> Float64,
|
||||||
|
): RandomForkingSampler<RandomForkingSample<T>> where A : Group<T>, A : ScaleOperations<T> =
|
||||||
|
RandomForkingSampler<RandomForkingSample<T>>(
|
||||||
|
scope = scope,
|
||||||
|
initialValue = { generator ->
|
||||||
|
RandomForkingSample<T>(
|
||||||
|
value = scope.async { startPoint(generator) },
|
||||||
|
generation = 0,
|
||||||
|
energy = initialEnergy,
|
||||||
|
stepChain = stepSampler.sample(generator)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
) { previousSample: RandomForkingSample<T> ->
|
||||||
|
val value = previousSample.value.await()
|
||||||
|
previousSample.energySplitRule().map { energy ->
|
||||||
|
RandomForkingSample<T>(
|
||||||
|
value = scope.async<T> {
|
||||||
|
val proposalPoint = with(algebra) {
|
||||||
|
value + previousSample.stepChain.next() * stepScaleRule(previousSample.energy)
|
||||||
|
}
|
||||||
|
val ratio = targetPdf(proposalPoint) / targetPdf(value)
|
||||||
|
|
||||||
|
if (ratio >= 1.0) {
|
||||||
|
proposalPoint
|
||||||
|
} else {
|
||||||
|
val acceptanceProbability = nextDouble()
|
||||||
|
if (acceptanceProbability <= ratio) {
|
||||||
|
proposalPoint
|
||||||
|
} else {
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
generation = previousSample.generation + 1,
|
||||||
|
energy = 0.0,
|
||||||
|
stepChain = previousSample.stepChain
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user