Minor edits. Tests added. | STUD-7
This commit is contained in:
parent
d0d250c67f
commit
dbc5488eb2
@ -6,33 +6,39 @@
|
||||
package space.kscience.kmath.samplers
|
||||
|
||||
import space.kscience.kmath.chains.BlockingDoubleChain
|
||||
import space.kscience.kmath.distributions.Distribution1D
|
||||
import space.kscience.kmath.distributions.NormalDistribution
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.structures.Float64Buffer
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
* [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) for sampling
|
||||
* target distribution [distribution].
|
||||
* target distribution [targetDist].
|
||||
*
|
||||
* params:
|
||||
* - targetDist: function close to the density of the sampled distribution;
|
||||
* - initialState: initial value of the chain of sampled values.
|
||||
*/
|
||||
public class MetropolisHastingsSampler(
|
||||
public val distribution: (x : Double) -> Double,
|
||||
public val initState : Double = 0.0,
|
||||
public val targetDist: (arg : Double) -> Double,
|
||||
public val initialState : Double = 0.0,
|
||||
) : BlockingDoubleSampler {
|
||||
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
|
||||
var currState = initState
|
||||
var currentState = initialState
|
||||
fun proposalDist(arg : Double) = NormalDistribution(arg, 0.01)
|
||||
|
||||
override fun nextBufferBlocking(size: Int): Float64Buffer {
|
||||
val u = generator.nextDoubleBuffer(size)
|
||||
val acceptanceProb = generator.nextDoubleBuffer(size)
|
||||
|
||||
return Float64Buffer(size) {index ->
|
||||
val proposalDist = NormalDistribution(currState, 0.01)
|
||||
val newState = proposalDist.sample(RandomGenerator.default(1)).nextBufferBlocking(1).get(0)
|
||||
val acceptanceRatio = distribution(newState) / distribution(currState)
|
||||
if (u[index] <= acceptanceRatio) {
|
||||
currState = newState
|
||||
}
|
||||
currState
|
||||
val newState = proposalDist(currentState).sample(RandomGenerator.default(0)).nextBufferBlocking(5).get(4)
|
||||
val firstComp = targetDist(newState) / targetDist(currentState)
|
||||
val secondComp = proposalDist(newState).probability(currentState) / proposalDist(currentState).probability(newState)
|
||||
val acceptanceRatio = min(1.0, firstComp * secondComp)
|
||||
|
||||
currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState
|
||||
currentState
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,26 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.samplers
|
||||
import space.kscience.kmath.distributions.NormalDistribution
|
||||
import space.kscience.kmath.operations.Float64Field
|
||||
import space.kscience.kmath.random.DefaultGenerator
|
||||
import space.kscience.kmath.stat.invoke
|
||||
import space.kscience.kmath.stat.mean
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class TestMetropolisHastingsSampler {
|
||||
|
||||
@Test
|
||||
fun samplingNormalTest1() {
|
||||
fun myDist(arg : Double) = NormalDistribution(0.0, 1.0).probability(arg)
|
||||
val sampler = MetropolisHastingsSampler(::myDist)
|
||||
|
||||
|
||||
val sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10)
|
||||
assertEquals(0.05, Float64Field.mean(sampledValues), 0.01)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user