Minimal refactor of existing random API, move samplers implementations to samplers package, implement Sampler<T> by all the Samplers
This commit is contained in:
parent
bc59f8b287
commit
28062cb096
@ -19,6 +19,7 @@ package scientifik.kmath.chains
|
|||||||
import kotlinx.coroutines.InternalCoroutinesApi
|
import kotlinx.coroutines.InternalCoroutinesApi
|
||||||
import kotlinx.coroutines.flow.Flow
|
import kotlinx.coroutines.flow.Flow
|
||||||
import kotlinx.coroutines.flow.FlowCollector
|
import kotlinx.coroutines.flow.FlowCollector
|
||||||
|
import kotlinx.coroutines.flow.flow
|
||||||
import kotlinx.coroutines.sync.Mutex
|
import kotlinx.coroutines.sync.Mutex
|
||||||
import kotlinx.coroutines.sync.withLock
|
import kotlinx.coroutines.sync.withLock
|
||||||
|
|
||||||
@ -27,7 +28,7 @@ import kotlinx.coroutines.sync.withLock
|
|||||||
* A not-necessary-Markov chain of some type
|
* A not-necessary-Markov chain of some type
|
||||||
* @param R - the chain element type
|
* @param R - the chain element type
|
||||||
*/
|
*/
|
||||||
interface Chain<out R>: Flow<R> {
|
interface Chain<out R> : Flow<R> {
|
||||||
/**
|
/**
|
||||||
* Generate next value, changing state if needed
|
* Generate next value, changing state if needed
|
||||||
*/
|
*/
|
||||||
@ -39,13 +40,11 @@ interface Chain<out R>: Flow<R> {
|
|||||||
fun fork(): Chain<R>
|
fun fork(): Chain<R>
|
||||||
|
|
||||||
@OptIn(InternalCoroutinesApi::class)
|
@OptIn(InternalCoroutinesApi::class)
|
||||||
override suspend fun collect(collector: FlowCollector<R>) {
|
override suspend fun collect(collector: FlowCollector<R>): Unit = flow {
|
||||||
kotlinx.coroutines.flow.flow {
|
while (true)
|
||||||
while (true){
|
emit(next())
|
||||||
emit(next())
|
|
||||||
}
|
}.collect(collector)
|
||||||
}.collect(collector)
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object
|
companion object
|
||||||
}
|
}
|
||||||
@ -66,24 +65,18 @@ class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
|||||||
* A stateless Markov chain
|
* A stateless Markov chain
|
||||||
*/
|
*/
|
||||||
class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
||||||
|
private val mutex: Mutex = Mutex()
|
||||||
private val mutex = Mutex()
|
|
||||||
|
|
||||||
private var value: R? = null
|
private var value: R? = null
|
||||||
|
|
||||||
fun value() = value
|
fun value() = value
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R = mutex.withLock {
|
||||||
mutex.withLock {
|
val newValue = gen(value ?: seed())
|
||||||
val newValue = gen(value ?: seed())
|
value = newValue
|
||||||
value = newValue
|
return newValue
|
||||||
return newValue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
override fun fork(): Chain<R> = MarkovChain(seed = { value ?: seed() }, gen = gen)
|
||||||
return MarkovChain(seed = { value ?: seed() }, gen = gen)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -97,24 +90,18 @@ class StatefulChain<S, out R>(
|
|||||||
private val forkState: ((S) -> S),
|
private val forkState: ((S) -> S),
|
||||||
private val gen: suspend S.(R) -> R
|
private val gen: suspend S.(R) -> R
|
||||||
) : Chain<R> {
|
) : Chain<R> {
|
||||||
|
|
||||||
private val mutex = Mutex()
|
private val mutex = Mutex()
|
||||||
|
|
||||||
private var value: R? = null
|
private var value: R? = null
|
||||||
|
|
||||||
fun value() = value
|
fun value() = value
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R = mutex.withLock {
|
||||||
mutex.withLock {
|
val newValue = state.gen(value ?: state.seed())
|
||||||
val newValue = state.gen(value ?: state.seed())
|
value = newValue
|
||||||
value = newValue
|
return newValue
|
||||||
return newValue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
override fun fork(): Chain<R> = StatefulChain(forkState(state), seed, forkState, gen)
|
||||||
return StatefulChain(forkState(state), seed, forkState, gen)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -123,9 +110,7 @@ class StatefulChain<S, out R>(
|
|||||||
class ConstantChain<out T>(val value: T) : Chain<T> {
|
class ConstantChain<out T>(val value: T) : Chain<T> {
|
||||||
override suspend fun next(): T = value
|
override suspend fun next(): T = value
|
||||||
|
|
||||||
override fun fork(): Chain<T> {
|
override fun fork(): Chain<T> = this
|
||||||
return this
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -143,9 +128,8 @@ fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
|
|||||||
fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain<T> {
|
fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain<T> {
|
||||||
override suspend fun next(): T {
|
override suspend fun next(): T {
|
||||||
var next: T
|
var next: T
|
||||||
do {
|
do next = this@filter.next()
|
||||||
next = this@filter.next()
|
while (!block(next))
|
||||||
} while (!block(next))
|
|
||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,7 +147,9 @@ fun <T, R> Chain<T>.collect(mapper: suspend (Chain<T>) -> R): Chain<R> = object
|
|||||||
fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
||||||
object : Chain<R> {
|
object : Chain<R> {
|
||||||
override suspend fun next(): R = state.mapper(this@collectWithState)
|
override suspend fun next(): R = state.mapper(this@collectWithState)
|
||||||
override fun fork(): Chain<R> = this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper)
|
|
||||||
|
override fun fork(): Chain<R> =
|
||||||
|
this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -171,6 +157,5 @@ fun <T, S, R> Chain<T>.collectWithState(state: S, stateFork: (S) -> S, mapper: s
|
|||||||
*/
|
*/
|
||||||
fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
|
fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
|
||||||
override suspend fun next(): R = block(this@zip.next(), other.next())
|
override suspend fun next(): R = block(this@zip.next(), other.next())
|
||||||
|
|
||||||
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
|
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
|
||||||
}
|
}
|
||||||
|
@ -1,19 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng
|
|
||||||
|
|
||||||
interface UniformRandomProvider {
|
|
||||||
fun nextBytes(bytes: ByteArray)
|
|
||||||
|
|
||||||
fun nextBytes(
|
|
||||||
bytes: ByteArray,
|
|
||||||
start: Int,
|
|
||||||
len: Int
|
|
||||||
)
|
|
||||||
|
|
||||||
fun nextInt(): Int
|
|
||||||
fun nextInt(n: Int): Int
|
|
||||||
fun nextLong(): Long
|
|
||||||
fun nextLong(n: Long): Long
|
|
||||||
fun nextBoolean(): Boolean
|
|
||||||
fun nextFloat(): Float
|
|
||||||
fun nextDouble(): Double
|
|
||||||
}
|
|
@ -1,7 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
|
|
||||||
interface SharedStateSampler<R> {
|
|
||||||
fun withUniformRandomProvider(rng: UniformRandomProvider): R
|
|
||||||
}
|
|
@ -1,101 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
import kotlin.math.ln
|
|
||||||
import kotlin.math.pow
|
|
||||||
|
|
||||||
class AhrensDieterExponentialSampler : SamplerBase,
|
|
||||||
SharedStateContinuousSampler {
|
|
||||||
private val mean: Double
|
|
||||||
private val rng: UniformRandomProvider
|
|
||||||
|
|
||||||
constructor(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
mean: Double
|
|
||||||
) : super(null) {
|
|
||||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
|
||||||
this.rng = rng
|
|
||||||
this.mean = mean
|
|
||||||
}
|
|
||||||
|
|
||||||
private constructor(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
source: AhrensDieterExponentialSampler
|
|
||||||
) : super(null) {
|
|
||||||
this.rng = rng
|
|
||||||
mean = source.mean
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sample(): Double {
|
|
||||||
// Step 1:
|
|
||||||
var a = 0.0
|
|
||||||
var u: Double = rng.nextDouble()
|
|
||||||
|
|
||||||
// Step 2 and 3:
|
|
||||||
while (u < 0.5) {
|
|
||||||
a += EXPONENTIAL_SA_QI.get(
|
|
||||||
0
|
|
||||||
)
|
|
||||||
u *= 2.0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 4 (now u >= 0.5):
|
|
||||||
u += u - 1
|
|
||||||
|
|
||||||
// Step 5:
|
|
||||||
if (u <= EXPONENTIAL_SA_QI.get(
|
|
||||||
0
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
return mean * (a + u)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 6:
|
|
||||||
var i = 0 // Should be 1, be we iterate before it in while using 0.
|
|
||||||
var u2: Double = rng.nextDouble()
|
|
||||||
var umin = u2
|
|
||||||
|
|
||||||
// Step 7 and 8:
|
|
||||||
do {
|
|
||||||
++i
|
|
||||||
u2 = rng.nextDouble()
|
|
||||||
if (u2 < umin) umin = u2
|
|
||||||
// Step 8:
|
|
||||||
} while (u > EXPONENTIAL_SA_QI[i]) // Ensured to exit since EXPONENTIAL_SA_QI[MAX] = 1.
|
|
||||||
return mean * (a + umin * EXPONENTIAL_SA_QI[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun toString(): String = "Ahrens-Dieter Exponential deviate [$rng]"
|
|
||||||
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler =
|
|
||||||
AhrensDieterExponentialSampler(rng, this)
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
private val EXPONENTIAL_SA_QI = DoubleArray(16)
|
|
||||||
|
|
||||||
fun of(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
mean: Double
|
|
||||||
): SharedStateContinuousSampler =
|
|
||||||
AhrensDieterExponentialSampler(
|
|
||||||
rng,
|
|
||||||
mean
|
|
||||||
)
|
|
||||||
|
|
||||||
init {
|
|
||||||
/**
|
|
||||||
* Filling EXPONENTIAL_SA_QI table.
|
|
||||||
* Note that we don't want qi = 0 in the table.
|
|
||||||
*/
|
|
||||||
val ln2 = ln(2.0)
|
|
||||||
var qi = 0.0
|
|
||||||
|
|
||||||
EXPONENTIAL_SA_QI.indices.forEach { i ->
|
|
||||||
qi += ln2.pow(i + 1.0) / InternalUtils.factorial(
|
|
||||||
i + 1
|
|
||||||
)
|
|
||||||
EXPONENTIAL_SA_QI[i] = qi
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,50 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
import kotlin.math.*
|
|
||||||
|
|
||||||
class BoxMullerNormalizedGaussianSampler(
|
|
||||||
private val rng: UniformRandomProvider
|
|
||||||
) :
|
|
||||||
NormalizedGaussianSampler,
|
|
||||||
SharedStateContinuousSampler {
|
|
||||||
private var nextGaussian: Double = Double.NaN
|
|
||||||
|
|
||||||
override fun sample(): Double {
|
|
||||||
val random: Double
|
|
||||||
|
|
||||||
if (nextGaussian.isNaN()) {
|
|
||||||
// Generate a pair of Gaussian numbers.
|
|
||||||
val x = rng.nextDouble()
|
|
||||||
val y = rng.nextDouble()
|
|
||||||
val alpha = 2 * PI * x
|
|
||||||
val r = sqrt(-2 * ln(y))
|
|
||||||
|
|
||||||
// Return the first element of the generated pair.
|
|
||||||
random = r * cos(alpha)
|
|
||||||
|
|
||||||
// Keep second element of the pair for next invocation.
|
|
||||||
nextGaussian = r * sin(alpha)
|
|
||||||
} else {
|
|
||||||
// Use the second element of the pair (generated at the
|
|
||||||
// previous invocation).
|
|
||||||
random = nextGaussian
|
|
||||||
|
|
||||||
// Both elements of the pair have been used.
|
|
||||||
nextGaussian = Double.NaN
|
|
||||||
}
|
|
||||||
|
|
||||||
return random
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun toString(): String = "Box-Muller normalized Gaussian deviate [$rng]"
|
|
||||||
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler =
|
|
||||||
BoxMullerNormalizedGaussianSampler(rng)
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
fun <S> of(rng: UniformRandomProvider): S where S : NormalizedGaussianSampler?, S : SharedStateContinuousSampler? =
|
|
||||||
BoxMullerNormalizedGaussianSampler(rng) as S
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,5 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
interface ContinuousSampler {
|
|
||||||
fun sample(): Double
|
|
||||||
}
|
|
@ -1,5 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
interface DiscreteSampler {
|
|
||||||
fun sample(): Int
|
|
||||||
}
|
|
@ -1,55 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
|
|
||||||
class GaussianSampler :
|
|
||||||
SharedStateContinuousSampler {
|
|
||||||
private val mean: Double
|
|
||||||
private val standardDeviation: Double
|
|
||||||
private val normalized: NormalizedGaussianSampler
|
|
||||||
|
|
||||||
constructor(
|
|
||||||
normalized: NormalizedGaussianSampler,
|
|
||||||
mean: Double,
|
|
||||||
standardDeviation: Double
|
|
||||||
) {
|
|
||||||
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
|
|
||||||
this.normalized = normalized
|
|
||||||
this.mean = mean
|
|
||||||
this.standardDeviation = standardDeviation
|
|
||||||
}
|
|
||||||
|
|
||||||
private constructor(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
source: GaussianSampler
|
|
||||||
) {
|
|
||||||
mean = source.mean
|
|
||||||
standardDeviation = source.standardDeviation
|
|
||||||
normalized =
|
|
||||||
InternalUtils.newNormalizedGaussianSampler(
|
|
||||||
source.normalized,
|
|
||||||
rng
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sample(): Double = standardDeviation * normalized.sample() + mean
|
|
||||||
|
|
||||||
override fun toString(): String = "Gaussian deviate [$normalized]"
|
|
||||||
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler {
|
|
||||||
return GaussianSampler(rng, this)
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
fun of(
|
|
||||||
normalized: NormalizedGaussianSampler,
|
|
||||||
mean: Double,
|
|
||||||
standardDeviation: Double
|
|
||||||
): SharedStateContinuousSampler =
|
|
||||||
GaussianSampler(
|
|
||||||
normalized,
|
|
||||||
mean,
|
|
||||||
standardDeviation
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,251 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
import kotlin.math.*
|
|
||||||
|
|
||||||
class LargeMeanPoissonSampler :
|
|
||||||
SharedStateDiscreteSampler {
|
|
||||||
private val rng: UniformRandomProvider
|
|
||||||
private val exponential: SharedStateContinuousSampler
|
|
||||||
private val gaussian: SharedStateContinuousSampler
|
|
||||||
private val factorialLog: InternalUtils.FactorialLog
|
|
||||||
private val lambda: Double
|
|
||||||
private val logLambda: Double
|
|
||||||
private val logLambdaFactorial: Double
|
|
||||||
private val delta: Double
|
|
||||||
private val halfDelta: Double
|
|
||||||
private val twolpd: Double
|
|
||||||
private val p1: Double
|
|
||||||
private val p2: Double
|
|
||||||
private val c1: Double
|
|
||||||
private val smallMeanPoissonSampler: SharedStateDiscreteSampler
|
|
||||||
|
|
||||||
constructor(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
mean: Double
|
|
||||||
) {
|
|
||||||
require(mean >= 1) { "mean is not >= 1: $mean" }
|
|
||||||
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
|
||||||
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
|
||||||
this.rng = rng
|
|
||||||
gaussian =
|
|
||||||
ZigguratNormalizedGaussianSampler(rng)
|
|
||||||
exponential =
|
|
||||||
AhrensDieterExponentialSampler.of(
|
|
||||||
rng,
|
|
||||||
1.0
|
|
||||||
)
|
|
||||||
// Plain constructor uses the uncached function.
|
|
||||||
factorialLog = NO_CACHE_FACTORIAL_LOG!!
|
|
||||||
// Cache values used in the algorithm
|
|
||||||
lambda = floor(mean)
|
|
||||||
logLambda = ln(lambda)
|
|
||||||
logLambdaFactorial = getFactorialLog(lambda.toInt())
|
|
||||||
delta = sqrt(lambda * ln(32 * lambda / PI + 1))
|
|
||||||
halfDelta = delta / 2
|
|
||||||
twolpd = 2 * lambda + delta
|
|
||||||
c1 = 1 / (8 * lambda)
|
|
||||||
val a1: Double = sqrt(PI * twolpd) * exp(c1)
|
|
||||||
val a2: Double = twolpd / delta * exp(-delta * (1 + delta) / twolpd)
|
|
||||||
val aSum = a1 + a2 + 1
|
|
||||||
p1 = a1 / aSum
|
|
||||||
p2 = a2 / aSum
|
|
||||||
|
|
||||||
// The algorithm requires a Poisson sample from the remaining lambda fraction.
|
|
||||||
val lambdaFractional = mean - lambda
|
|
||||||
smallMeanPoissonSampler =
|
|
||||||
if (lambdaFractional < Double.MIN_VALUE) NO_SMALL_MEAN_POISSON_SAMPLER else // Not used.
|
|
||||||
KempSmallMeanPoissonSampler.of(
|
|
||||||
rng,
|
|
||||||
lambdaFractional
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal constructor(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
state: LargeMeanPoissonSamplerState,
|
|
||||||
lambdaFractional: Double
|
|
||||||
) {
|
|
||||||
require(!(lambdaFractional < 0 || lambdaFractional >= 1)) { "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): $lambdaFractional" }
|
|
||||||
this.rng = rng
|
|
||||||
gaussian =
|
|
||||||
ZigguratNormalizedGaussianSampler(rng)
|
|
||||||
exponential =
|
|
||||||
AhrensDieterExponentialSampler.of(
|
|
||||||
rng,
|
|
||||||
1.0
|
|
||||||
)
|
|
||||||
// Plain constructor uses the uncached function.
|
|
||||||
factorialLog = NO_CACHE_FACTORIAL_LOG!!
|
|
||||||
// Use the state to initialise the algorithm
|
|
||||||
lambda = state.lambdaRaw
|
|
||||||
logLambda = state.logLambda
|
|
||||||
logLambdaFactorial = state.logLambdaFactorial
|
|
||||||
delta = state.delta
|
|
||||||
halfDelta = state.halfDelta
|
|
||||||
twolpd = state.twolpd
|
|
||||||
p1 = state.p1
|
|
||||||
p2 = state.p2
|
|
||||||
c1 = state.c1
|
|
||||||
|
|
||||||
// The algorithm requires a Poisson sample from the remaining lambda fraction.
|
|
||||||
smallMeanPoissonSampler =
|
|
||||||
if (lambdaFractional < Double.MIN_VALUE)
|
|
||||||
NO_SMALL_MEAN_POISSON_SAMPLER
|
|
||||||
else // Not used.
|
|
||||||
KempSmallMeanPoissonSampler.of(
|
|
||||||
rng,
|
|
||||||
lambdaFractional
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param rng Generator of uniformly distributed random numbers.
|
|
||||||
* @param source Source to copy.
|
|
||||||
*/
|
|
||||||
private constructor(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
source: LargeMeanPoissonSampler
|
|
||||||
) {
|
|
||||||
this.rng = rng
|
|
||||||
gaussian = source.gaussian.withUniformRandomProvider(rng)!!
|
|
||||||
exponential = source.exponential.withUniformRandomProvider(rng)!!
|
|
||||||
// Reuse the cache
|
|
||||||
factorialLog = source.factorialLog
|
|
||||||
lambda = source.lambda
|
|
||||||
logLambda = source.logLambda
|
|
||||||
logLambdaFactorial = source.logLambdaFactorial
|
|
||||||
delta = source.delta
|
|
||||||
halfDelta = source.halfDelta
|
|
||||||
twolpd = source.twolpd
|
|
||||||
p1 = source.p1
|
|
||||||
p2 = source.p2
|
|
||||||
c1 = source.c1
|
|
||||||
|
|
||||||
// Share the state of the small sampler
|
|
||||||
smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng)!!
|
|
||||||
}
|
|
||||||
|
|
||||||
/** {@inheritDoc} */
|
|
||||||
override fun sample(): Int {
|
|
||||||
// This will never be null. It may be a no-op delegate that returns zero.
|
|
||||||
val y2: Int = smallMeanPoissonSampler.sample()
|
|
||||||
var x: Double
|
|
||||||
var y: Double
|
|
||||||
var v: Double
|
|
||||||
var a: Int
|
|
||||||
var t: Double
|
|
||||||
var qr: Double
|
|
||||||
var qa: Double
|
|
||||||
while (true) {
|
|
||||||
// Step 1:
|
|
||||||
val u = rng.nextDouble()
|
|
||||||
|
|
||||||
if (u <= p1) {
|
|
||||||
// Step 2:
|
|
||||||
val n = gaussian.sample()
|
|
||||||
x = n * sqrt(lambda + halfDelta) - 0.5
|
|
||||||
if (x > delta || x < -lambda) continue
|
|
||||||
y = if (x < 0) floor(x) else ceil(x)
|
|
||||||
val e = exponential.sample()
|
|
||||||
v = -e - 0.5 * n * n + c1
|
|
||||||
} else {
|
|
||||||
// Step 3:
|
|
||||||
if (u > p1 + p2) {
|
|
||||||
y = lambda
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
x = delta + twolpd / delta * exponential.sample()
|
|
||||||
y = ceil(x)
|
|
||||||
v = -exponential.sample() - delta * (x + 1) / twolpd
|
|
||||||
}
|
|
||||||
// The Squeeze Principle
|
|
||||||
// Step 4.1:
|
|
||||||
a = if (x < 0) 1 else 0
|
|
||||||
t = y * (y + 1) / (2 * lambda)
|
|
||||||
|
|
||||||
// Step 4.2
|
|
||||||
if (v < -t && a == 0) {
|
|
||||||
y += lambda
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 4.3:
|
|
||||||
qr = t * ((2 * y + 1) / (6 * lambda) - 1)
|
|
||||||
qa = qr - t * t / (3 * (lambda + a * (y + 1)))
|
|
||||||
|
|
||||||
// Step 4.4:
|
|
||||||
if (v < qa) {
|
|
||||||
y += lambda
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 4.5:
|
|
||||||
if (v > qr) continue
|
|
||||||
|
|
||||||
// Step 4.6:
|
|
||||||
if (v < y * logLambda - getFactorialLog((y + lambda).toInt()) + logLambdaFactorial) {
|
|
||||||
y += lambda
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private fun getFactorialLog(n: Int): Double = factorialLog.value(n)
|
|
||||||
|
|
||||||
override fun toString(): String = "Large Mean Poisson deviate [$rng]"
|
|
||||||
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler =
|
|
||||||
LargeMeanPoissonSampler(rng, this)
|
|
||||||
|
|
||||||
val state: LargeMeanPoissonSamplerState
|
|
||||||
get() = LargeMeanPoissonSamplerState(
|
|
||||||
lambda, logLambda, logLambdaFactorial,
|
|
||||||
delta, halfDelta, twolpd, p1, p2, c1
|
|
||||||
)
|
|
||||||
|
|
||||||
class LargeMeanPoissonSamplerState(
|
|
||||||
val lambdaRaw: Double,
|
|
||||||
val logLambda: Double,
|
|
||||||
val logLambdaFactorial: Double,
|
|
||||||
val delta: Double,
|
|
||||||
val halfDelta: Double,
|
|
||||||
val twolpd: Double,
|
|
||||||
val p1: Double,
|
|
||||||
val p2: Double,
|
|
||||||
val c1: Double
|
|
||||||
) {
|
|
||||||
fun getLambda(): Int = lambdaRaw.toInt()
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
private const val MAX_MEAN = 0.5 * Int.MAX_VALUE
|
|
||||||
private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null
|
|
||||||
|
|
||||||
private val NO_SMALL_MEAN_POISSON_SAMPLER: SharedStateDiscreteSampler =
|
|
||||||
object : SharedStateDiscreteSampler {
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler =
|
|
||||||
// No requirement for RNG
|
|
||||||
this
|
|
||||||
|
|
||||||
override fun sample(): Int =// No Poisson sample
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
fun of(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
mean: Double
|
|
||||||
): SharedStateDiscreteSampler =
|
|
||||||
LargeMeanPoissonSampler(rng, mean)
|
|
||||||
|
|
||||||
init {
|
|
||||||
// Create without a cache.
|
|
||||||
NO_CACHE_FACTORIAL_LOG =
|
|
||||||
InternalUtils.FactorialLog.create()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,4 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
interface NormalizedGaussianSampler :
|
|
||||||
ContinuousSampler
|
|
@ -1,40 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
|
|
||||||
class PoissonSampler(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
mean: Double
|
|
||||||
) : SamplerBase(null),
|
|
||||||
SharedStateDiscreteSampler {
|
|
||||||
private val poissonSamplerDelegate: SharedStateDiscreteSampler
|
|
||||||
|
|
||||||
override fun sample(): Int = poissonSamplerDelegate.sample()
|
|
||||||
override fun toString(): String = poissonSamplerDelegate.toString()
|
|
||||||
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler? =
|
|
||||||
// Direct return of the optimised sampler
|
|
||||||
poissonSamplerDelegate.withUniformRandomProvider(rng)
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
const val PIVOT = 40.0
|
|
||||||
|
|
||||||
fun of(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
mean: Double
|
|
||||||
): SharedStateDiscreteSampler =// Each sampler should check the input arguments.
|
|
||||||
if (mean < PIVOT) SmallMeanPoissonSampler.of(
|
|
||||||
rng,
|
|
||||||
mean
|
|
||||||
) else LargeMeanPoissonSampler.of(
|
|
||||||
rng,
|
|
||||||
mean
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
init {
|
|
||||||
// Delegate all work to specialised samplers.
|
|
||||||
poissonSamplerDelegate =
|
|
||||||
of(rng, mean)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
|
|
||||||
@Deprecated("Since version 1.1. Class intended for internal use only.")
|
|
||||||
open class SamplerBase protected constructor(private val rng: UniformRandomProvider?) {
|
|
||||||
protected fun nextDouble(): Double = rng!!.nextDouble()
|
|
||||||
protected fun nextInt(): Int = rng!!.nextInt()
|
|
||||||
protected fun nextInt(max: Int): Int = rng!!.nextInt(max)
|
|
||||||
protected fun nextLong(): Long = rng!!.nextLong()
|
|
||||||
override fun toString(): String = "rng=$rng"
|
|
||||||
}
|
|
@ -1,7 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.sampling.SharedStateSampler
|
|
||||||
|
|
||||||
interface SharedStateContinuousSampler : ContinuousSampler,
|
|
||||||
SharedStateSampler<SharedStateContinuousSampler?> {
|
|
||||||
}
|
|
@ -1,7 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.sampling.SharedStateSampler
|
|
||||||
|
|
||||||
interface SharedStateDiscreteSampler : DiscreteSampler,
|
|
||||||
SharedStateSampler<SharedStateDiscreteSampler?> { // Composite interface
|
|
||||||
}
|
|
@ -1,60 +0,0 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
import kotlin.math.ceil
|
|
||||||
import kotlin.math.exp
|
|
||||||
|
|
||||||
class SmallMeanPoissonSampler :
|
|
||||||
SharedStateDiscreteSampler {
|
|
||||||
private val p0: Double
|
|
||||||
private val limit: Int
|
|
||||||
private val rng: UniformRandomProvider
|
|
||||||
|
|
||||||
constructor(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
mean: Double
|
|
||||||
) {
|
|
||||||
this.rng = rng
|
|
||||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
|
||||||
p0 = exp(-mean)
|
|
||||||
|
|
||||||
limit = (if (p0 > 0) ceil(1000 * mean) else throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
|
|
||||||
// This excludes NaN values for the mean
|
|
||||||
// else
|
|
||||||
// The returned sample is bounded by 1000 * mean
|
|
||||||
}
|
|
||||||
|
|
||||||
private constructor(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
source: SmallMeanPoissonSampler
|
|
||||||
) {
|
|
||||||
this.rng = rng
|
|
||||||
p0 = source.p0
|
|
||||||
limit = source.limit
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sample(): Int {
|
|
||||||
var n = 0
|
|
||||||
var r = 1.0
|
|
||||||
|
|
||||||
while (n < limit) {
|
|
||||||
r *= rng.nextDouble()
|
|
||||||
if (r >= p0) n++ else break
|
|
||||||
}
|
|
||||||
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun toString(): String = "Small Mean Poisson deviate [$rng]"
|
|
||||||
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler =
|
|
||||||
SmallMeanPoissonSampler(rng, this)
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
fun of(
|
|
||||||
rng: UniformRandomProvider,
|
|
||||||
mean: Double
|
|
||||||
): SharedStateDiscreteSampler =
|
|
||||||
SmallMeanPoissonSampler(rng, mean)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,5 +1,6 @@
|
|||||||
package scientifik.kmath.prob
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.flow.first
|
||||||
import scientifik.kmath.chains.Chain
|
import scientifik.kmath.chains.Chain
|
||||||
import scientifik.kmath.chains.collect
|
import scientifik.kmath.chains.collect
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
@ -69,6 +70,8 @@ fun <T : Any> Sampler<T>.sampleBuffer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator) = sample(generator).first()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a bunch of samples from real distributions
|
* Generate a bunch of samples from real distributions
|
||||||
*/
|
*/
|
||||||
|
@ -12,7 +12,6 @@ interface NamedDistribution<T> : Distribution<Map<String, T>>
|
|||||||
* A multivariate distribution that has independent distributions for separate axis
|
* A multivariate distribution that has independent distributions for separate axis
|
||||||
*/
|
*/
|
||||||
class FactorizedDistribution<T>(val distributions: Collection<NamedDistribution<T>>) : NamedDistribution<T> {
|
class FactorizedDistribution<T>(val distributions: Collection<NamedDistribution<T>>) : NamedDistribution<T> {
|
||||||
|
|
||||||
override fun probability(arg: Map<String, T>): Double {
|
override fun probability(arg: Map<String, T>): Double {
|
||||||
return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
|
return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
|
||||||
}
|
}
|
||||||
|
@ -7,13 +7,11 @@ import kotlin.random.Random
|
|||||||
*/
|
*/
|
||||||
interface RandomGenerator {
|
interface RandomGenerator {
|
||||||
fun nextBoolean(): Boolean
|
fun nextBoolean(): Boolean
|
||||||
|
|
||||||
fun nextDouble(): Double
|
fun nextDouble(): Double
|
||||||
fun nextInt(): Int
|
fun nextInt(): Int
|
||||||
fun nextInt(until: Int): Int
|
fun nextInt(until: Int): Int
|
||||||
fun nextLong(): Long
|
fun nextLong(): Long
|
||||||
fun nextLong(until: Long): Long
|
fun nextLong(until: Long): Long
|
||||||
|
|
||||||
fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size)
|
fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size)
|
||||||
fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) }
|
fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) }
|
||||||
|
|
||||||
@ -28,21 +26,16 @@ interface RandomGenerator {
|
|||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
val default by lazy { DefaultGenerator() }
|
val default by lazy { DefaultGenerator() }
|
||||||
|
|
||||||
fun default(seed: Long) = DefaultGenerator(Random(seed))
|
fun default(seed: Long) = DefaultGenerator(Random(seed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline class DefaultGenerator(val random: Random = Random) : RandomGenerator {
|
inline class DefaultGenerator(private val random: Random = Random) : RandomGenerator {
|
||||||
override fun nextBoolean(): Boolean = random.nextBoolean()
|
override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||||
|
|
||||||
override fun nextDouble(): Double = random.nextDouble()
|
override fun nextDouble(): Double = random.nextDouble()
|
||||||
|
|
||||||
override fun nextInt(): Int = random.nextInt()
|
override fun nextInt(): Int = random.nextInt()
|
||||||
override fun nextInt(until: Int): Int = random.nextInt(until)
|
override fun nextInt(until: Int): Int = random.nextInt(until)
|
||||||
|
|
||||||
override fun nextLong(): Long = random.nextLong()
|
override fun nextLong(): Long = random.nextLong()
|
||||||
|
|
||||||
override fun nextLong(until: Long): Long = random.nextLong(until)
|
override fun nextLong(until: Long): Long = random.nextLong(until)
|
||||||
|
|
||||||
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
||||||
@ -50,6 +43,5 @@ inline class DefaultGenerator(val random: Random = Random) : RandomGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
|
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
|
||||||
|
|
||||||
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
|
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
|
||||||
}
|
}
|
@ -0,0 +1,68 @@
|
|||||||
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
import scientifik.kmath.prob.chain
|
||||||
|
import kotlin.math.ln
|
||||||
|
import kotlin.math.pow
|
||||||
|
|
||||||
|
class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
|
||||||
|
init {
|
||||||
|
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||||
|
// Step 1:
|
||||||
|
var a = 0.0
|
||||||
|
var u = nextDouble()
|
||||||
|
|
||||||
|
// Step 2 and 3:
|
||||||
|
while (u < 0.5) {
|
||||||
|
a += EXPONENTIAL_SA_QI[0]
|
||||||
|
u *= 2.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4 (now u >= 0.5):
|
||||||
|
u += u - 1
|
||||||
|
// Step 5:
|
||||||
|
if (u <= EXPONENTIAL_SA_QI[0]) return@chain mean * (a + u)
|
||||||
|
// Step 6:
|
||||||
|
var i = 0 // Should be 1, be we iterate before it in while using 0.
|
||||||
|
var u2 = nextDouble()
|
||||||
|
var umin = u2
|
||||||
|
|
||||||
|
// Step 7 and 8:
|
||||||
|
do {
|
||||||
|
++i
|
||||||
|
u2 = nextDouble()
|
||||||
|
if (u2 < umin) umin = u2
|
||||||
|
// Step 8:
|
||||||
|
} while (u > EXPONENTIAL_SA_QI[i]) // Ensured to exit since EXPONENTIAL_SA_QI[MAX] = 1.
|
||||||
|
|
||||||
|
mean * (a + umin * EXPONENTIAL_SA_QI[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "Ahrens-Dieter Exponential deviate"
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private val EXPONENTIAL_SA_QI = DoubleArray(16)
|
||||||
|
|
||||||
|
fun of(mean: Double): Sampler<Double> =
|
||||||
|
AhrensDieterExponentialSampler(mean)
|
||||||
|
|
||||||
|
init {
|
||||||
|
/**
|
||||||
|
* Filling EXPONENTIAL_SA_QI table.
|
||||||
|
* Note that we don't want qi = 0 in the table.
|
||||||
|
*/
|
||||||
|
val ln2 = ln(2.0)
|
||||||
|
var qi = 0.0
|
||||||
|
|
||||||
|
EXPONENTIAL_SA_QI.indices.forEach { i ->
|
||||||
|
qi += ln2.pow(i + 1.0) / InternalUtils.factorial(i + 1)
|
||||||
|
EXPONENTIAL_SA_QI[i] = qi
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
import scientifik.kmath.prob.chain
|
||||||
|
import kotlin.math.*
|
||||||
|
|
||||||
|
class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> {
|
||||||
|
private var nextGaussian: Double = Double.NaN
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||||
|
val random: Double
|
||||||
|
|
||||||
|
if (nextGaussian.isNaN()) {
|
||||||
|
// Generate a pair of Gaussian numbers.
|
||||||
|
val x = nextDouble()
|
||||||
|
val y = nextDouble()
|
||||||
|
val alpha = 2 * PI * x
|
||||||
|
val r = sqrt(-2 * ln(y))
|
||||||
|
// Return the first element of the generated pair.
|
||||||
|
random = r * cos(alpha)
|
||||||
|
// Keep second element of the pair for next invocation.
|
||||||
|
nextGaussian = r * sin(alpha)
|
||||||
|
} else {
|
||||||
|
// Use the second element of the pair (generated at the
|
||||||
|
// previous invocation).
|
||||||
|
random = nextGaussian
|
||||||
|
// Both elements of the pair have been used.
|
||||||
|
nextGaussian = Double.NaN
|
||||||
|
}
|
||||||
|
|
||||||
|
random
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "Box-Muller normalized Gaussian deviate"
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.map
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
|
||||||
|
class GaussianSampler(
|
||||||
|
private val mean: Double,
|
||||||
|
private val standardDeviation: Double,
|
||||||
|
private val normalized: NormalizedGaussianSampler
|
||||||
|
) : Sampler<Double> {
|
||||||
|
|
||||||
|
init {
|
||||||
|
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Double> =
|
||||||
|
normalized.sample(generator).map { standardDeviation * it + mean }
|
||||||
|
|
||||||
|
override fun toString(): String = "Gaussian deviate [$normalized]"
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun of(
|
||||||
|
normalized: NormalizedGaussianSampler,
|
||||||
|
mean: Double,
|
||||||
|
standardDeviation: Double
|
||||||
|
) = GaussianSampler(
|
||||||
|
mean,
|
||||||
|
standardDeviation,
|
||||||
|
normalized
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
@ -1,10 +1,10 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
|
|
||||||
internal object InternalGamma {
|
internal object InternalGamma {
|
||||||
const val LANCZOS_G = 607.0 / 128.0
|
private const val LANCZOS_G = 607.0 / 128.0
|
||||||
|
|
||||||
private val LANCZOS_COEFFICIENTS = doubleArrayOf(
|
private val LANCZOS_COEFFICIENTS = doubleArrayOf(
|
||||||
0.99999999999999709182,
|
0.99999999999999709182,
|
@ -1,7 +1,5 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
import scientifik.kmath.commons.rng.sampling.SharedStateSampler
|
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
|
|
||||||
@ -20,8 +18,8 @@ internal object InternalUtils {
|
|||||||
|
|
||||||
fun factorial(n: Int): Long = FACTORIALS[n]
|
fun factorial(n: Int): Long = FACTORIALS[n]
|
||||||
|
|
||||||
fun validateProbabilities(probabilities: DoubleArray?): Double {
|
fun validateProbabilities(probabilities: DoubleArray): Double {
|
||||||
require(!(probabilities == null || probabilities.isEmpty())) { "Probabilities must not be empty." }
|
require(probabilities.isNotEmpty()) { "Probabilities must not be empty." }
|
||||||
var sumProb = 0.0
|
var sumProb = 0.0
|
||||||
|
|
||||||
probabilities.forEach { prob ->
|
probabilities.forEach { prob ->
|
||||||
@ -33,23 +31,10 @@ internal object InternalUtils {
|
|||||||
return sumProb
|
return sumProb
|
||||||
}
|
}
|
||||||
|
|
||||||
fun validateProbability(probability: Double): Unit =
|
private fun validateProbability(probability: Double): Unit =
|
||||||
require(!(probability < 0 || probability.isInfinite() || probability.isNaN())) { "Invalid probability: $probability" }
|
require(!(probability < 0 || probability.isInfinite() || probability.isNaN())) {
|
||||||
|
"Invalid probability: $probability"
|
||||||
fun newNormalizedGaussianSampler(
|
}
|
||||||
sampler: NormalizedGaussianSampler,
|
|
||||||
rng: UniformRandomProvider
|
|
||||||
): NormalizedGaussianSampler {
|
|
||||||
if (sampler !is SharedStateSampler<*>) throw UnsupportedOperationException("The underlying sampler cannot share state")
|
|
||||||
|
|
||||||
val newSampler: Any =
|
|
||||||
(sampler as SharedStateSampler<*>).withUniformRandomProvider(rng) as? NormalizedGaussianSampler
|
|
||||||
?: throw UnsupportedOperationException(
|
|
||||||
"The underlying sampler did not create a normalized Gaussian sampler"
|
|
||||||
)
|
|
||||||
|
|
||||||
return newSampler as NormalizedGaussianSampler
|
|
||||||
}
|
|
||||||
|
|
||||||
class FactorialLog private constructor(
|
class FactorialLog private constructor(
|
||||||
numValues: Int,
|
numValues: Int,
|
||||||
@ -68,23 +53,19 @@ internal object InternalUtils {
|
|||||||
BEGIN_LOG_FACTORIALS, endCopy)
|
BEGIN_LOG_FACTORIALS, endCopy)
|
||||||
}
|
}
|
||||||
// All values to be computed
|
// All values to be computed
|
||||||
else
|
else endCopy = BEGIN_LOG_FACTORIALS
|
||||||
endCopy =
|
|
||||||
BEGIN_LOG_FACTORIALS
|
|
||||||
|
|
||||||
// Compute remaining values.
|
// Compute remaining values.
|
||||||
(endCopy until numValues).forEach { i ->
|
(endCopy until numValues).forEach { i ->
|
||||||
if (i < FACTORIALS.size) logFactorials[i] = ln(
|
if (i < FACTORIALS.size)
|
||||||
FACTORIALS[i].toDouble()) else logFactorials[i] =
|
logFactorials[i] = ln(FACTORIALS[i].toDouble())
|
||||||
logFactorials[i - 1] + ln(i.toDouble())
|
else
|
||||||
|
logFactorials[i] = logFactorials[i - 1] + ln(i.toDouble())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun withCache(cacheSize: Int): FactorialLog =
|
fun withCache(cacheSize: Int): FactorialLog =
|
||||||
FactorialLog(
|
FactorialLog(cacheSize, logFactorials)
|
||||||
cacheSize,
|
|
||||||
logFactorials
|
|
||||||
)
|
|
||||||
|
|
||||||
fun value(n: Int): Double {
|
fun value(n: Int): Double {
|
||||||
if (n < logFactorials.size)
|
if (n < logFactorials.size)
|
||||||
@ -98,10 +79,7 @@ internal object InternalUtils {
|
|||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun create(): FactorialLog =
|
fun create(): FactorialLog =
|
||||||
FactorialLog(
|
FactorialLog(0, null)
|
||||||
0,
|
|
||||||
null
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,14 +1,16 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
import scientifik.kmath.prob.chain
|
||||||
import kotlin.math.exp
|
import kotlin.math.exp
|
||||||
|
|
||||||
class KempSmallMeanPoissonSampler private constructor(
|
class KempSmallMeanPoissonSampler private constructor(
|
||||||
private val rng: UniformRandomProvider,
|
|
||||||
private val p0: Double,
|
private val p0: Double,
|
||||||
private val mean: Double
|
private val mean: Double
|
||||||
) : SharedStateDiscreteSampler {
|
) : Sampler<Int> {
|
||||||
override fun sample(): Int {
|
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||||
// Note on the algorithm:
|
// Note on the algorithm:
|
||||||
// - X is the unknown sample deviate (the output of the algorithm)
|
// - X is the unknown sample deviate (the output of the algorithm)
|
||||||
// - x is the current value from the distribution
|
// - x is the current value from the distribution
|
||||||
@ -16,7 +18,7 @@ class KempSmallMeanPoissonSampler private constructor(
|
|||||||
// - u is effectively the cumulative probability that the sample X
|
// - u is effectively the cumulative probability that the sample X
|
||||||
// is equal or above the current value x, p(X>=x)
|
// is equal or above the current value x, p(X>=x)
|
||||||
// So if p(X>=x) > p(X=x) the sample must be above x, otherwise it is x
|
// So if p(X>=x) > p(X=x) the sample must be above x, otherwise it is x
|
||||||
var u = rng.nextDouble()
|
var u = nextDouble()
|
||||||
var x = 0
|
var x = 0
|
||||||
var p = p0
|
var p = p0
|
||||||
|
|
||||||
@ -28,31 +30,20 @@ class KempSmallMeanPoissonSampler private constructor(
|
|||||||
// The algorithm listed in Kemp (1981) does not check that the rolling probability
|
// The algorithm listed in Kemp (1981) does not check that the rolling probability
|
||||||
// is positive. This check is added to ensure no errors when the limit of the summation
|
// is positive. This check is added to ensure no errors when the limit of the summation
|
||||||
// 1 - sum(p(x)) is above 0 due to cumulative error in floating point arithmetic.
|
// 1 - sum(p(x)) is above 0 due to cumulative error in floating point arithmetic.
|
||||||
if (p == 0.0) return x
|
if (p == 0.0) return@chain x
|
||||||
}
|
}
|
||||||
|
|
||||||
return x
|
x
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "Kemp Small Mean Poisson deviate [$rng]"
|
override fun toString(): String = "Kemp Small Mean Poisson deviate"
|
||||||
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler =
|
|
||||||
KempSmallMeanPoissonSampler(rng, p0, mean)
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun of(
|
fun of(mean: Double): KempSmallMeanPoissonSampler {
|
||||||
rng: UniformRandomProvider,
|
|
||||||
mean: Double
|
|
||||||
): SharedStateDiscreteSampler {
|
|
||||||
require(mean > 0) { "Mean is not strictly positive: $mean" }
|
require(mean > 0) { "Mean is not strictly positive: $mean" }
|
||||||
val p0: Double = exp(-mean)
|
val p0: Double = exp(-mean)
|
||||||
|
|
||||||
// Probability must be positive. As mean increases then p(0) decreases.
|
// Probability must be positive. As mean increases then p(0) decreases.
|
||||||
if (p0 > 0) return KempSmallMeanPoissonSampler(
|
if (p0 > 0) return KempSmallMeanPoissonSampler(p0, mean)
|
||||||
rng,
|
|
||||||
p0,
|
|
||||||
mean
|
|
||||||
)
|
|
||||||
throw IllegalArgumentException("No probability for mean: $mean")
|
throw IllegalArgumentException("No probability for mean: $mean")
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -0,0 +1,132 @@
|
|||||||
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
|
import kotlinx.coroutines.flow.first
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.ConstantChain
|
||||||
|
import scientifik.kmath.chains.SimpleChain
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
import scientifik.kmath.prob.next
|
||||||
|
import kotlin.math.*
|
||||||
|
|
||||||
|
class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
|
||||||
|
private val exponential: Sampler<Double> =
|
||||||
|
AhrensDieterExponentialSampler.of(1.0)
|
||||||
|
private val gaussian: Sampler<Double> =
|
||||||
|
ZigguratNormalizedGaussianSampler()
|
||||||
|
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG!!
|
||||||
|
private val lambda: Double = floor(mean)
|
||||||
|
private val logLambda: Double = ln(lambda)
|
||||||
|
private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt())
|
||||||
|
private val delta: Double = sqrt(lambda * ln(32 * lambda / PI + 1))
|
||||||
|
private val halfDelta: Double = delta / 2
|
||||||
|
private val twolpd: Double = 2 * lambda + delta
|
||||||
|
private val c1: Double = 1 / (8 * lambda)
|
||||||
|
private val a1: Double = sqrt(PI * twolpd) * exp(c1)
|
||||||
|
private val a2: Double = twolpd / delta * exp(-delta * (1 + delta) / twolpd)
|
||||||
|
private val aSum: Double = a1 + a2 + 1
|
||||||
|
private val p1: Double = a1 / aSum
|
||||||
|
private val p2: Double = a2 / aSum
|
||||||
|
|
||||||
|
private val smallMeanPoissonSampler: Sampler<Int> = if (mean - lambda < Double.MIN_VALUE)
|
||||||
|
NO_SMALL_MEAN_POISSON_SAMPLER
|
||||||
|
else // Not used.
|
||||||
|
KempSmallMeanPoissonSampler.of(mean - lambda)
|
||||||
|
|
||||||
|
init {
|
||||||
|
require(mean >= 1) { "mean is not >= 1: $mean" }
|
||||||
|
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
||||||
|
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Int> = SimpleChain {
|
||||||
|
// This will never be null. It may be a no-op delegate that returns zero.
|
||||||
|
val y2 = smallMeanPoissonSampler.next(generator)
|
||||||
|
var x: Double
|
||||||
|
var y: Double
|
||||||
|
var v: Double
|
||||||
|
var a: Int
|
||||||
|
var t: Double
|
||||||
|
var qr: Double
|
||||||
|
var qa: Double
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
// Step 1:
|
||||||
|
val u = generator.nextDouble()
|
||||||
|
|
||||||
|
if (u <= p1) {
|
||||||
|
// Step 2:
|
||||||
|
val n = gaussian.next(generator)
|
||||||
|
x = n * sqrt(lambda + halfDelta) - 0.5
|
||||||
|
if (x > delta || x < -lambda) continue
|
||||||
|
y = if (x < 0) floor(x) else ceil(x)
|
||||||
|
val e = exponential.next(generator)
|
||||||
|
v = -e - 0.5 * n * n + c1
|
||||||
|
} else {
|
||||||
|
// Step 3:
|
||||||
|
if (u > p1 + p2) {
|
||||||
|
y = lambda
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
x = delta + twolpd / delta * exponential.next(generator)
|
||||||
|
y = ceil(x)
|
||||||
|
v = -exponential.next(generator) - delta * (x + 1) / twolpd
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Squeeze Principle
|
||||||
|
// Step 4.1:
|
||||||
|
a = if (x < 0) 1 else 0
|
||||||
|
t = y * (y + 1) / (2 * lambda)
|
||||||
|
|
||||||
|
// Step 4.2
|
||||||
|
if (v < -t && a == 0) {
|
||||||
|
y += lambda
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4.3:
|
||||||
|
qr = t * ((2 * y + 1) / (6 * lambda) - 1)
|
||||||
|
qa = qr - t * t / (3 * (lambda + a * (y + 1)))
|
||||||
|
|
||||||
|
// Step 4.4:
|
||||||
|
if (v < qa) {
|
||||||
|
y += lambda
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4.5:
|
||||||
|
if (v > qr) continue
|
||||||
|
|
||||||
|
// Step 4.6:
|
||||||
|
if (v < y * logLambda - getFactorialLog((y + lambda).toInt()) + logLambdaFactorial) {
|
||||||
|
y += lambda
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun getFactorialLog(n: Int): Double = factorialLog.value(n)
|
||||||
|
|
||||||
|
override fun toString(): String = "Large Mean Poisson deviate"
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val MAX_MEAN = 0.5 * Int.MAX_VALUE
|
||||||
|
private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null
|
||||||
|
|
||||||
|
private val NO_SMALL_MEAN_POISSON_SAMPLER = object : Sampler<Int> {
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun of(mean: Double): LargeMeanPoissonSampler =
|
||||||
|
LargeMeanPoissonSampler(mean)
|
||||||
|
|
||||||
|
init {
|
||||||
|
// Create without a cache.
|
||||||
|
NO_CACHE_FACTORIAL_LOG =
|
||||||
|
InternalUtils.FactorialLog.create()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,35 +1,42 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
import scientifik.kmath.prob.chain
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
class MarsagliaNormalizedGaussianSampler(private val rng: UniformRandomProvider) :
|
class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> {
|
||||||
NormalizedGaussianSampler,
|
|
||||||
SharedStateContinuousSampler {
|
|
||||||
private var nextGaussian = Double.NaN
|
private var nextGaussian = Double.NaN
|
||||||
|
|
||||||
override fun sample(): Double {
|
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||||
if (nextGaussian.isNaN()) {
|
if (nextGaussian.isNaN()) {
|
||||||
|
val alpha: Double
|
||||||
|
var x: Double
|
||||||
|
|
||||||
// Rejection scheme for selecting a pair that lies within the unit circle.
|
// Rejection scheme for selecting a pair that lies within the unit circle.
|
||||||
while (true) {
|
while (true) {
|
||||||
// Generate a pair of numbers within [-1 , 1).
|
// Generate a pair of numbers within [-1 , 1).
|
||||||
val x = 2.0 * rng.nextDouble() - 1.0
|
x = 2.0 * generator.nextDouble() - 1.0
|
||||||
val y = 2.0 * rng.nextDouble() - 1.0
|
val y = 2.0 * generator.nextDouble() - 1.0
|
||||||
val r2 = x * x + y * y
|
val r2 = x * x + y * y
|
||||||
|
|
||||||
if (r2 < 1 && r2 > 0) {
|
if (r2 < 1 && r2 > 0) {
|
||||||
// Pair (x, y) is within unit circle.
|
// Pair (x, y) is within unit circle.
|
||||||
val alpha = sqrt(-2 * ln(r2) / r2)
|
alpha = sqrt(-2 * ln(r2) / r2)
|
||||||
|
|
||||||
// Keep second element of the pair for next invocation.
|
// Keep second element of the pair for next invocation.
|
||||||
nextGaussian = alpha * y
|
nextGaussian = alpha * y
|
||||||
|
|
||||||
// Return the first element of the generated pair.
|
// Return the first element of the generated pair.
|
||||||
return alpha * x
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pair is not within the unit circle: Generate another one.
|
// Pair is not within the unit circle: Generate another one.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the first element of the generated pair.
|
||||||
|
alpha * x
|
||||||
} else {
|
} else {
|
||||||
// Use the second element of the pair (generated at the
|
// Use the second element of the pair (generated at the
|
||||||
// previous invocation).
|
// previous invocation).
|
||||||
@ -37,18 +44,9 @@ class MarsagliaNormalizedGaussianSampler(private val rng: UniformRandomProvider)
|
|||||||
|
|
||||||
// Both elements of the pair have been used.
|
// Both elements of the pair have been used.
|
||||||
nextGaussian = Double.NaN
|
nextGaussian = Double.NaN
|
||||||
return r
|
r
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate [$rng]"
|
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
|
||||||
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler =
|
|
||||||
MarsagliaNormalizedGaussianSampler(rng)
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
fun <S> of(rng: UniformRandomProvider): S where S : NormalizedGaussianSampler?, S : SharedStateContinuousSampler? =
|
|
||||||
MarsagliaNormalizedGaussianSampler(rng) as S
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -0,0 +1,5 @@
|
|||||||
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
|
||||||
|
interface NormalizedGaussianSampler : Sampler<Double>
|
@ -0,0 +1,29 @@
|
|||||||
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
|
||||||
|
|
||||||
|
class PoissonSampler(
|
||||||
|
mean: Double
|
||||||
|
) : Sampler<Int> {
|
||||||
|
private val poissonSamplerDelegate: Sampler<Int>
|
||||||
|
|
||||||
|
init {
|
||||||
|
// Delegate all work to specialised samplers.
|
||||||
|
poissonSamplerDelegate = of(mean)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
|
||||||
|
override fun toString(): String = poissonSamplerDelegate.toString()
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val PIVOT = 40.0
|
||||||
|
|
||||||
|
fun of(mean: Double) =// Each sampler should check the input arguments.
|
||||||
|
if (mean < PIVOT) SmallMeanPoissonSampler.of(
|
||||||
|
mean
|
||||||
|
) else LargeMeanPoissonSampler.of(mean)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,43 @@
|
|||||||
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.SimpleChain
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
import scientifik.kmath.prob.chain
|
||||||
|
import kotlin.math.ceil
|
||||||
|
import kotlin.math.exp
|
||||||
|
|
||||||
|
class SmallMeanPoissonSampler(mean: Double) : Sampler<Int> {
|
||||||
|
private val p0: Double
|
||||||
|
private val limit: Int
|
||||||
|
|
||||||
|
init {
|
||||||
|
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||||
|
p0 = exp(-mean)
|
||||||
|
|
||||||
|
limit = (if (p0 > 0)
|
||||||
|
ceil(1000 * mean)
|
||||||
|
else
|
||||||
|
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||||
|
var n = 0
|
||||||
|
var r = 1.0
|
||||||
|
|
||||||
|
while (n < limit) {
|
||||||
|
r *= nextDouble()
|
||||||
|
if (r >= p0) n++ else break
|
||||||
|
}
|
||||||
|
|
||||||
|
n
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "Small Mean Poisson deviate"
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun of(mean: Double): SmallMeanPoissonSampler =
|
||||||
|
SmallMeanPoissonSampler(mean)
|
||||||
|
}
|
||||||
|
}
|
@ -1,38 +1,76 @@
|
|||||||
package scientifik.kmath.commons.rng.sampling.distribution
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.SimpleChain
|
||||||
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.Sampler
|
||||||
|
import scientifik.kmath.prob.chain
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
class ZigguratNormalizedGaussianSampler(private val rng: UniformRandomProvider) :
|
class ZigguratNormalizedGaussianSampler() :
|
||||||
NormalizedGaussianSampler,
|
NormalizedGaussianSampler, Sampler<Double> {
|
||||||
SharedStateContinuousSampler {
|
|
||||||
|
private fun sampleOne(generator: RandomGenerator): Double {
|
||||||
|
val j = generator.nextLong()
|
||||||
|
val i = (j and LAST.toLong()).toInt()
|
||||||
|
return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { sampleOne(this) }
|
||||||
|
|
||||||
|
override fun toString(): String = "Ziggurat normalized Gaussian deviate"
|
||||||
|
|
||||||
|
private fun fix(
|
||||||
|
generator: RandomGenerator,
|
||||||
|
hz: Long,
|
||||||
|
iz: Int
|
||||||
|
): Double {
|
||||||
|
var x: Double
|
||||||
|
var y: Double
|
||||||
|
x = hz * W[iz]
|
||||||
|
|
||||||
|
return if (iz == 0) {
|
||||||
|
// Base strip.
|
||||||
|
// This branch is called about 5.7624515E-4 times per sample.
|
||||||
|
do {
|
||||||
|
y = -ln(generator.nextDouble())
|
||||||
|
x = -ln(generator.nextDouble()) * ONE_OVER_R
|
||||||
|
} while (y + y < x * x)
|
||||||
|
|
||||||
|
val out = R + x
|
||||||
|
if (hz > 0) out else -out
|
||||||
|
} else {
|
||||||
|
// Wedge of other strips.
|
||||||
|
// This branch is called about 0.027323 times per sample.
|
||||||
|
// else
|
||||||
|
// Try again.
|
||||||
|
// This branch is called about 0.012362 times per sample.
|
||||||
|
if (F[iz] + generator.nextDouble() * (F[iz - 1] - F[iz]) < gauss(
|
||||||
|
x
|
||||||
|
)
|
||||||
|
) x else sampleOne(generator)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private const val R = 3.442619855899
|
private const val R: Double = 3.442619855899
|
||||||
private const val ONE_OVER_R: Double = 1 / R
|
private const val ONE_OVER_R: Double = 1 / R
|
||||||
private const val V = 9.91256303526217e-3
|
private const val V = 9.91256303526217e-3
|
||||||
private val MAX: Double = 2.0.pow(63.0)
|
private val MAX: Double = 2.0.pow(63.0)
|
||||||
private val ONE_OVER_MAX: Double = 1.0 / MAX
|
private val ONE_OVER_MAX: Double = 1.0 / MAX
|
||||||
private const val LEN = 128
|
private const val LEN: Int = 128
|
||||||
private const val LAST: Int = LEN - 1
|
private const val LAST: Int = LEN - 1
|
||||||
private val K = LongArray(LEN)
|
private val K: LongArray = LongArray(LEN)
|
||||||
private val W = DoubleArray(LEN)
|
private val W: DoubleArray = DoubleArray(LEN)
|
||||||
private val F = DoubleArray(LEN)
|
private val F: DoubleArray = DoubleArray(LEN)
|
||||||
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
fun <S> of(rng: UniformRandomProvider): S where S : NormalizedGaussianSampler?, S : SharedStateContinuousSampler? =
|
|
||||||
ZigguratNormalizedGaussianSampler(rng) as S
|
|
||||||
|
|
||||||
init {
|
init {
|
||||||
// Filling the tables.
|
// Filling the tables.
|
||||||
var d =
|
var d = R
|
||||||
R
|
|
||||||
var t = d
|
var t = d
|
||||||
var fd =
|
var fd =
|
||||||
gauss(
|
gauss(d)
|
||||||
d
|
|
||||||
)
|
|
||||||
val q = V / fd
|
val q = V / fd
|
||||||
K[0] = (d / q * MAX).toLong()
|
K[0] = (d / q * MAX).toLong()
|
||||||
K[1] = 0
|
K[1] = 0
|
||||||
@ -54,46 +92,4 @@ class ZigguratNormalizedGaussianSampler(private val rng: UniformRandomProvider)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun sample(): Double {
|
|
||||||
val j = rng.nextLong()
|
|
||||||
val i = (j and LAST.toLong()).toInt()
|
|
||||||
return if (abs(j) < K[i]) j * W[i] else fix(j, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun toString(): String = "Ziggurat normalized Gaussian deviate [$rng]"
|
|
||||||
|
|
||||||
private fun fix(
|
|
||||||
hz: Long,
|
|
||||||
iz: Int
|
|
||||||
): Double {
|
|
||||||
var x: Double
|
|
||||||
var y: Double
|
|
||||||
x = hz * W[iz]
|
|
||||||
|
|
||||||
return if (iz == 0) {
|
|
||||||
// Base strip.
|
|
||||||
// This branch is called about 5.7624515E-4 times per sample.
|
|
||||||
do {
|
|
||||||
y = -ln(rng.nextDouble())
|
|
||||||
x = -ln(rng.nextDouble()) * ONE_OVER_R
|
|
||||||
} while (y + y < x * x)
|
|
||||||
|
|
||||||
val out = R + x
|
|
||||||
if (hz > 0) out else -out
|
|
||||||
} else {
|
|
||||||
// Wedge of other strips.
|
|
||||||
// This branch is called about 0.027323 times per sample.
|
|
||||||
// else
|
|
||||||
// Try again.
|
|
||||||
// This branch is called about 0.012362 times per sample.
|
|
||||||
if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < gauss(
|
|
||||||
x
|
|
||||||
)
|
|
||||||
) x else sample()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler =
|
|
||||||
ZigguratNormalizedGaussianSampler(rng)
|
|
||||||
}
|
}
|
@ -1,66 +0,0 @@
|
|||||||
package scientifik.kmath.prob
|
|
||||||
|
|
||||||
import scientifik.kmath.commons.rng.sampling.distribution.ContinuousSampler
|
|
||||||
import scientifik.kmath.commons.rng.sampling.distribution.DiscreteSampler
|
|
||||||
import scientifik.kmath.commons.rng.sampling.distribution.GaussianSampler
|
|
||||||
import scientifik.kmath.commons.rng.sampling.distribution.PoissonSampler
|
|
||||||
import kotlin.math.PI
|
|
||||||
import kotlin.math.exp
|
|
||||||
import kotlin.math.pow
|
|
||||||
import kotlin.math.sqrt
|
|
||||||
|
|
||||||
fun Distribution.Companion.normal(
|
|
||||||
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
|
||||||
): Distribution<Double> = object : ContinuousSamplerDistribution() {
|
|
||||||
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
|
||||||
val provider = generator.asUniformRandomProvider()
|
|
||||||
return normalSampler(method, provider)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun probability(arg: Double): Double {
|
|
||||||
return exp(-arg.pow(2) / 2) / sqrt(PI * 2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.normal(
|
|
||||||
mean: Double,
|
|
||||||
sigma: Double,
|
|
||||||
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
|
||||||
): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() {
|
|
||||||
private val sigma2 = sigma.pow(2)
|
|
||||||
private val norm = sigma * sqrt(PI * 2)
|
|
||||||
|
|
||||||
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
|
||||||
val provider = generator.asUniformRandomProvider()
|
|
||||||
val normalizedSampler = normalSampler(method, provider)
|
|
||||||
return GaussianSampler(
|
|
||||||
normalizedSampler,
|
|
||||||
mean,
|
|
||||||
sigma
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun probability(arg: Double): Double {
|
|
||||||
return exp(-(arg - mean).pow(2) / 2 / sigma2) / norm
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.poisson(
|
|
||||||
lambda: Double
|
|
||||||
): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() {
|
|
||||||
|
|
||||||
override fun buildSampler(generator: RandomGenerator): DiscreteSampler {
|
|
||||||
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda)
|
|
||||||
}
|
|
||||||
|
|
||||||
private val computedProb: HashMap<Int, Double> = hashMapOf(0 to exp(-lambda))
|
|
||||||
|
|
||||||
override fun probability(arg: Int): Double {
|
|
||||||
require(arg >= 0) { "The argument must be >= 0" }
|
|
||||||
|
|
||||||
return if (arg > 40)
|
|
||||||
exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda)
|
|
||||||
else
|
|
||||||
computedProb.getOrPut(arg) { probability(arg - 1) * lambda / arg }
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,11 +1,10 @@
|
|||||||
package scientifik.kmath.prob
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
import org.apache.commons.rng.simple.RandomSource
|
import org.apache.commons.rng.simple.RandomSource
|
||||||
import scientifik.kmath.commons.rng.UniformRandomProvider
|
|
||||||
|
|
||||||
class RandomSourceGenerator(val source: RandomSource, seed: Long?) :
|
class RandomSourceGenerator(private val source: RandomSource, seed: Long?) :
|
||||||
RandomGenerator {
|
RandomGenerator {
|
||||||
internal val random = seed?.let {
|
private val random = seed?.let {
|
||||||
RandomSource.create(source, seed)
|
RandomSource.create(source, seed)
|
||||||
} ?: RandomSource.create(source)
|
} ?: RandomSource.create(source)
|
||||||
|
|
||||||
@ -24,24 +23,6 @@ class RandomSourceGenerator(val source: RandomSource, seed: Long?) :
|
|||||||
override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong())
|
override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong())
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Represent this [RandomGenerator] as commons-rng [UniformRandomProvider] preserving and mirroring its current state.
|
|
||||||
* Getting new value from one of those changes the state of another.
|
|
||||||
*/
|
|
||||||
fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this is RandomSourceGenerator) {
|
|
||||||
object : UniformRandomProvider {
|
|
||||||
override fun nextBytes(bytes: ByteArray) = random.nextBytes(bytes)
|
|
||||||
override fun nextBytes(bytes: ByteArray, start: Int, len: Int) = random.nextBytes(bytes, start, len)
|
|
||||||
override fun nextInt(): Int = random.nextInt()
|
|
||||||
override fun nextInt(n: Int): Int = random.nextInt(n)
|
|
||||||
override fun nextLong(): Long = random.nextLong()
|
|
||||||
override fun nextLong(n: Long): Long = random.nextLong(n)
|
|
||||||
override fun nextBoolean(): Boolean = random.nextBoolean()
|
|
||||||
override fun nextFloat(): Float = random.nextFloat()
|
|
||||||
override fun nextDouble(): Double = random.nextDouble()
|
|
||||||
}
|
|
||||||
} else RandomGeneratorProvider(this)
|
|
||||||
|
|
||||||
fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
|
fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
|
||||||
RandomSourceGenerator(source, seed)
|
RandomSourceGenerator(source, seed)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user