Compare commits
5 Commits
3e8f44166c
...
57d1cd8c87
Author | SHA1 | Date | |
---|---|---|---|
57d1cd8c87 | |||
|
e0997ccf9c | ||
|
3417d8cdc1 | ||
dbc5488eb2 | |||
d0d250c67f |
@ -3,6 +3,7 @@
|
||||
## Unreleased
|
||||
|
||||
### Added
|
||||
- Metropolis-Hastings sampler
|
||||
|
||||
### Changed
|
||||
|
||||
|
9
gradle/libs.versions.toml
Normal file
9
gradle/libs.versions.toml
Normal file
@ -0,0 +1,9 @@
|
||||
[versions]
|
||||
|
||||
commons-rng = "1.6"
|
||||
|
||||
|
||||
[libraries]
|
||||
|
||||
commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"}
|
||||
commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"}
|
@ -40,13 +40,13 @@ public interface BlockingBufferChain<out T> : BlockingChain<T>, BufferChain<T> {
|
||||
}
|
||||
|
||||
|
||||
public suspend inline fun <reified T : Any> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
|
||||
public suspend inline fun <reified T> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
|
||||
nextBuffer(size)
|
||||
} else {
|
||||
Buffer(size) { next() }
|
||||
}
|
||||
|
||||
public inline fun <reified T : Any> BlockingChain<T>.nextBufferBlocking(
|
||||
public inline fun <reified T> BlockingChain<T>.nextBufferBlocking(
|
||||
size: Int,
|
||||
): Buffer<T> = if (this is BlockingBufferChain) {
|
||||
nextBufferBlocking(size)
|
||||
|
@ -23,7 +23,7 @@ public interface Chain<out T> : Flow<T> {
|
||||
public suspend fun next(): T
|
||||
|
||||
/**
|
||||
* Create a copy of current chain state. Consuming resulting chain does not affect initial chain.
|
||||
* Create a copy of the current chain state. Consuming the resulting chain does not affect the initial chain.
|
||||
*/
|
||||
public suspend fun fork(): Chain<T>
|
||||
|
||||
@ -47,7 +47,7 @@ public class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
||||
/**
|
||||
* A stateless Markov chain
|
||||
*/
|
||||
public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
||||
public class MarkovChain<out R>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
||||
private val mutex: Mutex = Mutex()
|
||||
private var value: R? = null
|
||||
|
||||
@ -71,8 +71,8 @@ public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private
|
||||
*/
|
||||
public class StatefulChain<S, out R>(
|
||||
private val state: S,
|
||||
private val seed: S.() -> R,
|
||||
private val forkState: ((S) -> S),
|
||||
private val seed: suspend S.() -> R,
|
||||
private val forkState: suspend ((S) -> S),
|
||||
private val gen: suspend S.(R) -> R,
|
||||
) : Chain<R> {
|
||||
private val mutex: Mutex = Mutex()
|
||||
@ -97,6 +97,11 @@ public class ConstantChain<out T>(public val value: T) : Chain<T> {
|
||||
override suspend fun fork(): Chain<T> = this
|
||||
}
|
||||
|
||||
/**
|
||||
* Discard a fixed number of samples
|
||||
*/
|
||||
public suspend inline fun <reified T> Chain<T>.discard(number: Int): Chain<T> = apply { nextBuffer<T>(number) }
|
||||
|
||||
/**
|
||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||
* since mapped chain consumes tokens. Accepts regular transformation function.
|
||||
|
@ -7,21 +7,17 @@ kscience {
|
||||
js()
|
||||
native()
|
||||
wasm()
|
||||
}
|
||||
|
||||
kotlin.sourceSets {
|
||||
useCoroutines()
|
||||
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(projects.kmathCoroutines)
|
||||
//implementation(spclibs.atomicfu)
|
||||
}
|
||||
}
|
||||
|
||||
getByName("jvmMain") {
|
||||
dependencies {
|
||||
api("org.apache.commons:commons-rng-sampling:1.3")
|
||||
api("org.apache.commons:commons-rng-simple:1.3")
|
||||
}
|
||||
jvmMain {
|
||||
api(libs.commons.rng.simple)
|
||||
api(libs.commons.rng.sampling)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.samplers
|
||||
|
||||
import space.kscience.kmath.chains.Chain
|
||||
import space.kscience.kmath.chains.StatefulChain
|
||||
import space.kscience.kmath.operations.Group
|
||||
import space.kscience.kmath.operations.algebra
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.stat.Sampler
|
||||
import space.kscience.kmath.structures.Float64
|
||||
|
||||
/**
|
||||
* [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling
|
||||
* target distribution [targetPdf].
|
||||
*
|
||||
* @param stepSampler a sampler for proposal point (offset to previous point)
|
||||
*/
|
||||
public class MetropolisHastingsSampler<T>(
|
||||
public val algebra: Group<T>,
|
||||
public val startPoint: T,
|
||||
public val stepSampler: Sampler<T>,
|
||||
public val targetPdf: suspend (T) -> Float64,
|
||||
) : Sampler<T> {
|
||||
|
||||
//TODO consider API for conditional step probability
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>(
|
||||
state = stepSampler.sample(generator),
|
||||
seed = { startPoint },
|
||||
forkState = Chain<T>::fork
|
||||
) { previousPoint: T ->
|
||||
val proposalPoint = with(algebra) { previousPoint + next() }
|
||||
val ratio = targetPdf(proposalPoint) / targetPdf(previousPoint)
|
||||
if (ratio >= 1.0) {
|
||||
proposalPoint
|
||||
} else {
|
||||
val acceptanceProbability = generator.nextDouble()
|
||||
if (acceptanceProbability <= ratio) {
|
||||
proposalPoint
|
||||
} else {
|
||||
previousPoint
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public companion object {
|
||||
|
||||
/**
|
||||
* A Metropolis-Hastings sampler for univariate [Float64] values
|
||||
*/
|
||||
public fun univariate(
|
||||
startPoint: Float64,
|
||||
stepSampler: Sampler<Float64>,
|
||||
targetPdf: suspend (Float64) -> Float64,
|
||||
): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler(
|
||||
algebra = Float64.algebra,
|
||||
startPoint = startPoint,
|
||||
stepSampler = stepSampler,
|
||||
targetPdf = targetPdf
|
||||
)
|
||||
|
||||
/**
|
||||
* A Metropolis-Hastings sampler for univariate [Float64] values with normal step distribution
|
||||
*/
|
||||
public fun univariateNormal(
|
||||
startPoint: Float64,
|
||||
stepSigma: Float64,
|
||||
targetPdf: suspend (Float64) -> Float64,
|
||||
): MetropolisHastingsSampler<Double> = univariate(startPoint, GaussianSampler(0.0, stepSigma), targetPdf)
|
||||
}
|
||||
}
|
@ -17,7 +17,7 @@ import kotlin.jvm.JvmName
|
||||
/**
|
||||
* Sampler that generates chains of values of type [T].
|
||||
*/
|
||||
public fun interface Sampler<out T : Any> {
|
||||
public fun interface Sampler<out T> {
|
||||
/**
|
||||
* Generates a chain of samples.
|
||||
*
|
||||
|
@ -0,0 +1,99 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.samplers
|
||||
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import space.kscience.kmath.chains.discard
|
||||
import space.kscience.kmath.chains.nextBuffer
|
||||
import space.kscience.kmath.distributions.NormalDistribution
|
||||
import space.kscience.kmath.operations.Float64Field
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.stat.invoke
|
||||
import space.kscience.kmath.stat.mean
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.exp
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class TestMetropolisHastingsSampler {
|
||||
|
||||
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
|
||||
|
||||
val sample = 1e6.toInt()
|
||||
val burnIn = sample / 5
|
||||
|
||||
@Test
|
||||
fun samplingNormalTest() = runTest {
|
||||
val generator = RandomGenerator.default(1)
|
||||
|
||||
listOf(
|
||||
TestSetup(0.5, 0.0),
|
||||
TestSetup(68.13, 60.0),
|
||||
).forEach {
|
||||
val distribution = NormalDistribution(it.mean, 1.0)
|
||||
val sampler = MetropolisHastingsSampler
|
||||
.univariateNormal(it.startPoint, it.sigma, distribution::probability)
|
||||
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||
|
||||
assertEquals(it.mean, Float64Field.mean(sampledValues), 1e-2)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun samplingExponentialTest() = runTest {
|
||||
val generator = RandomGenerator.default(1)
|
||||
|
||||
fun expDist(lambda: Double, arg: Double): Double = if (arg <= 0.0) 0.0 else lambda * exp(-arg * lambda)
|
||||
|
||||
listOf(
|
||||
TestSetup(0.5, 2.0),
|
||||
TestSetup(2.0, 1.0)
|
||||
).forEach { setup ->
|
||||
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
|
||||
expDist(setup.mean, it)
|
||||
}
|
||||
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||
|
||||
assertEquals(1.0 / setup.mean, Float64Field.mean(sampledValues), 1e-2)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun samplingRayleighTest() = runTest {
|
||||
val generator = RandomGenerator.default(1)
|
||||
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
arg * exp(-(arg / sigma).pow(2) / 2.0) / sigma.pow(2)
|
||||
}
|
||||
|
||||
listOf(
|
||||
TestSetup(0.5, 1.0),
|
||||
TestSetup(2.0, 1.0)
|
||||
).forEach { setup ->
|
||||
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
|
||||
rayleighDist(setup.mean, it)
|
||||
}
|
||||
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||
|
||||
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
||||
}
|
||||
//
|
||||
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
|
||||
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
|
||||
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
||||
//
|
||||
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
|
||||
//
|
||||
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
|
||||
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
|
||||
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
|
||||
//
|
||||
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user