From d0d250c67fab9cbe54f88b59b7b679e3d4c53fe7 Mon Sep 17 00:00:00 2001 From: "vasilev.ilia" Date: Mon, 22 Apr 2024 21:57:24 +0300 Subject: [PATCH 1/3] MHS first implementation | STUD-7 --- .../samplers/MetropolisHastingsSampler.kt | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt new file mode 100644 index 000000000..29fea9c13 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt @@ -0,0 +1,42 @@ +/* + * Copyright 2018-2024 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingDoubleChain +import space.kscience.kmath.distributions.NormalDistribution +import space.kscience.kmath.random.RandomGenerator +import space.kscience.kmath.structures.Float64Buffer + +/** + * [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) for sampling + * target distribution [distribution]. + * + */ +public class MetropolisHastingsSampler( + public val distribution: (x : Double) -> Double, + public val initState : Double = 0.0, +) : BlockingDoubleSampler { + override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { + var currState = initState + + override fun nextBufferBlocking(size: Int): Float64Buffer { + val u = generator.nextDoubleBuffer(size) + + return Float64Buffer(size) {index -> + val proposalDist = NormalDistribution(currState, 0.01) + val newState = proposalDist.sample(RandomGenerator.default(1)).nextBufferBlocking(1).get(0) + val acceptanceRatio = distribution(newState) / distribution(currState) + if (u[index] <= acceptanceRatio) { + currState = newState + } + currState + } + } + + override suspend fun fork(): BlockingDoubleChain = BoxMullerSampler.sample(generator.fork()) + } + +} \ No newline at end of file From dbc5488eb2af2842618dbd8dc754b1f3c8a213ca Mon Sep 17 00:00:00 2001 From: "vasilev.ilia" Date: Wed, 24 Apr 2024 23:29:14 +0300 Subject: [PATCH 2/3] Minor edits. Tests added. | STUD-7 --- .../samplers/MetropolisHastingsSampler.kt | 30 +++++++++++-------- .../samplers/TestMetropolisHastingsSampler.kt | 26 ++++++++++++++++ 2 files changed, 44 insertions(+), 12 deletions(-) create mode 100644 kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt index 29fea9c13..74dfa822f 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt @@ -6,33 +6,39 @@ package space.kscience.kmath.samplers import space.kscience.kmath.chains.BlockingDoubleChain +import space.kscience.kmath.distributions.Distribution1D import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.structures.Float64Buffer +import kotlin.math.* /** * [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) for sampling - * target distribution [distribution]. + * target distribution [targetDist]. * + * params: + * - targetDist: function close to the density of the sampled distribution; + * - initialState: initial value of the chain of sampled values. */ public class MetropolisHastingsSampler( - public val distribution: (x : Double) -> Double, - public val initState : Double = 0.0, + public val targetDist: (arg : Double) -> Double, + public val initialState : Double = 0.0, ) : BlockingDoubleSampler { override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { - var currState = initState + var currentState = initialState + fun proposalDist(arg : Double) = NormalDistribution(arg, 0.01) override fun nextBufferBlocking(size: Int): Float64Buffer { - val u = generator.nextDoubleBuffer(size) + val acceptanceProb = generator.nextDoubleBuffer(size) return Float64Buffer(size) {index -> - val proposalDist = NormalDistribution(currState, 0.01) - val newState = proposalDist.sample(RandomGenerator.default(1)).nextBufferBlocking(1).get(0) - val acceptanceRatio = distribution(newState) / distribution(currState) - if (u[index] <= acceptanceRatio) { - currState = newState - } - currState + val newState = proposalDist(currentState).sample(RandomGenerator.default(0)).nextBufferBlocking(5).get(4) + val firstComp = targetDist(newState) / targetDist(currentState) + val secondComp = proposalDist(newState).probability(currentState) / proposalDist(currentState).probability(newState) + val acceptanceRatio = min(1.0, firstComp * secondComp) + + currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState + currentState } } diff --git a/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt new file mode 100644 index 000000000..7bc1eb877 --- /dev/null +++ b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt @@ -0,0 +1,26 @@ +/* + * Copyright 2018-2024 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.samplers +import space.kscience.kmath.distributions.NormalDistribution +import space.kscience.kmath.operations.Float64Field +import space.kscience.kmath.random.DefaultGenerator +import space.kscience.kmath.stat.invoke +import space.kscience.kmath.stat.mean +import kotlin.test.Test +import kotlin.test.assertEquals + +class TestMetropolisHastingsSampler { + + @Test + fun samplingNormalTest1() { + fun myDist(arg : Double) = NormalDistribution(0.0, 1.0).probability(arg) + val sampler = MetropolisHastingsSampler(::myDist) + + + val sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10) + assertEquals(0.05, Float64Field.mean(sampledValues), 0.01) + } +} \ No newline at end of file From 3417d8cdc188a8151f2a3c90c3c9a40c515bdc41 Mon Sep 17 00:00:00 2001 From: "vasilev.ilya" Date: Mon, 20 May 2024 01:12:27 +0300 Subject: [PATCH 3/3] Minor edits. Tests added. | STUD-7 --- .../samplers/MetropolisHastingsSampler.kt | 18 +++--- .../samplers/TestMetropolisHastingsSampler.kt | 60 +++++++++++++++++-- 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt index 74dfa822f..d8ded96eb 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt @@ -13,36 +13,38 @@ import space.kscience.kmath.structures.Float64Buffer import kotlin.math.* /** - * [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) for sampling + * [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling * target distribution [targetDist]. * + * The normal distribution is used as the proposal function. + * * params: * - targetDist: function close to the density of the sampled distribution; - * - initialState: initial value of the chain of sampled values. + * - initialState: initial value of the chain of sampled values; + * - proposalStd: standard deviation of the proposal function. */ public class MetropolisHastingsSampler( public val targetDist: (arg : Double) -> Double, public val initialState : Double = 0.0, + public val proposalStd : Double = 1.0, ) : BlockingDoubleSampler { override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { var currentState = initialState - fun proposalDist(arg : Double) = NormalDistribution(arg, 0.01) + fun proposalDist(arg : Double) = NormalDistribution(arg, proposalStd) override fun nextBufferBlocking(size: Int): Float64Buffer { val acceptanceProb = generator.nextDoubleBuffer(size) return Float64Buffer(size) {index -> - val newState = proposalDist(currentState).sample(RandomGenerator.default(0)).nextBufferBlocking(5).get(4) - val firstComp = targetDist(newState) / targetDist(currentState) - val secondComp = proposalDist(newState).probability(currentState) / proposalDist(currentState).probability(newState) - val acceptanceRatio = min(1.0, firstComp * secondComp) + val newState = proposalDist(currentState).sample(generator).nextBufferBlocking(1).get(0) + val acceptanceRatio = min(1.0, targetDist(newState) / targetDist(currentState)) currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState currentState } } - override suspend fun fork(): BlockingDoubleChain = BoxMullerSampler.sample(generator.fork()) + override suspend fun fork(): BlockingDoubleChain = sample(generator.fork()) } } \ No newline at end of file diff --git a/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt index 7bc1eb877..9ab44f87d 100644 --- a/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt +++ b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt @@ -9,18 +9,68 @@ import space.kscience.kmath.operations.Float64Field import space.kscience.kmath.random.DefaultGenerator import space.kscience.kmath.stat.invoke import space.kscience.kmath.stat.mean +import kotlin.math.exp +import kotlin.math.pow import kotlin.test.Test import kotlin.test.assertEquals class TestMetropolisHastingsSampler { @Test - fun samplingNormalTest1() { - fun myDist(arg : Double) = NormalDistribution(0.0, 1.0).probability(arg) - val sampler = MetropolisHastingsSampler(::myDist) + fun samplingNormalTest() { + fun normalDist1(arg : Double) = NormalDistribution(0.5, 1.0).probability(arg) + var sampler = MetropolisHastingsSampler(::normalDist1, proposalStd = 1.0) + var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) + assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2) - val sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10) - assertEquals(0.05, Float64Field.mean(sampledValues), 0.01) + fun normalDist2(arg : Double) = NormalDistribution(68.13, 1.0).probability(arg) + sampler = MetropolisHastingsSampler(::normalDist2, initialState = 63.0, proposalStd = 1.0) + sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) + + assertEquals(68.13, Float64Field.mean(sampledValues), 1e-2) + } + + @Test + fun samplingExponentialTest() { + fun expDist(arg : Double, param : Double) : Double { + if (arg < 0.0) { return 0.0 } + return param * exp(-param * arg) + } + + fun expDist1(arg : Double) = expDist(arg, 0.5) + var sampler = MetropolisHastingsSampler(::expDist1, initialState = 2.0, proposalStd = 1.0) + var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) + + assertEquals(2.0, Float64Field.mean(sampledValues), 1e-2) + + fun expDist2(arg : Double) = expDist(arg, 2.0) + sampler = MetropolisHastingsSampler(::expDist2, initialState = 9.0, proposalStd = 1.0) + sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) + + assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2) + + } + + @Test + fun samplingRayleighTest() { + fun rayleighDist(arg : Double, sigma : Double) : Double { + if (arg < 0.0) { return 0.0 } + + val expArg = (arg / sigma).pow(2) + return arg * exp(-expArg / 2.0) / sigma.pow(2) + } + + fun rayleighDist1(arg : Double) = rayleighDist(arg, 1.0) + var sampler = MetropolisHastingsSampler(::rayleighDist1, initialState = 2.0, proposalStd = 1.0) + var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000) + + assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2) + + fun rayleighDist2(arg : Double) = rayleighDist(arg, 2.0) + sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0) + sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000) + + assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2) } } \ No newline at end of file