Add Metropolis-Hastings sampler.

Minor fixes.
This commit is contained in:
Alexander Nozik 2024-08-04 21:26:51 +03:00
parent 57d1cd8c87
commit 48d0ee8126

View File

@ -24,8 +24,8 @@ class TestMetropolisHastingsSampler {
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5) data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
val sample = 1e6.toInt() private val sample = 1e6.toInt()
val burnIn = sample / 5 private val burnIn = sample / 5
@Test @Test
fun samplingNormalTest() = runTest { fun samplingNormalTest() = runTest {
@ -66,6 +66,7 @@ class TestMetropolisHastingsSampler {
@Test @Test
fun samplingRayleighTest() = runTest { fun samplingRayleighTest() = runTest {
val generator = RandomGenerator.default(1) val generator = RandomGenerator.default(1)
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) { fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
0.0 0.0
} else { } else {
@ -83,17 +84,5 @@ class TestMetropolisHastingsSampler {
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2) assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
} }
//
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
//
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
//
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
//
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
} }
} }