Compare commits
1 Commits
dev
...
feature/fo
Author | SHA1 | Date | |
---|---|---|---|
c63eac3ba0 |
kmath-stat/src
commonMain/kotlin/space/kscience/kmath/samplers
commonTest/kotlin/space/kscience/kmath/samplers
@ -11,8 +11,10 @@ import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
import space.kscience.kmath.UnstableKMathAPI
|
||||
import space.kscience.kmath.chains.Chain
|
||||
import space.kscience.kmath.chains.map
|
||||
import space.kscience.kmath.operations.Group
|
||||
import space.kscience.kmath.operations.ScaleOperations
|
||||
import space.kscience.kmath.operations.algebra
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.stat.Sampler
|
||||
import space.kscience.kmath.structures.Float64
|
||||
@ -24,9 +26,14 @@ public data class RandomForkingSample<T>(
|
||||
val generation: Int,
|
||||
val energy: Float64,
|
||||
val generator: RandomGenerator,
|
||||
val stepChain: Chain<T>
|
||||
val stepChain: Chain<T>,
|
||||
)
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T> Sampler<RandomForkingSample<T>?>.values(): Sampler<T?> = Sampler { generator ->
|
||||
this@values.sample(generator).map { it?.value?.await() }
|
||||
}
|
||||
|
||||
/**
|
||||
* A sampler that creates a chain that could be split at each computation
|
||||
*/
|
||||
@ -34,17 +41,18 @@ public data class RandomForkingSample<T>(
|
||||
public class RandomForkingSampler<T : Any>(
|
||||
private val scope: CoroutineScope,
|
||||
private val initialValue: suspend (RandomGenerator) -> T,
|
||||
private val makeStep: suspend (T) -> List<T>
|
||||
private val makeStep: suspend (T) -> List<T>,
|
||||
) : Sampler<T?> {
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<T?> =
|
||||
buildChain(scope, initial = { initialValue(generator) }) { makeStep(it) }
|
||||
|
||||
public companion object {
|
||||
@OptIn(ExperimentalCoroutinesApi::class)
|
||||
private suspend fun <T> Channel<T>.receiveEvents(
|
||||
initial: T,
|
||||
buffer: Int = 50,
|
||||
makeStep: suspend (T) -> List<T>
|
||||
makeStep: suspend (T) -> List<T>,
|
||||
) {
|
||||
send(initial)
|
||||
//inner dispatch queue
|
||||
@ -66,7 +74,7 @@ public class RandomForkingSampler<T : Any>(
|
||||
internal fun <T : Any> buildChain(
|
||||
scope: CoroutineScope,
|
||||
initial: suspend () -> T,
|
||||
makeStep: suspend (T) -> List<T>
|
||||
makeStep: suspend (T) -> List<T>,
|
||||
): Chain<T?> {
|
||||
val channel = Channel<T>(Channel.RENDEZVOUS)
|
||||
scope.launch {
|
||||
@ -80,56 +88,105 @@ public class RandomForkingSampler<T : Any>(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @param scope [CoroutineScope] for forking procedure
|
||||
* @param initialEnergy initial energy parameter for the fork generation. By default equals 1.0
|
||||
* @param energySplitRule the main rule to generate new samples. If resulting list is empty, the chain dies.
|
||||
* If list contains a single element, the chain continues without forking. If it contains multiple elements,
|
||||
* the chain forks.
|
||||
* @param stepScaleRule the rule to multiply the step generated by [stepSampler] by a value depending on energy.
|
||||
*/
|
||||
public fun <T : Any, A> metropolisHastings(
|
||||
scope: CoroutineScope,
|
||||
algebra: A,
|
||||
startPoint: suspend (RandomGenerator) -> T,
|
||||
stepSampler: Sampler<T>,
|
||||
initialEnergy: Float64,
|
||||
initialEnergy: Float64 = 1.0,
|
||||
energySplitRule: suspend RandomForkingSample<T>.() -> List<Float64>,
|
||||
stepScaleRule: suspend (Float64) -> Float64 = { 1.0 },
|
||||
targetPdf: suspend (T) -> Float64,
|
||||
): RandomForkingSampler<RandomForkingSample<T>> where A : Group<T>, A : ScaleOperations<T> =
|
||||
RandomForkingSampler(
|
||||
scope = scope,
|
||||
initialValue = { generator ->
|
||||
RandomForkingSample<T>(
|
||||
value = scope.async { startPoint(generator) },
|
||||
generation = 0,
|
||||
energy = initialEnergy,
|
||||
generator = generator,
|
||||
stepChain = stepSampler.sample(generator)
|
||||
)
|
||||
}
|
||||
) { previousSample: RandomForkingSample<T> ->
|
||||
val value = previousSample.value.await()
|
||||
previousSample.energySplitRule().map { energy ->
|
||||
RandomForkingSample<T>(
|
||||
value = scope.async<T> {
|
||||
val proposalPoint = with(algebra) {
|
||||
value + previousSample.stepChain.next() * stepScaleRule(energy)
|
||||
}
|
||||
val ratio = targetPdf(proposalPoint) / targetPdf(value)
|
||||
): RandomForkingSampler<RandomForkingSample<T>>
|
||||
where A : Group<T>, A : ScaleOperations<T> = RandomForkingSampler(
|
||||
scope = scope,
|
||||
initialValue = { generator ->
|
||||
RandomForkingSample<T>(
|
||||
value = scope.async { startPoint(generator) },
|
||||
generation = 0,
|
||||
energy = initialEnergy,
|
||||
generator = generator,
|
||||
stepChain = stepSampler.sample(generator)
|
||||
)
|
||||
}
|
||||
) { previousSample: RandomForkingSample<T> ->
|
||||
val value = previousSample.value.await()
|
||||
previousSample.energySplitRule().map { energy ->
|
||||
RandomForkingSample<T>(
|
||||
value = scope.async<T> {
|
||||
val proposalPoint = with(algebra) {
|
||||
value + previousSample.stepChain.next() * stepScaleRule(energy)
|
||||
}
|
||||
val ratio = targetPdf(proposalPoint) / targetPdf(value)
|
||||
|
||||
if (ratio >= 1.0) {
|
||||
if (ratio >= 1.0) {
|
||||
proposalPoint
|
||||
} else {
|
||||
val acceptanceProbability = previousSample.generator.nextDouble()
|
||||
if (acceptanceProbability <= ratio) {
|
||||
proposalPoint
|
||||
} else {
|
||||
val acceptanceProbability = previousSample.generator.nextDouble()
|
||||
if (acceptanceProbability <= ratio) {
|
||||
proposalPoint
|
||||
} else {
|
||||
value
|
||||
}
|
||||
value
|
||||
}
|
||||
},
|
||||
generation = previousSample.generation + 1,
|
||||
energy = energy,
|
||||
generator = previousSample.generator,
|
||||
stepChain = previousSample.stepChain
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
generation = previousSample.generation + 1,
|
||||
energy = energy,
|
||||
generator = previousSample.generator,
|
||||
stepChain = previousSample.stepChain
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A Metropolis-Hastings sampler for univariate [Float64] values
|
||||
*/
|
||||
public fun univariate(
|
||||
scope: CoroutineScope,
|
||||
startPoint: Float64,
|
||||
stepSampler: Sampler<Float64>,
|
||||
initialEnergy: Float64 = 1.0,
|
||||
energySplitRule: suspend RandomForkingSample<Float64>.() -> List<Float64>,
|
||||
stepScaleRule: suspend (Float64) -> Float64 = { 1.0 },
|
||||
targetPdf: suspend (Float64) -> Float64,
|
||||
): RandomForkingSampler<RandomForkingSample<Float64>> = metropolisHastings(
|
||||
scope = scope,
|
||||
algebra = Float64.algebra,
|
||||
startPoint = { startPoint },
|
||||
stepSampler = stepSampler,
|
||||
initialEnergy = initialEnergy,
|
||||
energySplitRule = energySplitRule,
|
||||
stepScaleRule = stepScaleRule,
|
||||
targetPdf = targetPdf
|
||||
)
|
||||
|
||||
/**
|
||||
* A Metropolis-Hastings sampler for univariate [Float64] values with normal step distribution
|
||||
*/
|
||||
public fun univariateNormal(
|
||||
scope: CoroutineScope,
|
||||
startPoint: Float64,
|
||||
stepSigma: Float64,
|
||||
initialEnergy: Float64 = 1.0,
|
||||
energySplitRule: suspend RandomForkingSample<Float64>.() -> List<Float64>,
|
||||
stepScaleRule: suspend (Float64) -> Float64 = { 1.0 },
|
||||
targetPdf: suspend (Float64) -> Float64,
|
||||
): RandomForkingSampler<RandomForkingSample<Float64>> = univariate(
|
||||
scope = scope,
|
||||
startPoint = startPoint,
|
||||
stepSampler = GaussianSampler(0.0, stepSigma),
|
||||
initialEnergy = initialEnergy,
|
||||
energySplitRule = energySplitRule,
|
||||
stepScaleRule = stepScaleRule,
|
||||
targetPdf = targetPdf
|
||||
)
|
||||
}
|
||||
}
|
||||
|
97
kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsForkingSampler.kt
Normal file
97
kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsForkingSampler.kt
Normal file
@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.samplers
|
||||
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import space.kscience.kmath.chains.discard
|
||||
import space.kscience.kmath.chains.nextBuffer
|
||||
import space.kscience.kmath.distributions.NormalDistribution
|
||||
import space.kscience.kmath.operations.Float64Field
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.stat.invoke
|
||||
import space.kscience.kmath.stat.mean
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.exp
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class TestMetropolisHastingsForkingSampler {
|
||||
|
||||
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
|
||||
|
||||
private val sample = 1e5.toInt()
|
||||
private val burnIn = sample / 5
|
||||
|
||||
|
||||
@Test
|
||||
fun samplingNormalTest() = runTest {
|
||||
val generator = RandomGenerator.default(1)
|
||||
|
||||
listOf(
|
||||
TestSetup(0.5, 0.0),
|
||||
TestSetup(68.13, 60.0),
|
||||
).forEach {
|
||||
val distribution = NormalDistribution(it.mean, 1.0)
|
||||
val sampler = RandomForkingSampler.univariateNormal(
|
||||
scope = this,
|
||||
startPoint = it.startPoint,
|
||||
stepSigma = it.sigma,
|
||||
energySplitRule = TODO(),
|
||||
targetPdf = distribution::probability,
|
||||
).values()
|
||||
|
||||
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||
|
||||
assertEquals(it.mean, Float64Field.mean(sampledValues), 0.05)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun samplingExponentialTest() = runTest {
|
||||
val generator = RandomGenerator.default(1)
|
||||
|
||||
fun expDist(lambda: Double, arg: Double): Double = if (arg <= 0.0) 0.0 else lambda * exp(-arg * lambda)
|
||||
|
||||
listOf(
|
||||
TestSetup(0.5, 2.0),
|
||||
TestSetup(2.0, 1.0)
|
||||
).forEach { setup ->
|
||||
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
|
||||
expDist(setup.mean, it)
|
||||
}
|
||||
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||
|
||||
assertEquals(1.0 / setup.mean, Float64Field.mean(sampledValues), 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun samplingRayleighTest() = runTest {
|
||||
val generator = RandomGenerator.default(1)
|
||||
|
||||
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
arg * exp(-(arg / sigma).pow(2) / 2.0) / sigma.pow(2)
|
||||
}
|
||||
|
||||
listOf(
|
||||
TestSetup(0.5, 1.0),
|
||||
TestSetup(2.0, 1.0)
|
||||
).forEach { setup ->
|
||||
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
|
||||
rayleighDist(setup.mean, it)
|
||||
}
|
||||
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||
|
||||
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 0.05)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -36,8 +36,11 @@ class TestMetropolisHastingsSampler {
|
||||
TestSetup(68.13, 60.0),
|
||||
).forEach {
|
||||
val distribution = NormalDistribution(it.mean, 1.0)
|
||||
val sampler = MetropolisHastingsSampler
|
||||
.univariateNormal(it.startPoint, it.sigma, distribution::probability)
|
||||
val sampler = MetropolisHastingsSampler.univariateNormal(
|
||||
startPoint = it.startPoint,
|
||||
stepSigma = it.sigma,
|
||||
targetPdf = distribution::probability
|
||||
)
|
||||
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||
|
||||
assertEquals(it.mean, Float64Field.mean(sampledValues), 0.05)
|
||||
@ -87,5 +90,4 @@ class TestMetropolisHastingsSampler {
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user