forked from kscience/kmath
Minor edits. Tests added. | STUD-7
This commit is contained in:
parent
dbc5488eb2
commit
3417d8cdc1
@ -13,36 +13,38 @@ import space.kscience.kmath.structures.Float64Buffer
|
|||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) for sampling
|
* [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling
|
||||||
* target distribution [targetDist].
|
* target distribution [targetDist].
|
||||||
*
|
*
|
||||||
|
* The normal distribution is used as the proposal function.
|
||||||
|
*
|
||||||
* params:
|
* params:
|
||||||
* - targetDist: function close to the density of the sampled distribution;
|
* - targetDist: function close to the density of the sampled distribution;
|
||||||
* - initialState: initial value of the chain of sampled values.
|
* - initialState: initial value of the chain of sampled values;
|
||||||
|
* - proposalStd: standard deviation of the proposal function.
|
||||||
*/
|
*/
|
||||||
public class MetropolisHastingsSampler(
|
public class MetropolisHastingsSampler(
|
||||||
public val targetDist: (arg : Double) -> Double,
|
public val targetDist: (arg : Double) -> Double,
|
||||||
public val initialState : Double = 0.0,
|
public val initialState : Double = 0.0,
|
||||||
|
public val proposalStd : Double = 1.0,
|
||||||
) : BlockingDoubleSampler {
|
) : BlockingDoubleSampler {
|
||||||
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
|
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
|
||||||
var currentState = initialState
|
var currentState = initialState
|
||||||
fun proposalDist(arg : Double) = NormalDistribution(arg, 0.01)
|
fun proposalDist(arg : Double) = NormalDistribution(arg, proposalStd)
|
||||||
|
|
||||||
override fun nextBufferBlocking(size: Int): Float64Buffer {
|
override fun nextBufferBlocking(size: Int): Float64Buffer {
|
||||||
val acceptanceProb = generator.nextDoubleBuffer(size)
|
val acceptanceProb = generator.nextDoubleBuffer(size)
|
||||||
|
|
||||||
return Float64Buffer(size) {index ->
|
return Float64Buffer(size) {index ->
|
||||||
val newState = proposalDist(currentState).sample(RandomGenerator.default(0)).nextBufferBlocking(5).get(4)
|
val newState = proposalDist(currentState).sample(generator).nextBufferBlocking(1).get(0)
|
||||||
val firstComp = targetDist(newState) / targetDist(currentState)
|
val acceptanceRatio = min(1.0, targetDist(newState) / targetDist(currentState))
|
||||||
val secondComp = proposalDist(newState).probability(currentState) / proposalDist(currentState).probability(newState)
|
|
||||||
val acceptanceRatio = min(1.0, firstComp * secondComp)
|
|
||||||
|
|
||||||
currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState
|
currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState
|
||||||
currentState
|
currentState
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun fork(): BlockingDoubleChain = BoxMullerSampler.sample(generator.fork())
|
override suspend fun fork(): BlockingDoubleChain = sample(generator.fork())
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -9,18 +9,68 @@ import space.kscience.kmath.operations.Float64Field
|
|||||||
import space.kscience.kmath.random.DefaultGenerator
|
import space.kscience.kmath.random.DefaultGenerator
|
||||||
import space.kscience.kmath.stat.invoke
|
import space.kscience.kmath.stat.invoke
|
||||||
import space.kscience.kmath.stat.mean
|
import space.kscience.kmath.stat.mean
|
||||||
|
import kotlin.math.exp
|
||||||
|
import kotlin.math.pow
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class TestMetropolisHastingsSampler {
|
class TestMetropolisHastingsSampler {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun samplingNormalTest1() {
|
fun samplingNormalTest() {
|
||||||
fun myDist(arg : Double) = NormalDistribution(0.0, 1.0).probability(arg)
|
fun normalDist1(arg : Double) = NormalDistribution(0.5, 1.0).probability(arg)
|
||||||
val sampler = MetropolisHastingsSampler(::myDist)
|
var sampler = MetropolisHastingsSampler(::normalDist1, proposalStd = 1.0)
|
||||||
|
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
||||||
|
|
||||||
|
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
|
||||||
val sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10)
|
fun normalDist2(arg : Double) = NormalDistribution(68.13, 1.0).probability(arg)
|
||||||
assertEquals(0.05, Float64Field.mean(sampledValues), 0.01)
|
sampler = MetropolisHastingsSampler(::normalDist2, initialState = 63.0, proposalStd = 1.0)
|
||||||
|
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
||||||
|
|
||||||
|
assertEquals(68.13, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun samplingExponentialTest() {
|
||||||
|
fun expDist(arg : Double, param : Double) : Double {
|
||||||
|
if (arg < 0.0) { return 0.0 }
|
||||||
|
return param * exp(-param * arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun expDist1(arg : Double) = expDist(arg, 0.5)
|
||||||
|
var sampler = MetropolisHastingsSampler(::expDist1, initialState = 2.0, proposalStd = 1.0)
|
||||||
|
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
||||||
|
|
||||||
|
assertEquals(2.0, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
|
||||||
|
fun expDist2(arg : Double) = expDist(arg, 2.0)
|
||||||
|
sampler = MetropolisHastingsSampler(::expDist2, initialState = 9.0, proposalStd = 1.0)
|
||||||
|
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
||||||
|
|
||||||
|
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun samplingRayleighTest() {
|
||||||
|
fun rayleighDist(arg : Double, sigma : Double) : Double {
|
||||||
|
if (arg < 0.0) { return 0.0 }
|
||||||
|
|
||||||
|
val expArg = (arg / sigma).pow(2)
|
||||||
|
return arg * exp(-expArg / 2.0) / sigma.pow(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun rayleighDist1(arg : Double) = rayleighDist(arg, 1.0)
|
||||||
|
var sampler = MetropolisHastingsSampler(::rayleighDist1, initialState = 2.0, proposalStd = 1.0)
|
||||||
|
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
||||||
|
|
||||||
|
assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
|
||||||
|
fun rayleighDist2(arg : Double) = rayleighDist(arg, 2.0)
|
||||||
|
sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
|
||||||
|
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
|
||||||
|
|
||||||
|
assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user