Add Metropolis-Hastings sampler.

Minor fixes.
This commit is contained in:
Alexander Nozik 2024-08-04 15:01:47 +03:00
parent e0997ccf9c
commit 57d1cd8c87
8 changed files with 152 additions and 92 deletions

View File

@ -3,6 +3,7 @@
## Unreleased ## Unreleased
### Added ### Added
- Metropolis-Hastings sampler
### Changed ### Changed

View File

@ -0,0 +1,9 @@
[versions]
commons-rng = "1.6"
[libraries]
commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"}
commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"}

View File

@ -40,13 +40,13 @@ public interface BlockingBufferChain<out T> : BlockingChain<T>, BufferChain<T> {
} }
public suspend inline fun <reified T : Any> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) { public suspend inline fun <reified T> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
nextBuffer(size) nextBuffer(size)
} else { } else {
Buffer(size) { next() } Buffer(size) { next() }
} }
public inline fun <reified T : Any> BlockingChain<T>.nextBufferBlocking( public inline fun <reified T> BlockingChain<T>.nextBufferBlocking(
size: Int, size: Int,
): Buffer<T> = if (this is BlockingBufferChain) { ): Buffer<T> = if (this is BlockingBufferChain) {
nextBufferBlocking(size) nextBufferBlocking(size)

View File

@ -23,7 +23,7 @@ public interface Chain<out T> : Flow<T> {
public suspend fun next(): T public suspend fun next(): T
/** /**
* Create a copy of current chain state. Consuming resulting chain does not affect initial chain. * Create a copy of the current chain state. Consuming the resulting chain does not affect the initial chain.
*/ */
public suspend fun fork(): Chain<T> public suspend fun fork(): Chain<T>
@ -47,7 +47,7 @@ public class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
/** /**
* A stateless Markov chain * A stateless Markov chain
*/ */
public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> { public class MarkovChain<out R>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
private val mutex: Mutex = Mutex() private val mutex: Mutex = Mutex()
private var value: R? = null private var value: R? = null
@ -71,8 +71,8 @@ public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private
*/ */
public class StatefulChain<S, out R>( public class StatefulChain<S, out R>(
private val state: S, private val state: S,
private val seed: S.() -> R, private val seed: suspend S.() -> R,
private val forkState: ((S) -> S), private val forkState: suspend ((S) -> S),
private val gen: suspend S.(R) -> R, private val gen: suspend S.(R) -> R,
) : Chain<R> { ) : Chain<R> {
private val mutex: Mutex = Mutex() private val mutex: Mutex = Mutex()
@ -97,6 +97,11 @@ public class ConstantChain<out T>(public val value: T) : Chain<T> {
override suspend fun fork(): Chain<T> = this override suspend fun fork(): Chain<T> = this
} }
/**
* Discard a fixed number of samples
*/
public suspend inline fun <reified T> Chain<T>.discard(number: Int): Chain<T> = apply { nextBuffer<T>(number) }
/** /**
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed * Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
* since mapped chain consumes tokens. Accepts regular transformation function. * since mapped chain consumes tokens. Accepts regular transformation function.

View File

@ -7,21 +7,17 @@ kscience {
js() js()
native() native()
wasm() wasm()
}
kotlin.sourceSets { useCoroutines()
commonMain { commonMain {
dependencies {
api(projects.kmathCoroutines) api(projects.kmathCoroutines)
//implementation(spclibs.atomicfu) //implementation(spclibs.atomicfu)
} }
}
getByName("jvmMain") { jvmMain {
dependencies { api(libs.commons.rng.simple)
api("org.apache.commons:commons-rng-sampling:1.3") api(libs.commons.rng.sampling)
api("org.apache.commons:commons-rng-simple:1.3")
}
} }
} }

View File

@ -5,46 +5,72 @@
package space.kscience.kmath.samplers package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain import space.kscience.kmath.chains.Chain
import space.kscience.kmath.distributions.Distribution1D import space.kscience.kmath.chains.StatefulChain
import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.algebra
import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.structures.Float64Buffer import space.kscience.kmath.stat.Sampler
import kotlin.math.* import space.kscience.kmath.structures.Float64
/** /**
* [MetropolisHastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling * [MetropolisHastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling
* target distribution [targetDist]. * target distribution [targetPdf].
* *
* The normal distribution is used as the proposal function. * @param stepSampler a sampler for proposal point (offset to previous point)
*
* params:
* - targetDist: function close to the density of the sampled distribution;
* - initialState: initial value of the chain of sampled values;
* - proposalStd: standard deviation of the proposal function.
*/ */
public class MetropolisHastingsSampler( public class MetropolisHastingsSampler<T>(
public val targetDist: (arg : Double) -> Double, public val algebra: Group<T>,
public val initialState : Double = 0.0, public val startPoint: T,
public val proposalStd : Double = 1.0, public val stepSampler: Sampler<T>,
) : BlockingDoubleSampler { public val targetPdf: suspend (T) -> Float64,
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { ) : Sampler<T> {
var currentState = initialState
fun proposalDist(arg : Double) = NormalDistribution(arg, proposalStd)
override fun nextBufferBlocking(size: Int): Float64Buffer { //TODO consider API for conditional step probability
val acceptanceProb = generator.nextDoubleBuffer(size)
return Float64Buffer(size) {index -> override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>(
val newState = proposalDist(currentState).sample(generator).nextBufferBlocking(1).get(0) state = stepSampler.sample(generator),
val acceptanceRatio = min(1.0, targetDist(newState) / targetDist(currentState)) seed = { startPoint },
forkState = Chain<T>::fork
currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState ) { previousPoint: T ->
currentState val proposalPoint = with(algebra) { previousPoint + next() }
val ratio = targetPdf(proposalPoint) / targetPdf(previousPoint)
if (ratio >= 1.0) {
proposalPoint
} else {
val acceptanceProbability = generator.nextDouble()
if (acceptanceProbability <= ratio) {
proposalPoint
} else {
previousPoint
}
} }
} }
override suspend fun fork(): BlockingDoubleChain = sample(generator.fork())
}
public companion object {
/**
* A Metropolis-Hastings sampler for univariate [Float64] values
*/
public fun univariate(
startPoint: Float64,
stepSampler: Sampler<Float64>,
targetPdf: suspend (Float64) -> Float64,
): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler(
algebra = Float64.algebra,
startPoint = startPoint,
stepSampler = stepSampler,
targetPdf = targetPdf
)
/**
* A Metropolis-Hastings sampler for univariate [Float64] values with normal step distribution
*/
public fun univariateNormal(
startPoint: Float64,
stepSigma: Float64,
targetPdf: suspend (Float64) -> Float64,
): MetropolisHastingsSampler<Double> = univariate(startPoint, GaussianSampler(0.0, stepSigma), targetPdf)
}
} }

View File

@ -17,7 +17,7 @@ import kotlin.jvm.JvmName
/** /**
* Sampler that generates chains of values of type [T]. * Sampler that generates chains of values of type [T].
*/ */
public fun interface Sampler<out T : Any> { public fun interface Sampler<out T> {
/** /**
* Generates a chain of samples. * Generates a chain of samples.
* *

View File

@ -4,73 +4,96 @@
*/ */
package space.kscience.kmath.samplers package space.kscience.kmath.samplers
import kotlinx.coroutines.test.runTest
import space.kscience.kmath.chains.discard
import space.kscience.kmath.chains.nextBuffer
import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.operations.Float64Field import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.random.DefaultGenerator import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.stat.invoke import space.kscience.kmath.stat.invoke
import space.kscience.kmath.stat.mean import space.kscience.kmath.stat.mean
import kotlin.math.PI
import kotlin.math.exp import kotlin.math.exp
import kotlin.math.pow import kotlin.math.pow
import kotlin.math.sqrt
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class TestMetropolisHastingsSampler { class TestMetropolisHastingsSampler {
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
val sample = 1e6.toInt()
val burnIn = sample / 5
@Test @Test
fun samplingNormalTest() { fun samplingNormalTest() = runTest {
fun normalDist1(arg : Double) = NormalDistribution(0.5, 1.0).probability(arg) val generator = RandomGenerator.default(1)
var sampler = MetropolisHastingsSampler(::normalDist1, proposalStd = 1.0)
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2) listOf(
TestSetup(0.5, 0.0),
TestSetup(68.13, 60.0),
).forEach {
val distribution = NormalDistribution(it.mean, 1.0)
val sampler = MetropolisHastingsSampler
.univariateNormal(it.startPoint, it.sigma, distribution::probability)
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
fun normalDist2(arg : Double) = NormalDistribution(68.13, 1.0).probability(arg) assertEquals(it.mean, Float64Field.mean(sampledValues), 1e-2)
sampler = MetropolisHastingsSampler(::normalDist2, initialState = 63.0, proposalStd = 1.0) }
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
assertEquals(68.13, Float64Field.mean(sampledValues), 1e-2)
} }
@Test @Test
fun samplingExponentialTest() { fun samplingExponentialTest() = runTest {
fun expDist(arg : Double, param : Double) : Double { val generator = RandomGenerator.default(1)
if (arg < 0.0) { return 0.0 }
return param * exp(-param * arg) fun expDist(lambda: Double, arg: Double): Double = if (arg <= 0.0) 0.0 else lambda * exp(-arg * lambda)
listOf(
TestSetup(0.5, 2.0),
TestSetup(2.0, 1.0)
).forEach { setup ->
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
expDist(setup.mean, it)
} }
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
fun expDist1(arg : Double) = expDist(arg, 0.5) assertEquals(1.0 / setup.mean, Float64Field.mean(sampledValues), 1e-2)
var sampler = MetropolisHastingsSampler(::expDist1, initialState = 2.0, proposalStd = 1.0) }
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
assertEquals(2.0, Float64Field.mean(sampledValues), 1e-2)
fun expDist2(arg : Double) = expDist(arg, 2.0)
sampler = MetropolisHastingsSampler(::expDist2, initialState = 9.0, proposalStd = 1.0)
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2)
} }
@Test @Test
fun samplingRayleighTest() { fun samplingRayleighTest() = runTest {
fun rayleighDist(arg : Double, sigma : Double) : Double { val generator = RandomGenerator.default(1)
if (arg < 0.0) { return 0.0 } fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
0.0
val expArg = (arg / sigma).pow(2) } else {
return arg * exp(-expArg / 2.0) / sigma.pow(2) arg * exp(-(arg / sigma).pow(2) / 2.0) / sigma.pow(2)
} }
fun rayleighDist1(arg : Double) = rayleighDist(arg, 1.0) listOf(
var sampler = MetropolisHastingsSampler(::rayleighDist1, initialState = 2.0, proposalStd = 1.0) TestSetup(0.5, 1.0),
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) TestSetup(2.0, 1.0)
).forEach { setup ->
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
rayleighDist(setup.mean, it)
}
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2) assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
}
fun rayleighDist2(arg : Double) = rayleighDist(arg, 2.0) //
sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0) // fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000) // var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2) //
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
//
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
//
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
} }
} }