Add Metropolis-Hastings sampler.

Minor fixes.
This commit is contained in:
Alexander Nozik 2024-08-04 15:01:47 +03:00
parent e0997ccf9c
commit 57d1cd8c87
8 changed files with 152 additions and 92 deletions

View File

@ -3,6 +3,7 @@
## Unreleased
### Added
- Metropolis-Hastings sampler
### Changed

View File

@ -0,0 +1,9 @@
[versions]
commons-rng = "1.6"
[libraries]
commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"}
commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"}

View File

@ -40,13 +40,13 @@ public interface BlockingBufferChain<out T> : BlockingChain<T>, BufferChain<T> {
}
public suspend inline fun <reified T : Any> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
public suspend inline fun <reified T> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
nextBuffer(size)
} else {
Buffer(size) { next() }
}
public inline fun <reified T : Any> BlockingChain<T>.nextBufferBlocking(
public inline fun <reified T> BlockingChain<T>.nextBufferBlocking(
size: Int,
): Buffer<T> = if (this is BlockingBufferChain) {
nextBufferBlocking(size)

View File

@ -23,7 +23,7 @@ public interface Chain<out T> : Flow<T> {
public suspend fun next(): T
/**
* Create a copy of current chain state. Consuming resulting chain does not affect initial chain.
* Create a copy of the current chain state. Consuming the resulting chain does not affect the initial chain.
*/
public suspend fun fork(): Chain<T>
@ -47,7 +47,7 @@ public class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
/**
* A stateless Markov chain
*/
public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
public class MarkovChain<out R>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
private val mutex: Mutex = Mutex()
private var value: R? = null
@ -71,8 +71,8 @@ public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private
*/
public class StatefulChain<S, out R>(
private val state: S,
private val seed: S.() -> R,
private val forkState: ((S) -> S),
private val seed: suspend S.() -> R,
private val forkState: suspend ((S) -> S),
private val gen: suspend S.(R) -> R,
) : Chain<R> {
private val mutex: Mutex = Mutex()
@ -97,6 +97,11 @@ public class ConstantChain<out T>(public val value: T) : Chain<T> {
override suspend fun fork(): Chain<T> = this
}
/**
* Discard a fixed number of samples
*/
public suspend inline fun <reified T> Chain<T>.discard(number: Int): Chain<T> = apply { nextBuffer<T>(number) }
/**
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
* since mapped chain consumes tokens. Accepts regular transformation function.

View File

@ -7,21 +7,17 @@ kscience {
js()
native()
wasm()
}
kotlin.sourceSets {
useCoroutines()
commonMain {
dependencies {
api(projects.kmathCoroutines)
//implementation(spclibs.atomicfu)
}
}
getByName("jvmMain") {
dependencies {
api("org.apache.commons:commons-rng-sampling:1.3")
api("org.apache.commons:commons-rng-simple:1.3")
}
jvmMain {
api(libs.commons.rng.simple)
api(libs.commons.rng.sampling)
}
}

View File

@ -5,46 +5,72 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.distributions.Distribution1D
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.chains.StatefulChain
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.algebra
import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.structures.Float64Buffer
import kotlin.math.*
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.Float64
/**
* [MetropolisHastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling
* target distribution [targetDist].
* target distribution [targetPdf].
*
* The normal distribution is used as the proposal function.
*
* params:
* - targetDist: function close to the density of the sampled distribution;
* - initialState: initial value of the chain of sampled values;
* - proposalStd: standard deviation of the proposal function.
* @param stepSampler a sampler for proposal point (offset to previous point)
*/
public class MetropolisHastingsSampler(
public val targetDist: (arg : Double) -> Double,
public val initialState : Double = 0.0,
public val proposalStd : Double = 1.0,
) : BlockingDoubleSampler {
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
var currentState = initialState
fun proposalDist(arg : Double) = NormalDistribution(arg, proposalStd)
public class MetropolisHastingsSampler<T>(
public val algebra: Group<T>,
public val startPoint: T,
public val stepSampler: Sampler<T>,
public val targetPdf: suspend (T) -> Float64,
) : Sampler<T> {
override fun nextBufferBlocking(size: Int): Float64Buffer {
val acceptanceProb = generator.nextDoubleBuffer(size)
//TODO consider API for conditional step probability
return Float64Buffer(size) {index ->
val newState = proposalDist(currentState).sample(generator).nextBufferBlocking(1).get(0)
val acceptanceRatio = min(1.0, targetDist(newState) / targetDist(currentState))
currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState
currentState
override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>(
state = stepSampler.sample(generator),
seed = { startPoint },
forkState = Chain<T>::fork
) { previousPoint: T ->
val proposalPoint = with(algebra) { previousPoint + next() }
val ratio = targetPdf(proposalPoint) / targetPdf(previousPoint)
if (ratio >= 1.0) {
proposalPoint
} else {
val acceptanceProbability = generator.nextDouble()
if (acceptanceProbability <= ratio) {
proposalPoint
} else {
previousPoint
}
}
}
override suspend fun fork(): BlockingDoubleChain = sample(generator.fork())
}
public companion object {
/**
* A Metropolis-Hastings sampler for univariate [Float64] values
*/
public fun univariate(
startPoint: Float64,
stepSampler: Sampler<Float64>,
targetPdf: suspend (Float64) -> Float64,
): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler(
algebra = Float64.algebra,
startPoint = startPoint,
stepSampler = stepSampler,
targetPdf = targetPdf
)
/**
* A Metropolis-Hastings sampler for univariate [Float64] values with normal step distribution
*/
public fun univariateNormal(
startPoint: Float64,
stepSigma: Float64,
targetPdf: suspend (Float64) -> Float64,
): MetropolisHastingsSampler<Double> = univariate(startPoint, GaussianSampler(0.0, stepSigma), targetPdf)
}
}

View File

@ -17,7 +17,7 @@ import kotlin.jvm.JvmName
/**
* Sampler that generates chains of values of type [T].
*/
public fun interface Sampler<out T : Any> {
public fun interface Sampler<out T> {
/**
* Generates a chain of samples.
*

View File

@ -4,73 +4,96 @@
*/
package space.kscience.kmath.samplers
import kotlinx.coroutines.test.runTest
import space.kscience.kmath.chains.discard
import space.kscience.kmath.chains.nextBuffer
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.random.DefaultGenerator
import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.stat.invoke
import space.kscience.kmath.stat.mean
import kotlin.math.PI
import kotlin.math.exp
import kotlin.math.pow
import kotlin.math.sqrt
import kotlin.test.Test
import kotlin.test.assertEquals
class TestMetropolisHastingsSampler {
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
val sample = 1e6.toInt()
val burnIn = sample / 5
@Test
fun samplingNormalTest() {
fun normalDist1(arg : Double) = NormalDistribution(0.5, 1.0).probability(arg)
var sampler = MetropolisHastingsSampler(::normalDist1, proposalStd = 1.0)
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
fun samplingNormalTest() = runTest {
val generator = RandomGenerator.default(1)
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2)
listOf(
TestSetup(0.5, 0.0),
TestSetup(68.13, 60.0),
).forEach {
val distribution = NormalDistribution(it.mean, 1.0)
val sampler = MetropolisHastingsSampler
.univariateNormal(it.startPoint, it.sigma, distribution::probability)
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
fun normalDist2(arg : Double) = NormalDistribution(68.13, 1.0).probability(arg)
sampler = MetropolisHastingsSampler(::normalDist2, initialState = 63.0, proposalStd = 1.0)
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
assertEquals(68.13, Float64Field.mean(sampledValues), 1e-2)
assertEquals(it.mean, Float64Field.mean(sampledValues), 1e-2)
}
}
@Test
fun samplingExponentialTest() {
fun expDist(arg : Double, param : Double) : Double {
if (arg < 0.0) { return 0.0 }
return param * exp(-param * arg)
fun samplingExponentialTest() = runTest {
val generator = RandomGenerator.default(1)
fun expDist(lambda: Double, arg: Double): Double = if (arg <= 0.0) 0.0 else lambda * exp(-arg * lambda)
listOf(
TestSetup(0.5, 2.0),
TestSetup(2.0, 1.0)
).forEach { setup ->
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
expDist(setup.mean, it)
}
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
fun expDist1(arg : Double) = expDist(arg, 0.5)
var sampler = MetropolisHastingsSampler(::expDist1, initialState = 2.0, proposalStd = 1.0)
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
assertEquals(2.0, Float64Field.mean(sampledValues), 1e-2)
fun expDist2(arg : Double) = expDist(arg, 2.0)
sampler = MetropolisHastingsSampler(::expDist2, initialState = 9.0, proposalStd = 1.0)
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2)
assertEquals(1.0 / setup.mean, Float64Field.mean(sampledValues), 1e-2)
}
}
@Test
fun samplingRayleighTest() {
fun rayleighDist(arg : Double, sigma : Double) : Double {
if (arg < 0.0) { return 0.0 }
val expArg = (arg / sigma).pow(2)
return arg * exp(-expArg / 2.0) / sigma.pow(2)
fun samplingRayleighTest() = runTest {
val generator = RandomGenerator.default(1)
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
0.0
} else {
arg * exp(-(arg / sigma).pow(2) / 2.0) / sigma.pow(2)
}
fun rayleighDist1(arg : Double) = rayleighDist(arg, 1.0)
var sampler = MetropolisHastingsSampler(::rayleighDist1, initialState = 2.0, proposalStd = 1.0)
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
listOf(
TestSetup(0.5, 1.0),
TestSetup(2.0, 1.0)
).forEach { setup ->
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
rayleighDist(setup.mean, it)
}
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
fun rayleighDist2(arg : Double) = rayleighDist(arg, 2.0)
sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
}
//
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
//
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
//
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
//
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
}
}