Implement Commons RNG-like samplers in kmath-prob module for Multiplatform #164
2
kmath-commons-rng-part/build.gradle.kts
Normal file
2
kmath-commons-rng-part/build.gradle.kts
Normal file
@ -0,0 +1,2 @@
|
||||
plugins { id("scientifik.mpp") }
|
||||
kotlin.sourceSets { commonMain.get().dependencies { api(project(":kmath-coroutines")) } }
|
@ -0,0 +1,19 @@
|
||||
package scientifik.commons.rng
|
||||
|
||||
interface UniformRandomProvider {
|
||||
fun nextBytes(bytes: ByteArray)
|
||||
|
||||
fun nextBytes(
|
||||
bytes: ByteArray,
|
||||
start: Int,
|
||||
len: Int
|
||||
)
|
||||
|
||||
fun nextInt(): Int
|
||||
fun nextInt(n: Int): Int
|
||||
fun nextLong(): Long
|
||||
fun nextLong(n: Long): Long
|
||||
fun nextBoolean(): Boolean
|
||||
fun nextFloat(): Float
|
||||
fun nextDouble(): Double
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package scientifik.commons.rng.sampling
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
|
||||
interface SharedStateSampler<R> {
|
||||
fun withUniformRandomProvider(rng: UniformRandomProvider): R
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.pow
|
||||
|
||||
class AhrensDieterExponentialSampler : SamplerBase,
|
||||
SharedStateContinuousSampler {
|
||||
private val mean: Double
|
||||
private val rng: UniformRandomProvider
|
||||
|
||||
constructor(
|
||||
rng: UniformRandomProvider,
|
||||
mean: Double
|
||||
) : super(null) {
|
||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||
this.rng = rng
|
||||
this.mean = mean
|
||||
}
|
||||
|
||||
private constructor(
|
||||
rng: UniformRandomProvider,
|
||||
source: AhrensDieterExponentialSampler
|
||||
) : super(null) {
|
||||
this.rng = rng
|
||||
mean = source.mean
|
||||
}
|
||||
|
||||
override fun sample(): Double {
|
||||
// Step 1:
|
||||
var a = 0.0
|
||||
var u: Double = rng.nextDouble()
|
||||
|
||||
// Step 2 and 3:
|
||||
while (u < 0.5) {
|
||||
a += EXPONENTIAL_SA_QI.get(
|
||||
0
|
||||
)
|
||||
u *= 2.0
|
||||
}
|
||||
|
||||
// Step 4 (now u >= 0.5):
|
||||
u += u - 1
|
||||
|
||||
// Step 5:
|
||||
if (u <= EXPONENTIAL_SA_QI.get(
|
||||
0
|
||||
)
|
||||
) {
|
||||
return mean * (a + u)
|
||||
}
|
||||
|
||||
// Step 6:
|
||||
var i = 0 // Should be 1, be we iterate before it in while using 0.
|
||||
var u2: Double = rng.nextDouble()
|
||||
var umin = u2
|
||||
|
||||
// Step 7 and 8:
|
||||
do {
|
||||
++i
|
||||
u2 = rng.nextDouble()
|
||||
if (u2 < umin) umin = u2
|
||||
// Step 8:
|
||||
} while (u > EXPONENTIAL_SA_QI[i]) // Ensured to exit since EXPONENTIAL_SA_QI[MAX] = 1.
|
||||
return mean * (a + umin * EXPONENTIAL_SA_QI[0])
|
||||
}
|
||||
|
||||
override fun toString(): String = "Ahrens-Dieter Exponential deviate [$rng]"
|
||||
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler =
|
||||
AhrensDieterExponentialSampler(rng, this)
|
||||
|
||||
companion object {
|
||||
private val EXPONENTIAL_SA_QI = DoubleArray(16)
|
||||
|
||||
fun of(
|
||||
rng: UniformRandomProvider,
|
||||
mean: Double
|
||||
): SharedStateContinuousSampler = AhrensDieterExponentialSampler(rng, mean)
|
||||
|
||||
init {
|
||||
/**
|
||||
* Filling EXPONENTIAL_SA_QI table.
|
||||
* Note that we don't want qi = 0 in the table.
|
||||
*/
|
||||
val ln2 = ln(2.0)
|
||||
var qi = 0.0
|
||||
|
||||
EXPONENTIAL_SA_QI.indices.forEach { i ->
|
||||
qi += ln2.pow(i + 1.0) / InternalUtils.factorial(i + 1)
|
||||
EXPONENTIAL_SA_QI[i] = qi
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
import kotlin.math.*
|
||||
|
||||
class BoxMullerNormalizedGaussianSampler(
|
||||
private val rng: UniformRandomProvider
|
||||
) :
|
||||
NormalizedGaussianSampler,
|
||||
SharedStateContinuousSampler {
|
||||
private var nextGaussian: Double = Double.NaN
|
||||
|
||||
override fun sample(): Double {
|
||||
val random: Double
|
||||
|
||||
if (nextGaussian.isNaN()) {
|
||||
// Generate a pair of Gaussian numbers.
|
||||
val x = rng.nextDouble()
|
||||
val y = rng.nextDouble()
|
||||
val alpha = 2 * PI * x
|
||||
val r = sqrt(-2 * ln(y))
|
||||
|
||||
// Return the first element of the generated pair.
|
||||
random = r * cos(alpha)
|
||||
|
||||
// Keep second element of the pair for next invocation.
|
||||
nextGaussian = r * sin(alpha)
|
||||
} else {
|
||||
// Use the second element of the pair (generated at the
|
||||
// previous invocation).
|
||||
random = nextGaussian
|
||||
|
||||
// Both elements of the pair have been used.
|
||||
nextGaussian = Double.NaN
|
||||
}
|
||||
|
||||
return random
|
||||
}
|
||||
|
||||
override fun toString(): String = "Box-Muller normalized Gaussian deviate [$rng]"
|
||||
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler =
|
||||
BoxMullerNormalizedGaussianSampler(rng)
|
||||
|
||||
companion object {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun <S> of(rng: UniformRandomProvider): S where S : NormalizedGaussianSampler?, S : SharedStateContinuousSampler? =
|
||||
BoxMullerNormalizedGaussianSampler(rng) as S
|
||||
}
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
interface ContinuousSampler {
|
||||
fun sample(): Double
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
interface DiscreteSampler {
|
||||
fun sample(): Int
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
|
||||
class GaussianSampler : SharedStateContinuousSampler {
|
||||
private val mean: Double
|
||||
private val standardDeviation: Double
|
||||
private val normalized: NormalizedGaussianSampler
|
||||
|
||||
constructor(
|
||||
normalized: NormalizedGaussianSampler,
|
||||
mean: Double,
|
||||
standardDeviation: Double
|
||||
) {
|
||||
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
|
||||
this.normalized = normalized
|
||||
this.mean = mean
|
||||
this.standardDeviation = standardDeviation
|
||||
}
|
||||
|
||||
private constructor(
|
||||
rng: UniformRandomProvider,
|
||||
source: GaussianSampler
|
||||
) {
|
||||
mean = source.mean
|
||||
standardDeviation = source.standardDeviation
|
||||
normalized = InternalUtils.newNormalizedGaussianSampler(source.normalized, rng)
|
||||
}
|
||||
|
||||
override fun sample(): Double = standardDeviation * normalized.sample() + mean
|
||||
|
||||
override fun toString(): String = "Gaussian deviate [$normalized]"
|
||||
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler {
|
||||
return GaussianSampler(rng, this)
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun of(
|
||||
normalized: NormalizedGaussianSampler,
|
||||
mean: Double,
|
||||
standardDeviation: Double
|
||||
): SharedStateContinuousSampler = GaussianSampler(normalized, mean, standardDeviation)
|
||||
}
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.ln
|
||||
|
||||
internal object InternalGamma {
|
||||
const val LANCZOS_G = 607.0 / 128.0
|
||||
|
||||
private val LANCZOS_COEFFICIENTS = doubleArrayOf(
|
||||
0.99999999999999709182,
|
||||
57.156235665862923517,
|
||||
-59.597960355475491248,
|
||||
14.136097974741747174,
|
||||
-0.49191381609762019978,
|
||||
.33994649984811888699e-4,
|
||||
.46523628927048575665e-4,
|
||||
-.98374475304879564677e-4,
|
||||
.15808870322491248884e-3,
|
||||
-.21026444172410488319e-3,
|
||||
.21743961811521264320e-3,
|
||||
-.16431810653676389022e-3,
|
||||
.84418223983852743293e-4,
|
||||
-.26190838401581408670e-4,
|
||||
.36899182659531622704e-5
|
||||
)
|
||||
|
||||
private val HALF_LOG_2_PI: Double = 0.5 * ln(2.0 * PI)
|
||||
|
||||
fun logGamma(x: Double): Double {
|
||||
// Stripped-down version of the same method defined in "Commons Math":
|
||||
// Unused "if" branches (for when x < 8) have been removed here since
|
||||
// this method is only used (by class "InternalUtils") in order to
|
||||
// compute log(n!) for x > 20.
|
||||
val sum = lanczos(x)
|
||||
val tmp = x + LANCZOS_G + 0.5
|
||||
return (x + 0.5) * ln(tmp) - tmp + HALF_LOG_2_PI + ln(sum / x)
|
||||
}
|
||||
|
||||
private fun lanczos(x: Double): Double {
|
||||
val sum = (LANCZOS_COEFFICIENTS.size - 1 downTo 1).sumByDouble { LANCZOS_COEFFICIENTS[it] / (x + it) }
|
||||
return sum + LANCZOS_COEFFICIENTS[0]
|
||||
}
|
||||
}
|
@ -0,0 +1,92 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
import scientifik.commons.rng.sampling.SharedStateSampler
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.min
|
||||
|
||||
internal object InternalUtils {
|
||||
private val FACTORIALS = longArrayOf(
|
||||
1L, 1L, 2L,
|
||||
6L, 24L, 120L,
|
||||
720L, 5040L, 40320L,
|
||||
362880L, 3628800L, 39916800L,
|
||||
479001600L, 6227020800L, 87178291200L,
|
||||
1307674368000L, 20922789888000L, 355687428096000L,
|
||||
6402373705728000L, 121645100408832000L, 2432902008176640000L
|
||||
)
|
||||
|
||||
private const val BEGIN_LOG_FACTORIALS = 2
|
||||
|
||||
fun factorial(n: Int): Long = FACTORIALS[n]
|
||||
|
||||
fun validateProbabilities(probabilities: DoubleArray?): Double {
|
||||
require(!(probabilities == null || probabilities.isEmpty())) { "Probabilities must not be empty." }
|
||||
var sumProb = 0.0
|
||||
|
||||
probabilities.forEach { prob ->
|
||||
validateProbability(prob)
|
||||
sumProb += prob
|
||||
}
|
||||
|
||||
require(!(sumProb.isInfinite() || sumProb <= 0)) { "Invalid sum of probabilities: $sumProb" }
|
||||
return sumProb
|
||||
}
|
||||
|
||||
fun validateProbability(probability: Double): Unit =
|
||||
require(!(probability < 0 || probability.isInfinite() || probability.isNaN())) { "Invalid probability: $probability" }
|
||||
|
||||
fun newNormalizedGaussianSampler(
|
||||
sampler: NormalizedGaussianSampler,
|
||||
rng: UniformRandomProvider
|
||||
): NormalizedGaussianSampler {
|
||||
if (sampler !is SharedStateSampler<*>) throw UnsupportedOperationException("The underlying sampler cannot share state")
|
||||
|
||||
val newSampler: Any =
|
||||
(sampler as SharedStateSampler<*>).withUniformRandomProvider(rng) as? NormalizedGaussianSampler
|
||||
?: throw UnsupportedOperationException(
|
||||
"The underlying sampler did not create a normalized Gaussian sampler"
|
||||
)
|
||||
|
||||
return newSampler as NormalizedGaussianSampler
|
||||
}
|
||||
|
||||
class FactorialLog private constructor(
|
||||
numValues: Int,
|
||||
cache: DoubleArray?
|
||||
) {
|
||||
private val logFactorials: DoubleArray = DoubleArray(numValues)
|
||||
|
||||
init {
|
||||
val endCopy: Int
|
||||
|
||||
if (cache != null && cache.size > BEGIN_LOG_FACTORIALS) {
|
||||
// Copy available values.
|
||||
endCopy = min(cache.size, numValues)
|
||||
cache.copyInto(logFactorials, BEGIN_LOG_FACTORIALS, BEGIN_LOG_FACTORIALS, endCopy)
|
||||
}
|
||||
// All values to be computed
|
||||
else
|
||||
endCopy = BEGIN_LOG_FACTORIALS
|
||||
|
||||
// Compute remaining values.
|
||||
(endCopy until numValues).forEach { i ->
|
||||
if (i < FACTORIALS.size) logFactorials[i] = ln(FACTORIALS[i].toDouble()) else logFactorials[i] =
|
||||
logFactorials[i - 1] + ln(i.toDouble())
|
||||
}
|
||||
}
|
||||
|
||||
fun withCache(cacheSize: Int): FactorialLog = FactorialLog(cacheSize, logFactorials)
|
||||
|
||||
fun value(n: Int): Double {
|
||||
if (n < logFactorials.size)
|
||||
return logFactorials[n]
|
||||
|
||||
return if (n < FACTORIALS.size) ln(FACTORIALS[n].toDouble()) else InternalGamma.logGamma(n + 1.0)
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun create(): FactorialLog = FactorialLog(0, null)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
import kotlin.math.exp
|
||||
|
||||
class KempSmallMeanPoissonSampler private constructor(
|
||||
private val rng: UniformRandomProvider,
|
||||
private val p0: Double,
|
||||
private val mean: Double
|
||||
) : SharedStateDiscreteSampler {
|
||||
override fun sample(): Int {
|
||||
// Note on the algorithm:
|
||||
// - X is the unknown sample deviate (the output of the algorithm)
|
||||
// - x is the current value from the distribution
|
||||
// - p is the probability of the current value x, p(X=x)
|
||||
// - u is effectively the cumulative probability that the sample X
|
||||
// is equal or above the current value x, p(X>=x)
|
||||
// So if p(X>=x) > p(X=x) the sample must be above x, otherwise it is x
|
||||
var u = rng.nextDouble()
|
||||
var x = 0
|
||||
var p = p0
|
||||
|
||||
while (u > p) {
|
||||
u -= p
|
||||
// Compute the next probability using a recurrence relation.
|
||||
// p(x+1) = p(x) * mean / (x+1)
|
||||
p *= mean / ++x
|
||||
// The algorithm listed in Kemp (1981) does not check that the rolling probability
|
||||
// is positive. This check is added to ensure no errors when the limit of the summation
|
||||
// 1 - sum(p(x)) is above 0 due to cumulative error in floating point arithmetic.
|
||||
if (p == 0.0) return x
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
override fun toString(): String = "Kemp Small Mean Poisson deviate [$rng]"
|
||||
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler =
|
||||
KempSmallMeanPoissonSampler(rng, p0, mean)
|
||||
|
||||
companion object {
|
||||
fun of(
|
||||
rng: UniformRandomProvider,
|
||||
mean: Double
|
||||
): SharedStateDiscreteSampler {
|
||||
require(mean > 0) { "Mean is not strictly positive: $mean" }
|
||||
val p0: Double = exp(-mean)
|
||||
|
||||
// Probability must be positive. As mean increases then p(0) decreases.
|
||||
if (p0 > 0) return KempSmallMeanPoissonSampler(rng, p0, mean)
|
||||
throw IllegalArgumentException("No probability for mean: $mean")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,233 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
import kotlin.math.*
|
||||
|
||||
class LargeMeanPoissonSampler : SharedStateDiscreteSampler {
|
||||
private val rng: UniformRandomProvider
|
||||
private val exponential: SharedStateContinuousSampler
|
||||
private val gaussian: SharedStateContinuousSampler
|
||||
private val factorialLog: InternalUtils.FactorialLog
|
||||
private val lambda: Double
|
||||
private val logLambda: Double
|
||||
private val logLambdaFactorial: Double
|
||||
private val delta: Double
|
||||
private val halfDelta: Double
|
||||
private val twolpd: Double
|
||||
private val p1: Double
|
||||
private val p2: Double
|
||||
private val c1: Double
|
||||
private val smallMeanPoissonSampler: SharedStateDiscreteSampler
|
||||
|
||||
constructor(
|
||||
rng: UniformRandomProvider,
|
||||
mean: Double
|
||||
) {
|
||||
require(mean >= 1) { "mean is not >= 1: $mean" }
|
||||
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
||||
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
||||
this.rng = rng
|
||||
gaussian = ZigguratNormalizedGaussianSampler(rng)
|
||||
exponential = AhrensDieterExponentialSampler.of(rng, 1.0)
|
||||
// Plain constructor uses the uncached function.
|
||||
factorialLog = NO_CACHE_FACTORIAL_LOG!!
|
||||
// Cache values used in the algorithm
|
||||
lambda = floor(mean)
|
||||
logLambda = ln(lambda)
|
||||
logLambdaFactorial = getFactorialLog(lambda.toInt())
|
||||
delta = sqrt(lambda * ln(32 * lambda / PI + 1))
|
||||
halfDelta = delta / 2
|
||||
twolpd = 2 * lambda + delta
|
||||
c1 = 1 / (8 * lambda)
|
||||
val a1: Double = sqrt(PI * twolpd) * exp(c1)
|
||||
val a2: Double = twolpd / delta * exp(-delta * (1 + delta) / twolpd)
|
||||
val aSum = a1 + a2 + 1
|
||||
p1 = a1 / aSum
|
||||
p2 = a2 / aSum
|
||||
|
||||
// The algorithm requires a Poisson sample from the remaining lambda fraction.
|
||||
val lambdaFractional = mean - lambda
|
||||
smallMeanPoissonSampler =
|
||||
if (lambdaFractional < Double.MIN_VALUE) NO_SMALL_MEAN_POISSON_SAMPLER else // Not used.
|
||||
KempSmallMeanPoissonSampler.of(rng, lambdaFractional)
|
||||
}
|
||||
|
||||
internal constructor(
|
||||
rng: UniformRandomProvider,
|
||||
state: LargeMeanPoissonSamplerState,
|
||||
lambdaFractional: Double
|
||||
) {
|
||||
require(!(lambdaFractional < 0 || lambdaFractional >= 1)) { "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): $lambdaFractional" }
|
||||
this.rng = rng
|
||||
gaussian = ZigguratNormalizedGaussianSampler(rng)
|
||||
exponential = AhrensDieterExponentialSampler.of(rng, 1.0)
|
||||
// Plain constructor uses the uncached function.
|
||||
factorialLog = NO_CACHE_FACTORIAL_LOG!!
|
||||
// Use the state to initialise the algorithm
|
||||
lambda = state.lambdaRaw
|
||||
logLambda = state.logLambda
|
||||
logLambdaFactorial = state.logLambdaFactorial
|
||||
delta = state.delta
|
||||
halfDelta = state.halfDelta
|
||||
twolpd = state.twolpd
|
||||
p1 = state.p1
|
||||
p2 = state.p2
|
||||
c1 = state.c1
|
||||
|
||||
// The algorithm requires a Poisson sample from the remaining lambda fraction.
|
||||
smallMeanPoissonSampler =
|
||||
if (lambdaFractional < Double.MIN_VALUE)
|
||||
NO_SMALL_MEAN_POISSON_SAMPLER
|
||||
else // Not used.
|
||||
KempSmallMeanPoissonSampler.of(rng, lambdaFractional)
|
||||
}
|
||||
|
||||
/**
|
||||
* @param rng Generator of uniformly distributed random numbers.
|
||||
* @param source Source to copy.
|
||||
*/
|
||||
private constructor(
|
||||
rng: UniformRandomProvider,
|
||||
source: LargeMeanPoissonSampler
|
||||
) {
|
||||
this.rng = rng
|
||||
gaussian = source.gaussian.withUniformRandomProvider(rng)!!
|
||||
exponential = source.exponential.withUniformRandomProvider(rng)!!
|
||||
// Reuse the cache
|
||||
factorialLog = source.factorialLog
|
||||
lambda = source.lambda
|
||||
logLambda = source.logLambda
|
||||
logLambdaFactorial = source.logLambdaFactorial
|
||||
delta = source.delta
|
||||
halfDelta = source.halfDelta
|
||||
twolpd = source.twolpd
|
||||
p1 = source.p1
|
||||
p2 = source.p2
|
||||
c1 = source.c1
|
||||
|
||||
// Share the state of the small sampler
|
||||
smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng)!!
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
override fun sample(): Int {
|
||||
// This will never be null. It may be a no-op delegate that returns zero.
|
||||
val y2: Int = smallMeanPoissonSampler.sample()
|
||||
var x: Double
|
||||
var y: Double
|
||||
var v: Double
|
||||
var a: Int
|
||||
var t: Double
|
||||
var qr: Double
|
||||
var qa: Double
|
||||
while (true) {
|
||||
// Step 1:
|
||||
val u = rng.nextDouble()
|
||||
|
||||
if (u <= p1) {
|
||||
// Step 2:
|
||||
val n = gaussian.sample()
|
||||
x = n * sqrt(lambda + halfDelta) - 0.5
|
||||
if (x > delta || x < -lambda) continue
|
||||
y = if (x < 0) floor(x) else ceil(x)
|
||||
val e = exponential.sample()
|
||||
v = -e - 0.5 * n * n + c1
|
||||
} else {
|
||||
// Step 3:
|
||||
if (u > p1 + p2) {
|
||||
y = lambda
|
||||
break
|
||||
}
|
||||
|
||||
x = delta + twolpd / delta * exponential.sample()
|
||||
y = ceil(x)
|
||||
v = -exponential.sample() - delta * (x + 1) / twolpd
|
||||
}
|
||||
// The Squeeze Principle
|
||||
// Step 4.1:
|
||||
a = if (x < 0) 1 else 0
|
||||
t = y * (y + 1) / (2 * lambda)
|
||||
|
||||
// Step 4.2
|
||||
if (v < -t && a == 0) {
|
||||
y += lambda
|
||||
break
|
||||
}
|
||||
|
||||
// Step 4.3:
|
||||
qr = t * ((2 * y + 1) / (6 * lambda) - 1)
|
||||
qa = qr - t * t / (3 * (lambda + a * (y + 1)))
|
||||
|
||||
// Step 4.4:
|
||||
if (v < qa) {
|
||||
y += lambda
|
||||
break
|
||||
}
|
||||
|
||||
// Step 4.5:
|
||||
if (v > qr) continue
|
||||
|
||||
// Step 4.6:
|
||||
if (v < y * logLambda - getFactorialLog((y + lambda).toInt()) + logLambdaFactorial) {
|
||||
y += lambda
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt()
|
||||
}
|
||||
|
||||
|
||||
private fun getFactorialLog(n: Int): Double = factorialLog.value(n)
|
||||
|
||||
override fun toString(): String = "Large Mean Poisson deviate [$rng]"
|
||||
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler =
|
||||
LargeMeanPoissonSampler(rng, this)
|
||||
|
||||
val state: LargeMeanPoissonSamplerState
|
||||
get() = LargeMeanPoissonSamplerState(
|
||||
lambda, logLambda, logLambdaFactorial,
|
||||
delta, halfDelta, twolpd, p1, p2, c1
|
||||
)
|
||||
|
||||
class LargeMeanPoissonSamplerState(
|
||||
val lambdaRaw: Double,
|
||||
val logLambda: Double,
|
||||
val logLambdaFactorial: Double,
|
||||
val delta: Double,
|
||||
val halfDelta: Double,
|
||||
val twolpd: Double,
|
||||
val p1: Double,
|
||||
val p2: Double,
|
||||
val c1: Double
|
||||
) {
|
||||
fun getLambda(): Int = lambdaRaw.toInt()
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val MAX_MEAN = 0.5 * Int.MAX_VALUE
|
||||
private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null
|
||||
|
||||
private val NO_SMALL_MEAN_POISSON_SAMPLER: SharedStateDiscreteSampler =
|
||||
object : SharedStateDiscreteSampler {
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler =
|
||||
// No requirement for RNG
|
||||
this
|
||||
|
||||
override fun sample(): Int =// No Poisson sample
|
||||
0
|
||||
}
|
||||
|
||||
fun of(
|
||||
rng: UniformRandomProvider,
|
||||
mean: Double
|
||||
): SharedStateDiscreteSampler = LargeMeanPoissonSampler(rng, mean)
|
||||
|
||||
init {
|
||||
// Create without a cache.
|
||||
NO_CACHE_FACTORIAL_LOG =
|
||||
InternalUtils.FactorialLog.create()
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.sqrt
|
||||
|
||||
class MarsagliaNormalizedGaussianSampler(private val rng: UniformRandomProvider) :
|
||||
NormalizedGaussianSampler,
|
||||
SharedStateContinuousSampler {
|
||||
private var nextGaussian = Double.NaN
|
||||
|
||||
override fun sample(): Double {
|
||||
if (nextGaussian.isNaN()) {
|
||||
// Rejection scheme for selecting a pair that lies within the unit circle.
|
||||
while (true) {
|
||||
// Generate a pair of numbers within [-1 , 1).
|
||||
val x = 2.0 * rng.nextDouble() - 1.0
|
||||
val y = 2.0 * rng.nextDouble() - 1.0
|
||||
val r2 = x * x + y * y
|
||||
if (r2 < 1 && r2 > 0) {
|
||||
// Pair (x, y) is within unit circle.
|
||||
val alpha = sqrt(-2 * ln(r2) / r2)
|
||||
|
||||
// Keep second element of the pair for next invocation.
|
||||
nextGaussian = alpha * y
|
||||
|
||||
// Return the first element of the generated pair.
|
||||
return alpha * x
|
||||
}
|
||||
|
||||
// Pair is not within the unit circle: Generate another one.
|
||||
}
|
||||
} else {
|
||||
// Use the second element of the pair (generated at the
|
||||
// previous invocation).
|
||||
val r = nextGaussian
|
||||
|
||||
// Both elements of the pair have been used.
|
||||
nextGaussian = Double.NaN
|
||||
return r
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate [$rng]"
|
||||
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler =
|
||||
MarsagliaNormalizedGaussianSampler(rng)
|
||||
|
||||
companion object {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun <S> of(rng: UniformRandomProvider): S where S : NormalizedGaussianSampler?, S : SharedStateContinuousSampler? =
|
||||
MarsagliaNormalizedGaussianSampler(rng) as S
|
||||
}
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
interface NormalizedGaussianSampler : ContinuousSampler
|
@ -0,0 +1,32 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
|
||||
class PoissonSampler(
|
||||
rng: UniformRandomProvider,
|
||||
mean: Double
|
||||
) : SamplerBase(null), SharedStateDiscreteSampler {
|
||||
private val poissonSamplerDelegate: SharedStateDiscreteSampler
|
||||
|
||||
override fun sample(): Int = poissonSamplerDelegate.sample()
|
||||
override fun toString(): String = poissonSamplerDelegate.toString()
|
||||
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler? =
|
||||
// Direct return of the optimised sampler
|
||||
poissonSamplerDelegate.withUniformRandomProvider(rng)
|
||||
|
||||
companion object {
|
||||
const val PIVOT = 40.0
|
||||
|
||||
fun of(
|
||||
rng: UniformRandomProvider,
|
||||
mean: Double
|
||||
): SharedStateDiscreteSampler =// Each sampler should check the input arguments.
|
||||
if (mean < PIVOT) SmallMeanPoissonSampler.of(rng, mean) else LargeMeanPoissonSampler.of(rng, mean)
|
||||
}
|
||||
|
||||
init {
|
||||
// Delegate all work to specialised samplers.
|
||||
poissonSamplerDelegate = of(rng, mean)
|
||||
}
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
|
||||
@Deprecated("Since version 1.1. Class intended for internal use only.")
|
||||
open class SamplerBase protected constructor(private val rng: UniformRandomProvider?) {
|
||||
protected fun nextDouble(): Double = rng!!.nextDouble()
|
||||
protected fun nextInt(): Int = rng!!.nextInt()
|
||||
protected fun nextInt(max: Int): Int = rng!!.nextInt(max)
|
||||
protected fun nextLong(): Long = rng!!.nextLong()
|
||||
override fun toString(): String = "rng=$rng"
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.sampling.SharedStateSampler
|
||||
|
||||
interface SharedStateContinuousSampler : ContinuousSampler,
|
||||
SharedStateSampler<SharedStateContinuousSampler?> {
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.sampling.SharedStateSampler
|
||||
|
||||
interface SharedStateDiscreteSampler : DiscreteSampler,
|
||||
SharedStateSampler<SharedStateDiscreteSampler?> { // Composite interface
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
import kotlin.math.ceil
|
||||
import kotlin.math.exp
|
||||
|
||||
class SmallMeanPoissonSampler : SharedStateDiscreteSampler {
|
||||
private val p0: Double
|
||||
private val limit: Int
|
||||
private val rng: UniformRandomProvider
|
||||
|
||||
constructor(
|
||||
rng: UniformRandomProvider,
|
||||
mean: Double
|
||||
) {
|
||||
this.rng = rng
|
||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||
p0 = exp(-mean)
|
||||
|
||||
limit = (if (p0 > 0) ceil(1000 * mean) else throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
|
||||
// This excludes NaN values for the mean
|
||||
// else
|
||||
// The returned sample is bounded by 1000 * mean
|
||||
}
|
||||
|
||||
private constructor(
|
||||
rng: UniformRandomProvider,
|
||||
source: SmallMeanPoissonSampler
|
||||
) {
|
||||
this.rng = rng
|
||||
p0 = source.p0
|
||||
limit = source.limit
|
||||
}
|
||||
|
||||
override fun sample(): Int {
|
||||
var n = 0
|
||||
var r = 1.0
|
||||
|
||||
while (n < limit) {
|
||||
r *= rng.nextDouble()
|
||||
if (r >= p0) n++ else break
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
override fun toString(): String = "Small Mean Poisson deviate [$rng]"
|
||||
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler =
|
||||
SmallMeanPoissonSampler(rng, this)
|
||||
|
||||
companion object {
|
||||
fun of(
|
||||
rng: UniformRandomProvider,
|
||||
mean: Double
|
||||
): SharedStateDiscreteSampler = SmallMeanPoissonSampler(rng, mean)
|
||||
}
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
package scientifik.commons.rng.sampling.distribution
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
import kotlin.math.*
|
||||
|
||||
class ZigguratNormalizedGaussianSampler(private val rng: UniformRandomProvider) :
|
||||
NormalizedGaussianSampler,
|
||||
SharedStateContinuousSampler {
|
||||
|
||||
companion object {
|
||||
private const val R = 3.442619855899
|
||||
private const val ONE_OVER_R: Double = 1 / R
|
||||
private const val V = 9.91256303526217e-3
|
||||
private val MAX: Double = 2.0.pow(63.0)
|
||||
private val ONE_OVER_MAX: Double = 1.0 / MAX
|
||||
private const val LEN = 128
|
||||
private const val LAST: Int = LEN - 1
|
||||
private val K = LongArray(LEN)
|
||||
private val W = DoubleArray(LEN)
|
||||
private val F = DoubleArray(LEN)
|
||||
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun <S> of(rng: UniformRandomProvider): S where S : NormalizedGaussianSampler?, S : SharedStateContinuousSampler? =
|
||||
ZigguratNormalizedGaussianSampler(rng) as S
|
||||
|
||||
init {
|
||||
// Filling the tables.
|
||||
var d = R
|
||||
var t = d
|
||||
var fd = gauss(d)
|
||||
val q = V / fd
|
||||
K[0] = (d / q * MAX).toLong()
|
||||
K[1] = 0
|
||||
W[0] = q * ONE_OVER_MAX
|
||||
W[LAST] = d * ONE_OVER_MAX
|
||||
F[0] = 1.0
|
||||
F[LAST] = fd
|
||||
|
||||
(LAST - 1 downTo 1).forEach { i ->
|
||||
d = sqrt(-2 * ln(V / d + fd))
|
||||
fd = gauss(d)
|
||||
K[i + 1] = (d / t * MAX).toLong()
|
||||
t = d
|
||||
F[i] = fd
|
||||
W[i] = d * ONE_OVER_MAX
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun sample(): Double {
|
||||
val j = rng.nextLong()
|
||||
val i = (j and LAST.toLong()).toInt()
|
||||
return if (abs(j) < K[i]) j * W[i] else fix(j, i)
|
||||
}
|
||||
|
||||
override fun toString(): String = "Ziggurat normalized Gaussian deviate [$rng]"
|
||||
|
||||
private fun fix(
|
||||
hz: Long,
|
||||
iz: Int
|
||||
): Double {
|
||||
var x: Double
|
||||
var y: Double
|
||||
x = hz * W[iz]
|
||||
|
||||
return if (iz == 0) {
|
||||
// Base strip.
|
||||
// This branch is called about 5.7624515E-4 times per sample.
|
||||
do {
|
||||
y = -ln(rng.nextDouble())
|
||||
x = -ln(rng.nextDouble()) * ONE_OVER_R
|
||||
} while (y + y < x * x)
|
||||
|
||||
val out = R + x
|
||||
if (hz > 0) out else -out
|
||||
} else {
|
||||
// Wedge of other strips.
|
||||
// This branch is called about 0.027323 times per sample.
|
||||
// else
|
||||
// Try again.
|
||||
// This branch is called about 0.012362 times per sample.
|
||||
if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x)) x else sample()
|
||||
}
|
||||
}
|
||||
|
||||
override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler =
|
||||
ZigguratNormalizedGaussianSampler(rng)
|
||||
}
|
@ -6,8 +6,10 @@ kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-coroutines"))
|
||||
api(project(":kmath-commons-rng-part"))
|
||||
}
|
||||
}
|
||||
|
||||
jvmMain {
|
||||
dependencies {
|
||||
api("org.apache.commons:commons-rng-sampling:1.3")
|
||||
|
@ -0,0 +1,53 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
import scientifik.commons.rng.sampling.distribution.*
|
||||
import scientifik.kmath.chains.BlockingIntChain
|
||||
import scientifik.kmath.chains.BlockingRealChain
|
||||
import scientifik.kmath.chains.Chain
|
||||
|
||||
abstract class ContinuousSamplerDistribution : Distribution<Double> {
|
||||
|
||||
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() {
|
||||
private val sampler = buildCMSampler(generator)
|
||||
|
||||
override fun nextDouble(): Double = sampler.sample()
|
||||
|
||||
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
|
||||
}
|
||||
|
||||
protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler
|
||||
|
||||
override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator)
|
||||
}
|
||||
|
||||
abstract class DiscreteSamplerDistribution : Distribution<Int> {
|
||||
|
||||
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() {
|
||||
private val sampler = buildSampler(generator)
|
||||
|
||||
override fun nextInt(): Int = sampler.sample()
|
||||
|
||||
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
|
||||
}
|
||||
|
||||
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
|
||||
|
||||
override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator)
|
||||
}
|
||||
|
||||
enum class NormalSamplerMethod {
|
||||
BoxMuller,
|
||||
Marsaglia,
|
||||
Ziggurat
|
||||
}
|
||||
|
||||
fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomProvider): NormalizedGaussianSampler =
|
||||
when (method) {
|
||||
NormalSamplerMethod.BoxMuller -> BoxMullerNormalizedGaussianSampler(
|
||||
provider
|
||||
)
|
||||
NormalSamplerMethod.Marsaglia -> MarsagliaNormalizedGaussianSampler(provider)
|
||||
NormalSamplerMethod.Ziggurat -> ZigguratNormalizedGaussianSampler(provider)
|
||||
}
|
||||
|
@ -0,0 +1,28 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
|
||||
|
||||
inline class RandomGeneratorProvider(val generator: RandomGenerator) : UniformRandomProvider {
|
||||
override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||
|
||||
override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
||||
|
||||
override fun nextBytes(bytes: ByteArray) {
|
||||
generator.fillBytes(bytes)
|
||||
}
|
||||
|
||||
override fun nextBytes(bytes: ByteArray, start: Int, len: Int) {
|
||||
generator.fillBytes(bytes, start, start + len)
|
||||
}
|
||||
|
||||
override fun nextInt(): Int = generator.nextInt()
|
||||
|
||||
override fun nextInt(n: Int): Int = generator.nextInt(n)
|
||||
|
||||
override fun nextDouble(): Double = generator.nextDouble()
|
||||
|
||||
override fun nextLong(): Long = generator.nextLong()
|
||||
|
||||
override fun nextLong(n: Long): Long = generator.nextLong(n)
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import scientifik.commons.rng.sampling.distribution.ContinuousSampler
|
||||
import scientifik.commons.rng.sampling.distribution.DiscreteSampler
|
||||
import scientifik.commons.rng.sampling.distribution.GaussianSampler
|
||||
import scientifik.commons.rng.sampling.distribution.PoissonSampler
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.exp
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
fun Distribution.Companion.normal(
|
||||
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
||||
): Distribution<Double> = object : ContinuousSamplerDistribution() {
|
||||
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
||||
val provider = generator.asUniformRandomProvider()
|
||||
return normalSampler(method, provider)
|
||||
}
|
||||
|
||||
override fun probability(arg: Double): Double {
|
||||
return exp(-arg.pow(2) / 2) / sqrt(PI * 2)
|
||||
}
|
||||
}
|
||||
|
||||
fun Distribution.Companion.normal(
|
||||
mean: Double,
|
||||
sigma: Double,
|
||||
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
||||
): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() {
|
||||
private val sigma2 = sigma.pow(2)
|
||||
private val norm = sigma * sqrt(PI * 2)
|
||||
|
||||
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
||||
val provider = generator.asUniformRandomProvider()
|
||||
val normalizedSampler = normalSampler(method, provider)
|
||||
return GaussianSampler(normalizedSampler, mean, sigma)
|
||||
}
|
||||
|
||||
override fun probability(arg: Double): Double {
|
||||
return exp(-(arg - mean).pow(2) / 2 / sigma2) / norm
|
||||
}
|
||||
}
|
||||
|
||||
fun Distribution.Companion.poisson(
|
||||
lambda: Double
|
||||
): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() {
|
||||
|
||||
override fun buildSampler(generator: RandomGenerator): DiscreteSampler {
|
||||
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda)
|
||||
}
|
||||
|
||||
private val computedProb: HashMap<Int, Double> = hashMapOf(0 to exp(-lambda))
|
||||
|
||||
override fun probability(arg: Int): Double {
|
||||
require(arg >= 0) { "The argument must be >= 0" }
|
||||
|
||||
return if (arg > 40)
|
||||
exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda)
|
||||
else
|
||||
computedProb.getOrPut(arg) { probability(arg - 1) * lambda / arg }
|
||||
}
|
||||
}
|
@ -1,20 +1,18 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import org.apache.commons.rng.UniformRandomProvider
|
||||
import org.apache.commons.rng.simple.RandomSource
|
||||
import scientifik.commons.rng.UniformRandomProvider
|
||||
|
||||
class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator {
|
||||
internal val random: UniformRandomProvider = seed?.let {
|
||||
class RandomSourceGenerator(val source: RandomSource, seed: Long?) :
|
||||
RandomGenerator {
|
||||
internal val random = seed?.let {
|
||||
RandomSource.create(source, seed)
|
||||
} ?: RandomSource.create(source)
|
||||
|
||||
override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||
|
||||
override fun nextDouble(): Double = random.nextDouble()
|
||||
|
||||
override fun nextInt(): Int = random.nextInt()
|
||||
override fun nextInt(until: Int): Int = random.nextInt(until)
|
||||
|
||||
override fun nextLong(): Long = random.nextLong()
|
||||
override fun nextLong(until: Long): Long = random.nextLong(until)
|
||||
|
||||
@ -26,39 +24,23 @@ class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGener
|
||||
override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong())
|
||||
}
|
||||
|
||||
inline class RandomGeneratorProvider(val generator: RandomGenerator) : UniformRandomProvider {
|
||||
override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||
|
||||
override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
||||
|
||||
override fun nextBytes(bytes: ByteArray) {
|
||||
generator.fillBytes(bytes)
|
||||
}
|
||||
|
||||
override fun nextBytes(bytes: ByteArray, start: Int, len: Int) {
|
||||
generator.fillBytes(bytes, start, start + len)
|
||||
}
|
||||
|
||||
override fun nextInt(): Int = generator.nextInt()
|
||||
|
||||
override fun nextInt(n: Int): Int = generator.nextInt(n)
|
||||
|
||||
override fun nextDouble(): Double = generator.nextDouble()
|
||||
|
||||
override fun nextLong(): Long = generator.nextLong()
|
||||
|
||||
override fun nextLong(n: Long): Long = generator.nextLong(n)
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent this [RandomGenerator] as commons-rng [UniformRandomProvider] preserving and mirroring its current state.
|
||||
* Getting new value from one of those changes the state of another.
|
||||
*/
|
||||
fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this is RandomSourceGenerator) {
|
||||
random
|
||||
} else {
|
||||
RandomGeneratorProvider(this)
|
||||
object : UniformRandomProvider {
|
||||
override fun nextBytes(bytes: ByteArray) = random.nextBytes(bytes)
|
||||
override fun nextBytes(bytes: ByteArray, start: Int, len: Int) = random.nextBytes(bytes, start, len)
|
||||
override fun nextInt(): Int = random.nextInt()
|
||||
override fun nextInt(n: Int): Int = random.nextInt(n)
|
||||
override fun nextLong(): Long = random.nextLong()
|
||||
override fun nextLong(n: Long): Long = random.nextLong(n)
|
||||
override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||
override fun nextFloat(): Float = random.nextFloat()
|
||||
override fun nextDouble(): Double = random.nextDouble()
|
||||
}
|
||||
} else RandomGeneratorProvider(this)
|
||||
|
||||
fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
|
||||
RandomSourceGenerator(source, seed)
|
||||
|
@ -1,109 +0,0 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import org.apache.commons.rng.UniformRandomProvider
|
||||
import org.apache.commons.rng.sampling.distribution.*
|
||||
import scientifik.kmath.chains.BlockingIntChain
|
||||
import scientifik.kmath.chains.BlockingRealChain
|
||||
import scientifik.kmath.chains.Chain
|
||||
import java.util.*
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.exp
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
abstract class ContinuousSamplerDistribution : Distribution<Double> {
|
||||
|
||||
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() {
|
||||
private val sampler = buildCMSampler(generator)
|
||||
|
||||
override fun nextDouble(): Double = sampler.sample()
|
||||
|
||||
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
|
||||
}
|
||||
|
||||
protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler
|
||||
|
||||
override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator)
|
||||
}
|
||||
|
||||
abstract class DiscreteSamplerDistribution : Distribution<Int> {
|
||||
|
||||
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() {
|
||||
private val sampler = buildSampler(generator)
|
||||
|
||||
override fun nextInt(): Int = sampler.sample()
|
||||
|
||||
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
|
||||
}
|
||||
|
||||
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
|
||||
|
||||
override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator)
|
||||
}
|
||||
|
||||
enum class NormalSamplerMethod {
|
||||
BoxMuller,
|
||||
Marsaglia,
|
||||
Ziggurat
|
||||
}
|
||||
|
||||
private fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomProvider): NormalizedGaussianSampler =
|
||||
when (method) {
|
||||
NormalSamplerMethod.BoxMuller -> BoxMullerNormalizedGaussianSampler(provider)
|
||||
NormalSamplerMethod.Marsaglia -> MarsagliaNormalizedGaussianSampler(provider)
|
||||
NormalSamplerMethod.Ziggurat -> ZigguratNormalizedGaussianSampler(provider)
|
||||
}
|
||||
|
||||
fun Distribution.Companion.normal(
|
||||
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
||||
): Distribution<Double> = object : ContinuousSamplerDistribution() {
|
||||
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
||||
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
||||
return normalSampler(method, provider)
|
||||
}
|
||||
|
||||
override fun probability(arg: Double): Double {
|
||||
return exp(-arg.pow(2) / 2) / sqrt(PI * 2)
|
||||
}
|
||||
}
|
||||
|
||||
fun Distribution.Companion.normal(
|
||||
mean: Double,
|
||||
sigma: Double,
|
||||
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
||||
): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() {
|
||||
private val sigma2 = sigma.pow(2)
|
||||
private val norm = sigma * sqrt(PI * 2)
|
||||
|
||||
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
||||
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
||||
val normalizedSampler = normalSampler(method, provider)
|
||||
return GaussianSampler(normalizedSampler, mean, sigma)
|
||||
}
|
||||
|
||||
override fun probability(arg: Double): Double {
|
||||
return exp(-(arg - mean).pow(2) / 2 / sigma2) / norm
|
||||
}
|
||||
}
|
||||
|
||||
fun Distribution.Companion.poisson(
|
||||
lambda: Double
|
||||
): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() {
|
||||
|
||||
override fun buildSampler(generator: RandomGenerator): DiscreteSampler {
|
||||
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda)
|
||||
}
|
||||
|
||||
private val computedProb: HashMap<Int, Double> = hashMapOf(0 to exp(-lambda))
|
||||
|
||||
override fun probability(arg: Int): Double {
|
||||
require(arg >= 0) { "The argument must be >= 0" }
|
||||
return if (arg > 40) {
|
||||
exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda)
|
||||
} else {
|
||||
computedProb.getOrPut(arg) {
|
||||
probability(arg - 1) * lambda / arg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -35,6 +35,7 @@ include(
|
||||
":kmath-functions",
|
||||
// ":kmath-io",
|
||||
":kmath-coroutines",
|
||||
"kmath-commons-rng-part",
|
||||
":kmath-histograms",
|
||||
":kmath-commons",
|
||||
":kmath-viktor",
|
||||
|
Loading…
Reference in New Issue
Block a user