Minor edits. Tests added. | STUD-7

This commit is contained in:
vasilev.ilia 2024-04-24 23:29:14 +03:00
parent d0d250c67f
commit dbc5488eb2
2 changed files with 44 additions and 12 deletions

View File

@ -6,33 +6,39 @@
package space.kscience.kmath.samplers package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.distributions.Distribution1D
import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.structures.Float64Buffer import space.kscience.kmath.structures.Float64Buffer
import kotlin.math.*
/** /**
* [MetropolisHastings algorithm](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) for sampling * [MetropolisHastings algorithm](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) for sampling
* target distribution [distribution]. * target distribution [targetDist].
* *
* params:
* - targetDist: function close to the density of the sampled distribution;
* - initialState: initial value of the chain of sampled values.
*/ */
public class MetropolisHastingsSampler( public class MetropolisHastingsSampler(
public val distribution: (x : Double) -> Double, public val targetDist: (arg : Double) -> Double,
public val initState : Double = 0.0, public val initialState : Double = 0.0,
) : BlockingDoubleSampler { ) : BlockingDoubleSampler {
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
var currState = initState var currentState = initialState
fun proposalDist(arg : Double) = NormalDistribution(arg, 0.01)
override fun nextBufferBlocking(size: Int): Float64Buffer { override fun nextBufferBlocking(size: Int): Float64Buffer {
val u = generator.nextDoubleBuffer(size) val acceptanceProb = generator.nextDoubleBuffer(size)
return Float64Buffer(size) {index -> return Float64Buffer(size) {index ->
val proposalDist = NormalDistribution(currState, 0.01) val newState = proposalDist(currentState).sample(RandomGenerator.default(0)).nextBufferBlocking(5).get(4)
val newState = proposalDist.sample(RandomGenerator.default(1)).nextBufferBlocking(1).get(0) val firstComp = targetDist(newState) / targetDist(currentState)
val acceptanceRatio = distribution(newState) / distribution(currState) val secondComp = proposalDist(newState).probability(currentState) / proposalDist(currentState).probability(newState)
if (u[index] <= acceptanceRatio) { val acceptanceRatio = min(1.0, firstComp * secondComp)
currState = newState
} currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState
currState currentState
} }
} }

View File

@ -0,0 +1,26 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.samplers
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.random.DefaultGenerator
import space.kscience.kmath.stat.invoke
import space.kscience.kmath.stat.mean
import kotlin.test.Test
import kotlin.test.assertEquals
class TestMetropolisHastingsSampler {
@Test
fun samplingNormalTest1() {
fun myDist(arg : Double) = NormalDistribution(0.0, 1.0).probability(arg)
val sampler = MetropolisHastingsSampler(::myDist)
val sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10)
assertEquals(0.05, Float64Field.mean(sampledValues), 0.01)
}
}