Compare commits
5 Commits
3e8f44166c
...
57d1cd8c87
Author | SHA1 | Date | |
---|---|---|---|
57d1cd8c87 | |||
|
e0997ccf9c | ||
|
3417d8cdc1 | ||
dbc5488eb2 | |||
d0d250c67f |
@ -3,6 +3,7 @@
|
|||||||
## Unreleased
|
## Unreleased
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
- Metropolis-Hastings sampler
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
|
9
gradle/libs.versions.toml
Normal file
9
gradle/libs.versions.toml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
[versions]
|
||||||
|
|
||||||
|
commons-rng = "1.6"
|
||||||
|
|
||||||
|
|
||||||
|
[libraries]
|
||||||
|
|
||||||
|
commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"}
|
||||||
|
commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"}
|
@ -40,13 +40,13 @@ public interface BlockingBufferChain<out T> : BlockingChain<T>, BufferChain<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public suspend inline fun <reified T : Any> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
|
public suspend inline fun <reified T> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
|
||||||
nextBuffer(size)
|
nextBuffer(size)
|
||||||
} else {
|
} else {
|
||||||
Buffer(size) { next() }
|
Buffer(size) { next() }
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <reified T : Any> BlockingChain<T>.nextBufferBlocking(
|
public inline fun <reified T> BlockingChain<T>.nextBufferBlocking(
|
||||||
size: Int,
|
size: Int,
|
||||||
): Buffer<T> = if (this is BlockingBufferChain) {
|
): Buffer<T> = if (this is BlockingBufferChain) {
|
||||||
nextBufferBlocking(size)
|
nextBufferBlocking(size)
|
||||||
|
@ -23,7 +23,7 @@ public interface Chain<out T> : Flow<T> {
|
|||||||
public suspend fun next(): T
|
public suspend fun next(): T
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a copy of current chain state. Consuming resulting chain does not affect initial chain.
|
* Create a copy of the current chain state. Consuming the resulting chain does not affect the initial chain.
|
||||||
*/
|
*/
|
||||||
public suspend fun fork(): Chain<T>
|
public suspend fun fork(): Chain<T>
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ public class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
|||||||
/**
|
/**
|
||||||
* A stateless Markov chain
|
* A stateless Markov chain
|
||||||
*/
|
*/
|
||||||
public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
public class MarkovChain<out R>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
||||||
private val mutex: Mutex = Mutex()
|
private val mutex: Mutex = Mutex()
|
||||||
private var value: R? = null
|
private var value: R? = null
|
||||||
|
|
||||||
@ -71,8 +71,8 @@ public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private
|
|||||||
*/
|
*/
|
||||||
public class StatefulChain<S, out R>(
|
public class StatefulChain<S, out R>(
|
||||||
private val state: S,
|
private val state: S,
|
||||||
private val seed: S.() -> R,
|
private val seed: suspend S.() -> R,
|
||||||
private val forkState: ((S) -> S),
|
private val forkState: suspend ((S) -> S),
|
||||||
private val gen: suspend S.(R) -> R,
|
private val gen: suspend S.(R) -> R,
|
||||||
) : Chain<R> {
|
) : Chain<R> {
|
||||||
private val mutex: Mutex = Mutex()
|
private val mutex: Mutex = Mutex()
|
||||||
@ -97,6 +97,11 @@ public class ConstantChain<out T>(public val value: T) : Chain<T> {
|
|||||||
override suspend fun fork(): Chain<T> = this
|
override suspend fun fork(): Chain<T> = this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discard a fixed number of samples
|
||||||
|
*/
|
||||||
|
public suspend inline fun <reified T> Chain<T>.discard(number: Int): Chain<T> = apply { nextBuffer<T>(number) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||||
* since mapped chain consumes tokens. Accepts regular transformation function.
|
* since mapped chain consumes tokens. Accepts regular transformation function.
|
||||||
|
@ -7,21 +7,17 @@ kscience {
|
|||||||
js()
|
js()
|
||||||
native()
|
native()
|
||||||
wasm()
|
wasm()
|
||||||
}
|
|
||||||
|
|
||||||
kotlin.sourceSets {
|
useCoroutines()
|
||||||
|
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
|
||||||
api(projects.kmathCoroutines)
|
api(projects.kmathCoroutines)
|
||||||
//implementation(spclibs.atomicfu)
|
//implementation(spclibs.atomicfu)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
getByName("jvmMain") {
|
jvmMain {
|
||||||
dependencies {
|
api(libs.commons.rng.simple)
|
||||||
api("org.apache.commons:commons-rng-sampling:1.3")
|
api(libs.commons.rng.sampling)
|
||||||
api("org.apache.commons:commons-rng-simple:1.3")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,76 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2024 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.samplers
|
||||||
|
|
||||||
|
import space.kscience.kmath.chains.Chain
|
||||||
|
import space.kscience.kmath.chains.StatefulChain
|
||||||
|
import space.kscience.kmath.operations.Group
|
||||||
|
import space.kscience.kmath.operations.algebra
|
||||||
|
import space.kscience.kmath.random.RandomGenerator
|
||||||
|
import space.kscience.kmath.stat.Sampler
|
||||||
|
import space.kscience.kmath.structures.Float64
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling
|
||||||
|
* target distribution [targetPdf].
|
||||||
|
*
|
||||||
|
* @param stepSampler a sampler for proposal point (offset to previous point)
|
||||||
|
*/
|
||||||
|
public class MetropolisHastingsSampler<T>(
|
||||||
|
public val algebra: Group<T>,
|
||||||
|
public val startPoint: T,
|
||||||
|
public val stepSampler: Sampler<T>,
|
||||||
|
public val targetPdf: suspend (T) -> Float64,
|
||||||
|
) : Sampler<T> {
|
||||||
|
|
||||||
|
//TODO consider API for conditional step probability
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>(
|
||||||
|
state = stepSampler.sample(generator),
|
||||||
|
seed = { startPoint },
|
||||||
|
forkState = Chain<T>::fork
|
||||||
|
) { previousPoint: T ->
|
||||||
|
val proposalPoint = with(algebra) { previousPoint + next() }
|
||||||
|
val ratio = targetPdf(proposalPoint) / targetPdf(previousPoint)
|
||||||
|
if (ratio >= 1.0) {
|
||||||
|
proposalPoint
|
||||||
|
} else {
|
||||||
|
val acceptanceProbability = generator.nextDouble()
|
||||||
|
if (acceptanceProbability <= ratio) {
|
||||||
|
proposalPoint
|
||||||
|
} else {
|
||||||
|
previousPoint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Metropolis-Hastings sampler for univariate [Float64] values
|
||||||
|
*/
|
||||||
|
public fun univariate(
|
||||||
|
startPoint: Float64,
|
||||||
|
stepSampler: Sampler<Float64>,
|
||||||
|
targetPdf: suspend (Float64) -> Float64,
|
||||||
|
): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler(
|
||||||
|
algebra = Float64.algebra,
|
||||||
|
startPoint = startPoint,
|
||||||
|
stepSampler = stepSampler,
|
||||||
|
targetPdf = targetPdf
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Metropolis-Hastings sampler for univariate [Float64] values with normal step distribution
|
||||||
|
*/
|
||||||
|
public fun univariateNormal(
|
||||||
|
startPoint: Float64,
|
||||||
|
stepSigma: Float64,
|
||||||
|
targetPdf: suspend (Float64) -> Float64,
|
||||||
|
): MetropolisHastingsSampler<Double> = univariate(startPoint, GaussianSampler(0.0, stepSigma), targetPdf)
|
||||||
|
}
|
||||||
|
}
|
@ -17,7 +17,7 @@ import kotlin.jvm.JvmName
|
|||||||
/**
|
/**
|
||||||
* Sampler that generates chains of values of type [T].
|
* Sampler that generates chains of values of type [T].
|
||||||
*/
|
*/
|
||||||
public fun interface Sampler<out T : Any> {
|
public fun interface Sampler<out T> {
|
||||||
/**
|
/**
|
||||||
* Generates a chain of samples.
|
* Generates a chain of samples.
|
||||||
*
|
*
|
||||||
|
@ -0,0 +1,99 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2024 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.samplers
|
||||||
|
|
||||||
|
import kotlinx.coroutines.test.runTest
|
||||||
|
import space.kscience.kmath.chains.discard
|
||||||
|
import space.kscience.kmath.chains.nextBuffer
|
||||||
|
import space.kscience.kmath.distributions.NormalDistribution
|
||||||
|
import space.kscience.kmath.operations.Float64Field
|
||||||
|
import space.kscience.kmath.random.RandomGenerator
|
||||||
|
import space.kscience.kmath.stat.invoke
|
||||||
|
import space.kscience.kmath.stat.mean
|
||||||
|
import kotlin.math.PI
|
||||||
|
import kotlin.math.exp
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class TestMetropolisHastingsSampler {
|
||||||
|
|
||||||
|
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
|
||||||
|
|
||||||
|
val sample = 1e6.toInt()
|
||||||
|
val burnIn = sample / 5
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun samplingNormalTest() = runTest {
|
||||||
|
val generator = RandomGenerator.default(1)
|
||||||
|
|
||||||
|
listOf(
|
||||||
|
TestSetup(0.5, 0.0),
|
||||||
|
TestSetup(68.13, 60.0),
|
||||||
|
).forEach {
|
||||||
|
val distribution = NormalDistribution(it.mean, 1.0)
|
||||||
|
val sampler = MetropolisHastingsSampler
|
||||||
|
.univariateNormal(it.startPoint, it.sigma, distribution::probability)
|
||||||
|
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||||
|
|
||||||
|
assertEquals(it.mean, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun samplingExponentialTest() = runTest {
|
||||||
|
val generator = RandomGenerator.default(1)
|
||||||
|
|
||||||
|
fun expDist(lambda: Double, arg: Double): Double = if (arg <= 0.0) 0.0 else lambda * exp(-arg * lambda)
|
||||||
|
|
||||||
|
listOf(
|
||||||
|
TestSetup(0.5, 2.0),
|
||||||
|
TestSetup(2.0, 1.0)
|
||||||
|
).forEach { setup ->
|
||||||
|
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
|
||||||
|
expDist(setup.mean, it)
|
||||||
|
}
|
||||||
|
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||||
|
|
||||||
|
assertEquals(1.0 / setup.mean, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun samplingRayleighTest() = runTest {
|
||||||
|
val generator = RandomGenerator.default(1)
|
||||||
|
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
arg * exp(-(arg / sigma).pow(2) / 2.0) / sigma.pow(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
listOf(
|
||||||
|
TestSetup(0.5, 1.0),
|
||||||
|
TestSetup(2.0, 1.0)
|
||||||
|
).forEach { setup ->
|
||||||
|
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
|
||||||
|
rayleighDist(setup.mean, it)
|
||||||
|
}
|
||||||
|
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||||
|
|
||||||
|
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
|
||||||
|
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
|
||||||
|
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
||||||
|
//
|
||||||
|
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
//
|
||||||
|
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
|
||||||
|
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
|
||||||
|
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
|
||||||
|
//
|
||||||
|
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user