forked from kscience/kmath
Add Metropolis-Hastings sampler.
Minor fixes.
This commit is contained in:
parent
57d1cd8c87
commit
48d0ee8126
@ -24,8 +24,8 @@ class TestMetropolisHastingsSampler {
|
|||||||
|
|
||||||
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
|
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
|
||||||
|
|
||||||
val sample = 1e6.toInt()
|
private val sample = 1e6.toInt()
|
||||||
val burnIn = sample / 5
|
private val burnIn = sample / 5
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun samplingNormalTest() = runTest {
|
fun samplingNormalTest() = runTest {
|
||||||
@ -66,6 +66,7 @@ class TestMetropolisHastingsSampler {
|
|||||||
@Test
|
@Test
|
||||||
fun samplingRayleighTest() = runTest {
|
fun samplingRayleighTest() = runTest {
|
||||||
val generator = RandomGenerator.default(1)
|
val generator = RandomGenerator.default(1)
|
||||||
|
|
||||||
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
|
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
|
||||||
0.0
|
0.0
|
||||||
} else {
|
} else {
|
||||||
@ -83,17 +84,5 @@ class TestMetropolisHastingsSampler {
|
|||||||
|
|
||||||
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
||||||
}
|
}
|
||||||
//
|
|
||||||
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
|
|
||||||
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
|
|
||||||
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
|
||||||
//
|
|
||||||
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
|
|
||||||
//
|
|
||||||
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
|
|
||||||
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
|
|
||||||
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
|
|
||||||
//
|
|
||||||
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user