Compare commits

..

5 Commits

Author SHA1 Message Date
57d1cd8c87 Add Metropolis-Hastings sampler.
Minor fixes.
2024-08-04 15:01:47 +03:00
SPC-code
e0997ccf9c
Merge pull request #533 from Vasilev-Ilya/STUD-7_metropolis_hastings_sampler
Draft: STUD-7: Metropolis-Hastings sampler implementation
2024-08-03 21:35:29 +03:00
vasilev.ilya
3417d8cdc1 Minor edits. Tests added. | STUD-7 2024-05-20 01:12:27 +03:00
dbc5488eb2 Minor edits. Tests added. | STUD-7 2024-04-24 23:29:14 +03:00
d0d250c67f MHS first implementation | STUD-7 2024-04-22 22:01:15 +03:00
8 changed files with 204 additions and 18 deletions

View File

@ -3,6 +3,7 @@
## Unreleased
### Added
- Metropolis-Hastings sampler
### Changed

View File

@ -0,0 +1,9 @@
[versions]
commons-rng = "1.6"
[libraries]
commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"}
commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"}

View File

@ -40,13 +40,13 @@ public interface BlockingBufferChain<out T> : BlockingChain<T>, BufferChain<T> {
}
public suspend inline fun <reified T : Any> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
public suspend inline fun <reified T> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
nextBuffer(size)
} else {
Buffer(size) { next() }
}
public inline fun <reified T : Any> BlockingChain<T>.nextBufferBlocking(
public inline fun <reified T> BlockingChain<T>.nextBufferBlocking(
size: Int,
): Buffer<T> = if (this is BlockingBufferChain) {
nextBufferBlocking(size)

View File

@ -23,7 +23,7 @@ public interface Chain<out T> : Flow<T> {
public suspend fun next(): T
/**
* Create a copy of current chain state. Consuming resulting chain does not affect initial chain.
* Create a copy of the current chain state. Consuming the resulting chain does not affect the initial chain.
*/
public suspend fun fork(): Chain<T>
@ -47,7 +47,7 @@ public class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
/**
* A stateless Markov chain
*/
public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
public class MarkovChain<out R>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
private val mutex: Mutex = Mutex()
private var value: R? = null
@ -71,8 +71,8 @@ public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private
*/
public class StatefulChain<S, out R>(
private val state: S,
private val seed: S.() -> R,
private val forkState: ((S) -> S),
private val seed: suspend S.() -> R,
private val forkState: suspend ((S) -> S),
private val gen: suspend S.(R) -> R,
) : Chain<R> {
private val mutex: Mutex = Mutex()
@ -97,6 +97,11 @@ public class ConstantChain<out T>(public val value: T) : Chain<T> {
override suspend fun fork(): Chain<T> = this
}
/**
* Discard a fixed number of samples
*/
public suspend inline fun <reified T> Chain<T>.discard(number: Int): Chain<T> = apply { nextBuffer<T>(number) }
/**
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
* since mapped chain consumes tokens. Accepts regular transformation function.

View File

@ -7,21 +7,17 @@ kscience {
js()
native()
wasm()
}
kotlin.sourceSets {
useCoroutines()
commonMain {
dependencies {
api(projects.kmathCoroutines)
//implementation(spclibs.atomicfu)
}
api(projects.kmathCoroutines)
//implementation(spclibs.atomicfu)
}
getByName("jvmMain") {
dependencies {
api("org.apache.commons:commons-rng-sampling:1.3")
api("org.apache.commons:commons-rng-simple:1.3")
}
jvmMain {
api(libs.commons.rng.simple)
api(libs.commons.rng.sampling)
}
}

View File

@ -0,0 +1,76 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.chains.StatefulChain
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.algebra
import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.Float64
/**
* [MetropolisHastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling
* target distribution [targetPdf].
*
* @param stepSampler a sampler for proposal point (offset to previous point)
*/
public class MetropolisHastingsSampler<T>(
public val algebra: Group<T>,
public val startPoint: T,
public val stepSampler: Sampler<T>,
public val targetPdf: suspend (T) -> Float64,
) : Sampler<T> {
//TODO consider API for conditional step probability
override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>(
state = stepSampler.sample(generator),
seed = { startPoint },
forkState = Chain<T>::fork
) { previousPoint: T ->
val proposalPoint = with(algebra) { previousPoint + next() }
val ratio = targetPdf(proposalPoint) / targetPdf(previousPoint)
if (ratio >= 1.0) {
proposalPoint
} else {
val acceptanceProbability = generator.nextDouble()
if (acceptanceProbability <= ratio) {
proposalPoint
} else {
previousPoint
}
}
}
public companion object {
/**
* A Metropolis-Hastings sampler for univariate [Float64] values
*/
public fun univariate(
startPoint: Float64,
stepSampler: Sampler<Float64>,
targetPdf: suspend (Float64) -> Float64,
): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler(
algebra = Float64.algebra,
startPoint = startPoint,
stepSampler = stepSampler,
targetPdf = targetPdf
)
/**
* A Metropolis-Hastings sampler for univariate [Float64] values with normal step distribution
*/
public fun univariateNormal(
startPoint: Float64,
stepSigma: Float64,
targetPdf: suspend (Float64) -> Float64,
): MetropolisHastingsSampler<Double> = univariate(startPoint, GaussianSampler(0.0, stepSigma), targetPdf)
}
}

View File

@ -17,7 +17,7 @@ import kotlin.jvm.JvmName
/**
* Sampler that generates chains of values of type [T].
*/
public fun interface Sampler<out T : Any> {
public fun interface Sampler<out T> {
/**
* Generates a chain of samples.
*

View File

@ -0,0 +1,99 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.samplers
import kotlinx.coroutines.test.runTest
import space.kscience.kmath.chains.discard
import space.kscience.kmath.chains.nextBuffer
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.stat.invoke
import space.kscience.kmath.stat.mean
import kotlin.math.PI
import kotlin.math.exp
import kotlin.math.pow
import kotlin.math.sqrt
import kotlin.test.Test
import kotlin.test.assertEquals
class TestMetropolisHastingsSampler {
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
val sample = 1e6.toInt()
val burnIn = sample / 5
@Test
fun samplingNormalTest() = runTest {
val generator = RandomGenerator.default(1)
listOf(
TestSetup(0.5, 0.0),
TestSetup(68.13, 60.0),
).forEach {
val distribution = NormalDistribution(it.mean, 1.0)
val sampler = MetropolisHastingsSampler
.univariateNormal(it.startPoint, it.sigma, distribution::probability)
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
assertEquals(it.mean, Float64Field.mean(sampledValues), 1e-2)
}
}
@Test
fun samplingExponentialTest() = runTest {
val generator = RandomGenerator.default(1)
fun expDist(lambda: Double, arg: Double): Double = if (arg <= 0.0) 0.0 else lambda * exp(-arg * lambda)
listOf(
TestSetup(0.5, 2.0),
TestSetup(2.0, 1.0)
).forEach { setup ->
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
expDist(setup.mean, it)
}
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
assertEquals(1.0 / setup.mean, Float64Field.mean(sampledValues), 1e-2)
}
}
@Test
fun samplingRayleighTest() = runTest {
val generator = RandomGenerator.default(1)
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
0.0
} else {
arg * exp(-(arg / sigma).pow(2) / 2.0) / sigma.pow(2)
}
listOf(
TestSetup(0.5, 1.0),
TestSetup(2.0, 1.0)
).forEach { setup ->
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
rayleighDist(setup.mean, it)
}
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
}
//
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
//
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
//
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
//
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
}
}