Add Metropolis-Hastings sampler.
Minor fixes.
This commit is contained in:
parent
e0997ccf9c
commit
57d1cd8c87
@ -3,6 +3,7 @@
|
|||||||
## Unreleased
|
## Unreleased
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
- Metropolis-Hastings sampler
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
|
9
gradle/libs.versions.toml
Normal file
9
gradle/libs.versions.toml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
[versions]
|
||||||
|
|
||||||
|
commons-rng = "1.6"
|
||||||
|
|
||||||
|
|
||||||
|
[libraries]
|
||||||
|
|
||||||
|
commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"}
|
||||||
|
commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"}
|
@ -40,13 +40,13 @@ public interface BlockingBufferChain<out T> : BlockingChain<T>, BufferChain<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public suspend inline fun <reified T : Any> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
|
public suspend inline fun <reified T> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
|
||||||
nextBuffer(size)
|
nextBuffer(size)
|
||||||
} else {
|
} else {
|
||||||
Buffer(size) { next() }
|
Buffer(size) { next() }
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <reified T : Any> BlockingChain<T>.nextBufferBlocking(
|
public inline fun <reified T> BlockingChain<T>.nextBufferBlocking(
|
||||||
size: Int,
|
size: Int,
|
||||||
): Buffer<T> = if (this is BlockingBufferChain) {
|
): Buffer<T> = if (this is BlockingBufferChain) {
|
||||||
nextBufferBlocking(size)
|
nextBufferBlocking(size)
|
||||||
|
@ -23,7 +23,7 @@ public interface Chain<out T> : Flow<T> {
|
|||||||
public suspend fun next(): T
|
public suspend fun next(): T
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a copy of current chain state. Consuming resulting chain does not affect initial chain.
|
* Create a copy of the current chain state. Consuming the resulting chain does not affect the initial chain.
|
||||||
*/
|
*/
|
||||||
public suspend fun fork(): Chain<T>
|
public suspend fun fork(): Chain<T>
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ public class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
|||||||
/**
|
/**
|
||||||
* A stateless Markov chain
|
* A stateless Markov chain
|
||||||
*/
|
*/
|
||||||
public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
public class MarkovChain<out R>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
||||||
private val mutex: Mutex = Mutex()
|
private val mutex: Mutex = Mutex()
|
||||||
private var value: R? = null
|
private var value: R? = null
|
||||||
|
|
||||||
@ -71,8 +71,8 @@ public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private
|
|||||||
*/
|
*/
|
||||||
public class StatefulChain<S, out R>(
|
public class StatefulChain<S, out R>(
|
||||||
private val state: S,
|
private val state: S,
|
||||||
private val seed: S.() -> R,
|
private val seed: suspend S.() -> R,
|
||||||
private val forkState: ((S) -> S),
|
private val forkState: suspend ((S) -> S),
|
||||||
private val gen: suspend S.(R) -> R,
|
private val gen: suspend S.(R) -> R,
|
||||||
) : Chain<R> {
|
) : Chain<R> {
|
||||||
private val mutex: Mutex = Mutex()
|
private val mutex: Mutex = Mutex()
|
||||||
@ -97,6 +97,11 @@ public class ConstantChain<out T>(public val value: T) : Chain<T> {
|
|||||||
override suspend fun fork(): Chain<T> = this
|
override suspend fun fork(): Chain<T> = this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discard a fixed number of samples
|
||||||
|
*/
|
||||||
|
public suspend inline fun <reified T> Chain<T>.discard(number: Int): Chain<T> = apply { nextBuffer<T>(number) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||||
* since mapped chain consumes tokens. Accepts regular transformation function.
|
* since mapped chain consumes tokens. Accepts regular transformation function.
|
||||||
|
@ -7,21 +7,17 @@ kscience {
|
|||||||
js()
|
js()
|
||||||
native()
|
native()
|
||||||
wasm()
|
wasm()
|
||||||
}
|
|
||||||
|
|
||||||
kotlin.sourceSets {
|
useCoroutines()
|
||||||
|
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
|
||||||
api(projects.kmathCoroutines)
|
api(projects.kmathCoroutines)
|
||||||
//implementation(spclibs.atomicfu)
|
//implementation(spclibs.atomicfu)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
getByName("jvmMain") {
|
jvmMain {
|
||||||
dependencies {
|
api(libs.commons.rng.simple)
|
||||||
api("org.apache.commons:commons-rng-sampling:1.3")
|
api(libs.commons.rng.sampling)
|
||||||
api("org.apache.commons:commons-rng-simple:1.3")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,46 +5,72 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.samplers
|
package space.kscience.kmath.samplers
|
||||||
|
|
||||||
import space.kscience.kmath.chains.BlockingDoubleChain
|
import space.kscience.kmath.chains.Chain
|
||||||
import space.kscience.kmath.distributions.Distribution1D
|
import space.kscience.kmath.chains.StatefulChain
|
||||||
import space.kscience.kmath.distributions.NormalDistribution
|
import space.kscience.kmath.operations.Group
|
||||||
|
import space.kscience.kmath.operations.algebra
|
||||||
import space.kscience.kmath.random.RandomGenerator
|
import space.kscience.kmath.random.RandomGenerator
|
||||||
import space.kscience.kmath.structures.Float64Buffer
|
import space.kscience.kmath.stat.Sampler
|
||||||
import kotlin.math.*
|
import space.kscience.kmath.structures.Float64
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling
|
* [Metropolis–Hastings algorithm](https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm) for sampling
|
||||||
* target distribution [targetDist].
|
* target distribution [targetPdf].
|
||||||
*
|
*
|
||||||
* The normal distribution is used as the proposal function.
|
* @param stepSampler a sampler for proposal point (offset to previous point)
|
||||||
*
|
|
||||||
* params:
|
|
||||||
* - targetDist: function close to the density of the sampled distribution;
|
|
||||||
* - initialState: initial value of the chain of sampled values;
|
|
||||||
* - proposalStd: standard deviation of the proposal function.
|
|
||||||
*/
|
*/
|
||||||
public class MetropolisHastingsSampler(
|
public class MetropolisHastingsSampler<T>(
|
||||||
public val targetDist: (arg : Double) -> Double,
|
public val algebra: Group<T>,
|
||||||
public val initialState : Double = 0.0,
|
public val startPoint: T,
|
||||||
public val proposalStd : Double = 1.0,
|
public val stepSampler: Sampler<T>,
|
||||||
) : BlockingDoubleSampler {
|
public val targetPdf: suspend (T) -> Float64,
|
||||||
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
|
) : Sampler<T> {
|
||||||
var currentState = initialState
|
|
||||||
fun proposalDist(arg : Double) = NormalDistribution(arg, proposalStd)
|
|
||||||
|
|
||||||
override fun nextBufferBlocking(size: Int): Float64Buffer {
|
//TODO consider API for conditional step probability
|
||||||
val acceptanceProb = generator.nextDoubleBuffer(size)
|
|
||||||
|
|
||||||
return Float64Buffer(size) {index ->
|
override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>(
|
||||||
val newState = proposalDist(currentState).sample(generator).nextBufferBlocking(1).get(0)
|
state = stepSampler.sample(generator),
|
||||||
val acceptanceRatio = min(1.0, targetDist(newState) / targetDist(currentState))
|
seed = { startPoint },
|
||||||
|
forkState = Chain<T>::fork
|
||||||
currentState = if (acceptanceProb[index] <= acceptanceRatio) newState else currentState
|
) { previousPoint: T ->
|
||||||
currentState
|
val proposalPoint = with(algebra) { previousPoint + next() }
|
||||||
|
val ratio = targetPdf(proposalPoint) / targetPdf(previousPoint)
|
||||||
|
if (ratio >= 1.0) {
|
||||||
|
proposalPoint
|
||||||
|
} else {
|
||||||
|
val acceptanceProbability = generator.nextDouble()
|
||||||
|
if (acceptanceProbability <= ratio) {
|
||||||
|
proposalPoint
|
||||||
|
} else {
|
||||||
|
previousPoint
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun fork(): BlockingDoubleChain = sample(generator.fork())
|
|
||||||
}
|
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Metropolis-Hastings sampler for univariate [Float64] values
|
||||||
|
*/
|
||||||
|
public fun univariate(
|
||||||
|
startPoint: Float64,
|
||||||
|
stepSampler: Sampler<Float64>,
|
||||||
|
targetPdf: suspend (Float64) -> Float64,
|
||||||
|
): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler(
|
||||||
|
algebra = Float64.algebra,
|
||||||
|
startPoint = startPoint,
|
||||||
|
stepSampler = stepSampler,
|
||||||
|
targetPdf = targetPdf
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A Metropolis-Hastings sampler for univariate [Float64] values with normal step distribution
|
||||||
|
*/
|
||||||
|
public fun univariateNormal(
|
||||||
|
startPoint: Float64,
|
||||||
|
stepSigma: Float64,
|
||||||
|
targetPdf: suspend (Float64) -> Float64,
|
||||||
|
): MetropolisHastingsSampler<Double> = univariate(startPoint, GaussianSampler(0.0, stepSigma), targetPdf)
|
||||||
|
}
|
||||||
}
|
}
|
@ -17,7 +17,7 @@ import kotlin.jvm.JvmName
|
|||||||
/**
|
/**
|
||||||
* Sampler that generates chains of values of type [T].
|
* Sampler that generates chains of values of type [T].
|
||||||
*/
|
*/
|
||||||
public fun interface Sampler<out T : Any> {
|
public fun interface Sampler<out T> {
|
||||||
/**
|
/**
|
||||||
* Generates a chain of samples.
|
* Generates a chain of samples.
|
||||||
*
|
*
|
||||||
|
@ -4,73 +4,96 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
package space.kscience.kmath.samplers
|
package space.kscience.kmath.samplers
|
||||||
|
|
||||||
|
import kotlinx.coroutines.test.runTest
|
||||||
|
import space.kscience.kmath.chains.discard
|
||||||
|
import space.kscience.kmath.chains.nextBuffer
|
||||||
import space.kscience.kmath.distributions.NormalDistribution
|
import space.kscience.kmath.distributions.NormalDistribution
|
||||||
import space.kscience.kmath.operations.Float64Field
|
import space.kscience.kmath.operations.Float64Field
|
||||||
import space.kscience.kmath.random.DefaultGenerator
|
import space.kscience.kmath.random.RandomGenerator
|
||||||
import space.kscience.kmath.stat.invoke
|
import space.kscience.kmath.stat.invoke
|
||||||
import space.kscience.kmath.stat.mean
|
import space.kscience.kmath.stat.mean
|
||||||
|
import kotlin.math.PI
|
||||||
import kotlin.math.exp
|
import kotlin.math.exp
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class TestMetropolisHastingsSampler {
|
class TestMetropolisHastingsSampler {
|
||||||
|
|
||||||
|
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
|
||||||
|
|
||||||
|
val sample = 1e6.toInt()
|
||||||
|
val burnIn = sample / 5
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun samplingNormalTest() {
|
fun samplingNormalTest() = runTest {
|
||||||
fun normalDist1(arg : Double) = NormalDistribution(0.5, 1.0).probability(arg)
|
val generator = RandomGenerator.default(1)
|
||||||
var sampler = MetropolisHastingsSampler(::normalDist1, proposalStd = 1.0)
|
|
||||||
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
|
||||||
|
|
||||||
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2)
|
listOf(
|
||||||
|
TestSetup(0.5, 0.0),
|
||||||
|
TestSetup(68.13, 60.0),
|
||||||
|
).forEach {
|
||||||
|
val distribution = NormalDistribution(it.mean, 1.0)
|
||||||
|
val sampler = MetropolisHastingsSampler
|
||||||
|
.univariateNormal(it.startPoint, it.sigma, distribution::probability)
|
||||||
|
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||||
|
|
||||||
fun normalDist2(arg : Double) = NormalDistribution(68.13, 1.0).probability(arg)
|
assertEquals(it.mean, Float64Field.mean(sampledValues), 1e-2)
|
||||||
sampler = MetropolisHastingsSampler(::normalDist2, initialState = 63.0, proposalStd = 1.0)
|
}
|
||||||
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
|
||||||
|
|
||||||
assertEquals(68.13, Float64Field.mean(sampledValues), 1e-2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun samplingExponentialTest() {
|
fun samplingExponentialTest() = runTest {
|
||||||
fun expDist(arg : Double, param : Double) : Double {
|
val generator = RandomGenerator.default(1)
|
||||||
if (arg < 0.0) { return 0.0 }
|
|
||||||
return param * exp(-param * arg)
|
fun expDist(lambda: Double, arg: Double): Double = if (arg <= 0.0) 0.0 else lambda * exp(-arg * lambda)
|
||||||
|
|
||||||
|
listOf(
|
||||||
|
TestSetup(0.5, 2.0),
|
||||||
|
TestSetup(2.0, 1.0)
|
||||||
|
).forEach { setup ->
|
||||||
|
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
|
||||||
|
expDist(setup.mean, it)
|
||||||
}
|
}
|
||||||
|
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||||
|
|
||||||
fun expDist1(arg : Double) = expDist(arg, 0.5)
|
assertEquals(1.0 / setup.mean, Float64Field.mean(sampledValues), 1e-2)
|
||||||
var sampler = MetropolisHastingsSampler(::expDist1, initialState = 2.0, proposalStd = 1.0)
|
}
|
||||||
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
|
||||||
|
|
||||||
assertEquals(2.0, Float64Field.mean(sampledValues), 1e-2)
|
|
||||||
|
|
||||||
fun expDist2(arg : Double) = expDist(arg, 2.0)
|
|
||||||
sampler = MetropolisHastingsSampler(::expDist2, initialState = 9.0, proposalStd = 1.0)
|
|
||||||
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
|
||||||
|
|
||||||
assertEquals(0.5, Float64Field.mean(sampledValues), 1e-2)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun samplingRayleighTest() {
|
fun samplingRayleighTest() = runTest {
|
||||||
fun rayleighDist(arg : Double, sigma : Double) : Double {
|
val generator = RandomGenerator.default(1)
|
||||||
if (arg < 0.0) { return 0.0 }
|
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
|
||||||
|
0.0
|
||||||
val expArg = (arg / sigma).pow(2)
|
} else {
|
||||||
return arg * exp(-expArg / 2.0) / sigma.pow(2)
|
arg * exp(-(arg / sigma).pow(2) / 2.0) / sigma.pow(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun rayleighDist1(arg : Double) = rayleighDist(arg, 1.0)
|
listOf(
|
||||||
var sampler = MetropolisHastingsSampler(::rayleighDist1, initialState = 2.0, proposalStd = 1.0)
|
TestSetup(0.5, 1.0),
|
||||||
var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
TestSetup(2.0, 1.0)
|
||||||
|
).forEach { setup ->
|
||||||
|
val sampler = MetropolisHastingsSampler.univariateNormal(setup.startPoint, setup.sigma) {
|
||||||
|
rayleighDist(setup.mean, it)
|
||||||
|
}
|
||||||
|
val sampledValues = sampler.sample(generator).discard(burnIn).nextBuffer(sample)
|
||||||
|
|
||||||
assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
|
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
}
|
||||||
fun rayleighDist2(arg : Double) = rayleighDist(arg, 2.0)
|
//
|
||||||
sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
|
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
|
||||||
sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
|
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
|
||||||
|
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
||||||
assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
|
//
|
||||||
|
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
|
||||||
|
//
|
||||||
|
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
|
||||||
|
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
|
||||||
|
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
|
||||||
|
//
|
||||||
|
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user