From 57d1cd8c87a43aa5406459a654b8d17316805acc Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 4 Aug 2024 15:01:47 +0300 Subject: [PATCH] Add Metropolis-Hastings sampler. Minor fixes. --- CHANGELOG.md | 1 + gradle/libs.versions.toml | 9 ++ .../kscience/kmath/chains/BlockingChain.kt | 4 +- .../space/kscience/kmath/chains/Chain.kt | 13 +- kmath-stat/build.gradle.kts | 18 ++- .../samplers/MetropolisHastingsSampler.kt | 86 +++++++++----- .../space/kscience/kmath/samplers/Sampler.kt | 2 +- .../samplers/TestMetropolisHastingsSampler.kt | 111 +++++++++++------- 8 files changed, 152 insertions(+), 92 deletions(-) create mode 100644 gradle/libs.versions.toml diff --git a/CHANGELOG.md b/CHANGELOG.md index 6bf936c10..6c2e987f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Unreleased ### Added +- Metropolis-Hastings sampler ### Changed diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml new file mode 100644 index 000000000..9796ed176 --- /dev/null +++ b/gradle/libs.versions.toml @@ -0,0 +1,9 @@ +[versions] + +commons-rng = "1.6" + + +[libraries] + +commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"} +commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"} \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt index f843e8af8..d58da67a5 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt @@ -40,13 +40,13 @@ public interface BlockingBufferChain : BlockingChain, BufferChain { } -public suspend inline fun Chain.nextBuffer(size: Int): Buffer = if (this is BufferChain) { +public suspend inline fun Chain.nextBuffer(size: Int): Buffer = if (this is BufferChain) { nextBuffer(size) } else { Buffer(size) { next() } } -public inline fun BlockingChain.nextBufferBlocking( +public inline fun BlockingChain.nextBufferBlocking( size: Int, ): Buffer = if (this is BlockingBufferChain) { nextBufferBlocking(size) diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt index 0158c6752..3927795e3 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt @@ -23,7 +23,7 @@ public interface Chain : Flow { public suspend fun next(): T /** - * Create a copy of current chain state. Consuming resulting chain does not affect initial chain. + * Create a copy of the current chain state. Consuming the resulting chain does not affect the initial chain. */ public suspend fun fork(): Chain @@ -47,7 +47,7 @@ public class SimpleChain(private val gen: suspend () -> R) : Chain { /** * A stateless Markov chain */ -public class MarkovChain(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain { +public class MarkovChain(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain { private val mutex: Mutex = Mutex() private var value: R? = null @@ -71,8 +71,8 @@ public class MarkovChain(private val seed: suspend () -> R, private */ public class StatefulChain( private val state: S, - private val seed: S.() -> R, - private val forkState: ((S) -> S), + private val seed: suspend S.() -> R, + private val forkState: suspend ((S) -> S), private val gen: suspend S.(R) -> R, ) : Chain { private val mutex: Mutex = Mutex() @@ -97,6 +97,11 @@ public class ConstantChain(public val value: T) : Chain { override suspend fun fork(): Chain = this } +/** + * Discard a fixed number of samples + */ +public suspend inline fun Chain.discard(number: Int): Chain = apply { nextBuffer(number) } + /** * Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed * since mapped chain consumes tokens. Accepts regular transformation function. diff --git a/kmath-stat/build.gradle.kts b/kmath-stat/build.gradle.kts index 8116a8925..5552ed9a6 100644 --- a/kmath-stat/build.gradle.kts +++ b/kmath-stat/build.gradle.kts @@ -7,21 +7,17 @@ kscience { js() native() wasm() -} -kotlin.sourceSets { + useCoroutines() + commonMain { - dependencies { - api(projects.kmathCoroutines) - //implementation(spclibs.atomicfu) - } + api(projects.kmathCoroutines) + //implementation(spclibs.atomicfu) } - getByName("jvmMain") { - dependencies { - api("org.apache.commons:commons-rng-sampling:1.3") - api("org.apache.commons:commons-rng-simple:1.3") - } + jvmMain { + api(libs.commons.rng.simple) + api(libs.commons.rng.sampling) } } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt index d8ded96eb..4e00bbad2 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt @@ -5,46 +5,72 @@ package space.kscience.kmath.samplers -import space.kscience.kmath.chains.BlockingDoubleChain -import space.kscience.kmath.distributions.Distribution1D -import space.kscience.kmath.distributions.NormalDistribution +import space.kscience.kmath.chains.Chain +import space.kscience.kmath.chains.StatefulChain +import space.kscience.kmath.operations.Group +import space.kscience.kmath.operations.algebra import space.kscience.kmath.random.RandomGenerator -import space.kscience.kmath.structures.Float64Buffer -import kotlin.math.* +import space.kscience.kmath.stat.Sampler +import space.kscience.kmath.structures.Float64 /** * [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling - * target distribution [targetDist]. + * target distribution [targetPdf]. * - * The normal distribution is used as the proposal function. - * - * params: - * - targetDist: function close to the density of the sampled distribution; - * - initialState: initial value of the chain of sampled values; - * - proposalStd: standard deviation of the proposal function. + * @param stepSampler a sampler for proposal point (offset to previous point) */ -public class MetropolisHastingsSampler( - public val targetDist: (arg : Double) -> Double, - public val initialState : Double = 0.0, - public val proposalStd : Double = 1.0, -) : BlockingDoubleSampler { - override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { - var currentState = initialState - fun proposalDist(arg : Double) = NormalDistribution(arg, proposalStd) +public class MetropolisHastingsSampler( + public val algebra: Group, + public val startPoint: T, + public val stepSampler: Sampler, + public val targetPdf: suspend (T) -> Float64, +) : Sampler { - override fun nextBufferBlocking(size: Int): Float64Buffer { - val acceptanceProb = generator.nextDoubleBuffer(size) + //TODO consider API for conditional step probability - return Float64Buffer(size) {index -> - val newState = proposalDist(currentState).sample(generator).nextBufferBlocking(1).get(0) - val acceptanceRatio = min(1.0, targetDist(newState) / targetDist(currentState)) - - currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState - currentState + override fun sample(generator: RandomGenerator): Chain = StatefulChain, T>( + state = stepSampler.sample(generator), + seed = { startPoint }, + forkState = Chain::fork + ) { previousPoint: T -> + val proposalPoint = with(algebra) { previousPoint + next() } + val ratio = targetPdf(proposalPoint) / targetPdf(previousPoint) + if (ratio >= 1.0) { + proposalPoint + } else { + val acceptanceProbability = generator.nextDouble() + if (acceptanceProbability <= ratio) { + proposalPoint + } else { + previousPoint } } - - override suspend fun fork(): BlockingDoubleChain = sample(generator.fork()) } + + public companion object { + + /** + * A Metropolis-Hastings sampler for univariate [Float64] values + */ + public fun univariate( + startPoint: Float64, + stepSampler: Sampler, + targetPdf: suspend (Float64) -> Float64, + ): MetropolisHastingsSampler = MetropolisHastingsSampler( + algebra = Float64.algebra, + startPoint = startPoint, + stepSampler = stepSampler, + targetPdf = targetPdf + ) + + /** + * A Metropolis-Hastings sampler for univariate [Float64] values with normal step distribution + */ + public fun univariateNormal( + startPoint: Float64, + stepSigma: Float64, + targetPdf: suspend (Float64) -> Float64, + ): MetropolisHastingsSampler = univariate(startPoint, GaussianSampler(0.0, stepSigma), targetPdf) + } } \ No newline at end of file diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/Sampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/Sampler.kt index 44e099a52..9bc1a6bc2 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/Sampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/Sampler.kt @@ -17,7 +17,7 @@ import kotlin.jvm.JvmName /** * Sampler that generates chains of values of type [T]. */ -public fun interface Sampler { +public fun interface Sampler { /** * Generates a chain of samples. * diff --git a/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt index 9ab44f87d..b03b0c4c1 100644 --- a/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt +++ b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt @@ -4,73 +4,96 @@ */ package space.kscience.kmath.samplers + +import kotlinx.coroutines.test.runTest +import space.kscience.kmath.chains.discard +import space.kscience.kmath.chains.nextBuffer import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.operations.Float64Field -import space.kscience.kmath.random.DefaultGenerator +import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.stat.invoke import space.kscience.kmath.stat.mean +import kotlin.math.PI import kotlin.math.exp import kotlin.math.pow +import kotlin.math.sqrt import kotlin.test.Test import kotlin.test.assertEquals class TestMetropolisHastingsSampler { + data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5) + + val sample = 1e6.toInt() + val burnIn = sample / 5 + @Test - fun samplingNormalTest() { - fun normalDist1(arg : Double) = NormalDistribution(0.5, 1.0).probability(arg) - var sampler = MetropolisHastingsSampler(::normalDist1, proposalStd = 1.0) - var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) + fun samplingNormalTest() = runTest { + val generator = RandomGenerator.default(1) - assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2) + listOf( + TestSetup(0.5, 0.0), + TestSetup(68.13, 60.0), + ).forEach { + val distribution = NormalDistribution(it.mean, 1.0) + val sampler = MetropolisHastingsSampler + .univariateNormal(it.startPoint, it.sigma, distribution::probability) + val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample) - fun normalDist2(arg : Double) = NormalDistribution(68.13, 1.0).probability(arg) - sampler = MetropolisHastingsSampler(::normalDist2, initialState = 63.0, proposalStd = 1.0) - sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) - - assertEquals(68.13, Float64Field.mean(sampledValues), 1e-2) + assertEquals(it.mean, Float64Field.mean(sampledValues), 1e-2) + } } @Test - fun samplingExponentialTest() { - fun expDist(arg : Double, param : Double) : Double { - if (arg < 0.0) { return 0.0 } - return param * exp(-param * arg) + fun samplingExponentialTest() = runTest { + val generator = RandomGenerator.default(1) + + fun expDist(lambda: Double, arg: Double): Double = if (arg <= 0.0) 0.0 else lambda * exp(-arg * lambda) + + listOf( + TestSetup(0.5, 2.0), + TestSetup(2.0, 1.0) + ).forEach { setup -> + val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) { + expDist(setup.mean, it) + } + val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample) + + assertEquals(1.0 / setup.mean, Float64Field.mean(sampledValues), 1e-2) } - - fun expDist1(arg : Double) = expDist(arg, 0.5) - var sampler = MetropolisHastingsSampler(::expDist1, initialState = 2.0, proposalStd = 1.0) - var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) - - assertEquals(2.0, Float64Field.mean(sampledValues), 1e-2) - - fun expDist2(arg : Double) = expDist(arg, 2.0) - sampler = MetropolisHastingsSampler(::expDist2, initialState = 9.0, proposalStd = 1.0) - sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) - - assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2) - } @Test - fun samplingRayleighTest() { - fun rayleighDist(arg : Double, sigma : Double) : Double { - if (arg < 0.0) { return 0.0 } - - val expArg = (arg / sigma).pow(2) - return arg * exp(-expArg / 2.0) / sigma.pow(2) + fun samplingRayleighTest() = runTest { + val generator = RandomGenerator.default(1) + fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) { + 0.0 + } else { + arg * exp(-(arg / sigma).pow(2) / 2.0) / sigma.pow(2) } - fun rayleighDist1(arg : Double) = rayleighDist(arg, 1.0) - var sampler = MetropolisHastingsSampler(::rayleighDist1, initialState = 2.0, proposalStd = 1.0) - var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) + listOf( + TestSetup(0.5, 1.0), + TestSetup(2.0, 1.0) + ).forEach { setup -> + val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) { + rayleighDist(setup.mean, it) + } + val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample) - assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2) - - fun rayleighDist2(arg : Double) = rayleighDist(arg, 2.0) - sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0) - sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000) - - assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2) + assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2) + } +// +// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0) +// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0) +// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) +// +// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2) +// +// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0) +// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0) +// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000) +// +// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2) } } \ No newline at end of file