From 28062cb09637801b7628a6b31c1a02cb1674531d Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 8 Jun 2020 17:16:57 +0700 Subject: [PATCH] Minimal refactor of existing random API, move samplers implementations to samplers package, implement Sampler by all the Samplers --- .../kotlin/scientifik/kmath/chains/Chain.kt | 65 ++--- .../commons/rng/UniformRandomProvider.kt | 19 -- .../rng/sampling/SharedStateSampler.kt | 7 - .../AhrensDieterExponentialSampler.kt | 101 ------- .../BoxMullerNormalizedGaussianSampler.kt | 50 ---- .../distribution/ContinuousSampler.kt | 5 - .../sampling/distribution/DiscreteSampler.kt | 5 - .../sampling/distribution/GaussianSampler.kt | 55 ---- .../distribution/LargenMeanPoissonSampler.kt | 251 ------------------ .../distribution/NormalizedGaussianSampler.kt | 4 - .../sampling/distribution/PoissonSampler.kt | 40 --- .../rng/sampling/distribution/SamplerBase.kt | 12 - .../SharedStateContinuousSampler.kt | 7 - .../SharedStateDiscreteSampler.kt | 7 - .../distribution/SmallMeanPoissonSampler.kt | 60 ----- .../scientifik/kmath/prob/Distribution.kt | 3 + .../kmath/prob/FactorizedDistribution.kt | 1 - .../scientifik/kmath/prob/RandomGenerator.kt | 10 +- .../AhrensDieterExponentialSampler.kt | 68 +++++ .../BoxMullerNormalizedGaussianSampler.kt | 37 +++ .../kmath/prob/samplers/GaussianSampler.kt | 34 +++ .../samplers}/InternalGamma.kt | 4 +- .../samplers}/InternalUtils.kt | 50 +--- .../samplers}/KempSmallMeanPoissonSampler.kt | 35 +-- .../prob/samplers/LargeMeanPoissonSampler.kt | 132 +++++++++ .../MarsagliaNormalizedGaussianSampler.kt | 42 ++- .../samplers/NormalizedGaussianSampler.kt | 5 + .../kmath/prob/samplers/PoissonSampler.kt | 29 ++ .../prob/samplers/SmallMeanPoissonSampler.kt | 43 +++ .../ZigguratNormalizedGaussianSampler.kt | 118 ++++---- .../scientifik/kmath/prob/DistributionsJVM.kt | 66 ----- .../kmath/prob/RandomSourceGenerator.kt | 23 +- 32 files changed, 485 insertions(+), 903 deletions(-) delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/UniformRandomProvider.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/SharedStateSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/ContinuousSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/DiscreteSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/GaussianSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/LargenMeanPoissonSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/NormalizedGaussianSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/PoissonSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SamplerBase.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SharedStateContinuousSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SharedStateDiscreteSampler.kt delete mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SmallMeanPoissonSampler.kt create mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/AhrensDieterExponentialSampler.kt create mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/BoxMullerNormalizedGaussianSampler.kt create mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/GaussianSampler.kt rename kmath-prob/src/commonMain/kotlin/scientifik/kmath/{commons/rng/sampling/distribution => prob/samplers}/InternalGamma.kt (93%) rename kmath-prob/src/commonMain/kotlin/scientifik/kmath/{commons/rng/sampling/distribution => prob/samplers}/InternalUtils.kt (55%) rename kmath-prob/src/commonMain/kotlin/scientifik/kmath/{commons/rng/sampling/distribution => prob/samplers}/KempSmallMeanPoissonSampler.kt (66%) create mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/LargeMeanPoissonSampler.kt rename kmath-prob/src/commonMain/kotlin/scientifik/kmath/{commons/rng/sampling/distribution => prob/samplers}/MarsagliaNormalizedGaussianSampler.kt (53%) create mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/NormalizedGaussianSampler.kt create mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/PoissonSampler.kt create mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/SmallMeanPoissonSampler.kt rename kmath-prob/src/commonMain/kotlin/scientifik/kmath/{commons/rng/sampling/distribution => prob/samplers}/ZigguratNormalizedGaussianSampler.kt (58%) delete mode 100644 kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/DistributionsJVM.kt diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt index 5635499e5..6457c74fd 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt @@ -19,6 +19,7 @@ package scientifik.kmath.chains import kotlinx.coroutines.InternalCoroutinesApi import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock @@ -27,7 +28,7 @@ import kotlinx.coroutines.sync.withLock * A not-necessary-Markov chain of some type * @param R - the chain element type */ -interface Chain: Flow { +interface Chain : Flow { /** * Generate next value, changing state if needed */ @@ -39,13 +40,11 @@ interface Chain: Flow { fun fork(): Chain @OptIn(InternalCoroutinesApi::class) - override suspend fun collect(collector: FlowCollector) { - kotlinx.coroutines.flow.flow { - while (true){ - emit(next()) - } - }.collect(collector) - } + override suspend fun collect(collector: FlowCollector): Unit = flow { + while (true) + emit(next()) + + }.collect(collector) companion object } @@ -66,24 +65,18 @@ class SimpleChain(private val gen: suspend () -> R) : Chain { * A stateless Markov chain */ class MarkovChain(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain { - - private val mutex = Mutex() - + private val mutex: Mutex = Mutex() private var value: R? = null fun value() = value - override suspend fun next(): R { - mutex.withLock { - val newValue = gen(value ?: seed()) - value = newValue - return newValue - } + override suspend fun next(): R = mutex.withLock { + val newValue = gen(value ?: seed()) + value = newValue + return newValue } - override fun fork(): Chain { - return MarkovChain(seed = { value ?: seed() }, gen = gen) - } + override fun fork(): Chain = MarkovChain(seed = { value ?: seed() }, gen = gen) } /** @@ -97,24 +90,18 @@ class StatefulChain( private val forkState: ((S) -> S), private val gen: suspend S.(R) -> R ) : Chain { - private val mutex = Mutex() - private var value: R? = null fun value() = value - override suspend fun next(): R { - mutex.withLock { - val newValue = state.gen(value ?: state.seed()) - value = newValue - return newValue - } + override suspend fun next(): R = mutex.withLock { + val newValue = state.gen(value ?: state.seed()) + value = newValue + return newValue } - override fun fork(): Chain { - return StatefulChain(forkState(state), seed, forkState, gen) - } + override fun fork(): Chain = StatefulChain(forkState(state), seed, forkState, gen) } /** @@ -123,9 +110,7 @@ class StatefulChain( class ConstantChain(val value: T) : Chain { override suspend fun next(): T = value - override fun fork(): Chain { - return this - } + override fun fork(): Chain = this } /** @@ -143,9 +128,8 @@ fun Chain.map(func: suspend (T) -> R): Chain = object : Chain { fun Chain.filter(block: (T) -> Boolean): Chain = object : Chain { override suspend fun next(): T { var next: T - do { - next = this@filter.next() - } while (!block(next)) + do next = this@filter.next() + while (!block(next)) return next } @@ -163,7 +147,9 @@ fun Chain.collect(mapper: suspend (Chain) -> R): Chain = object fun Chain.collectWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain) -> R): Chain = object : Chain { override suspend fun next(): R = state.mapper(this@collectWithState) - override fun fork(): Chain = this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) + + override fun fork(): Chain = + this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) } /** @@ -171,6 +157,5 @@ fun Chain.collectWithState(state: S, stateFork: (S) -> S, mapper: s */ fun Chain.zip(other: Chain, block: suspend (T, U) -> R): Chain = object : Chain { override suspend fun next(): R = block(this@zip.next(), other.next()) - override fun fork(): Chain = this@zip.fork().zip(other.fork(), block) -} \ No newline at end of file +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/UniformRandomProvider.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/UniformRandomProvider.kt deleted file mode 100644 index bd6d30124..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/UniformRandomProvider.kt +++ /dev/null @@ -1,19 +0,0 @@ -package scientifik.kmath.commons.rng - -interface UniformRandomProvider { - fun nextBytes(bytes: ByteArray) - - fun nextBytes( - bytes: ByteArray, - start: Int, - len: Int - ) - - fun nextInt(): Int - fun nextInt(n: Int): Int - fun nextLong(): Long - fun nextLong(n: Long): Long - fun nextBoolean(): Boolean - fun nextFloat(): Float - fun nextDouble(): Double -} \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/SharedStateSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/SharedStateSampler.kt deleted file mode 100644 index 30ccefb54..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/SharedStateSampler.kt +++ /dev/null @@ -1,7 +0,0 @@ -package scientifik.kmath.commons.rng.sampling - -import scientifik.kmath.commons.rng.UniformRandomProvider - -interface SharedStateSampler { - fun withUniformRandomProvider(rng: UniformRandomProvider): R -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.kt deleted file mode 100644 index 714cf2164..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.kt +++ /dev/null @@ -1,101 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -import scientifik.kmath.commons.rng.UniformRandomProvider -import kotlin.math.ln -import kotlin.math.pow - -class AhrensDieterExponentialSampler : SamplerBase, - SharedStateContinuousSampler { - private val mean: Double - private val rng: UniformRandomProvider - - constructor( - rng: UniformRandomProvider, - mean: Double - ) : super(null) { - require(mean > 0) { "mean is not strictly positive: $mean" } - this.rng = rng - this.mean = mean - } - - private constructor( - rng: UniformRandomProvider, - source: AhrensDieterExponentialSampler - ) : super(null) { - this.rng = rng - mean = source.mean - } - - override fun sample(): Double { - // Step 1: - var a = 0.0 - var u: Double = rng.nextDouble() - - // Step 2 and 3: - while (u < 0.5) { - a += EXPONENTIAL_SA_QI.get( - 0 - ) - u *= 2.0 - } - - // Step 4 (now u >= 0.5): - u += u - 1 - - // Step 5: - if (u <= EXPONENTIAL_SA_QI.get( - 0 - ) - ) { - return mean * (a + u) - } - - // Step 6: - var i = 0 // Should be 1, be we iterate before it in while using 0. - var u2: Double = rng.nextDouble() - var umin = u2 - - // Step 7 and 8: - do { - ++i - u2 = rng.nextDouble() - if (u2 < umin) umin = u2 - // Step 8: - } while (u > EXPONENTIAL_SA_QI[i]) // Ensured to exit since EXPONENTIAL_SA_QI[MAX] = 1. - return mean * (a + umin * EXPONENTIAL_SA_QI[0]) - } - - override fun toString(): String = "Ahrens-Dieter Exponential deviate [$rng]" - - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler = - AhrensDieterExponentialSampler(rng, this) - - companion object { - private val EXPONENTIAL_SA_QI = DoubleArray(16) - - fun of( - rng: UniformRandomProvider, - mean: Double - ): SharedStateContinuousSampler = - AhrensDieterExponentialSampler( - rng, - mean - ) - - init { - /** - * Filling EXPONENTIAL_SA_QI table. - * Note that we don't want qi = 0 in the table. - */ - val ln2 = ln(2.0) - var qi = 0.0 - - EXPONENTIAL_SA_QI.indices.forEach { i -> - qi += ln2.pow(i + 1.0) / InternalUtils.factorial( - i + 1 - ) - EXPONENTIAL_SA_QI[i] = qi - } - } - } -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.kt deleted file mode 100644 index 0e927bb32..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.kt +++ /dev/null @@ -1,50 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -import scientifik.kmath.commons.rng.UniformRandomProvider -import kotlin.math.* - -class BoxMullerNormalizedGaussianSampler( - private val rng: UniformRandomProvider -) : - NormalizedGaussianSampler, - SharedStateContinuousSampler { - private var nextGaussian: Double = Double.NaN - - override fun sample(): Double { - val random: Double - - if (nextGaussian.isNaN()) { - // Generate a pair of Gaussian numbers. - val x = rng.nextDouble() - val y = rng.nextDouble() - val alpha = 2 * PI * x - val r = sqrt(-2 * ln(y)) - - // Return the first element of the generated pair. - random = r * cos(alpha) - - // Keep second element of the pair for next invocation. - nextGaussian = r * sin(alpha) - } else { - // Use the second element of the pair (generated at the - // previous invocation). - random = nextGaussian - - // Both elements of the pair have been used. - nextGaussian = Double.NaN - } - - return random - } - - override fun toString(): String = "Box-Muller normalized Gaussian deviate [$rng]" - - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler = - BoxMullerNormalizedGaussianSampler(rng) - - companion object { - @Suppress("UNCHECKED_CAST") - fun of(rng: UniformRandomProvider): S where S : NormalizedGaussianSampler?, S : SharedStateContinuousSampler? = - BoxMullerNormalizedGaussianSampler(rng) as S - } -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/ContinuousSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/ContinuousSampler.kt deleted file mode 100644 index aea81cf92..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/ContinuousSampler.kt +++ /dev/null @@ -1,5 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -interface ContinuousSampler { - fun sample(): Double -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/DiscreteSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/DiscreteSampler.kt deleted file mode 100644 index 4de780424..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/DiscreteSampler.kt +++ /dev/null @@ -1,5 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -interface DiscreteSampler { - fun sample(): Int -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/GaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/GaussianSampler.kt deleted file mode 100644 index 0b54a0ad1..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/GaussianSampler.kt +++ /dev/null @@ -1,55 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -import scientifik.kmath.commons.rng.UniformRandomProvider - -class GaussianSampler : - SharedStateContinuousSampler { - private val mean: Double - private val standardDeviation: Double - private val normalized: NormalizedGaussianSampler - - constructor( - normalized: NormalizedGaussianSampler, - mean: Double, - standardDeviation: Double - ) { - require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" } - this.normalized = normalized - this.mean = mean - this.standardDeviation = standardDeviation - } - - private constructor( - rng: UniformRandomProvider, - source: GaussianSampler - ) { - mean = source.mean - standardDeviation = source.standardDeviation - normalized = - InternalUtils.newNormalizedGaussianSampler( - source.normalized, - rng - ) - } - - override fun sample(): Double = standardDeviation * normalized.sample() + mean - - override fun toString(): String = "Gaussian deviate [$normalized]" - - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler { - return GaussianSampler(rng, this) - } - - companion object { - fun of( - normalized: NormalizedGaussianSampler, - mean: Double, - standardDeviation: Double - ): SharedStateContinuousSampler = - GaussianSampler( - normalized, - mean, - standardDeviation - ) - } -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/LargenMeanPoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/LargenMeanPoissonSampler.kt deleted file mode 100644 index ae7d3f079..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/LargenMeanPoissonSampler.kt +++ /dev/null @@ -1,251 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -import scientifik.kmath.commons.rng.UniformRandomProvider -import kotlin.math.* - -class LargeMeanPoissonSampler : - SharedStateDiscreteSampler { - private val rng: UniformRandomProvider - private val exponential: SharedStateContinuousSampler - private val gaussian: SharedStateContinuousSampler - private val factorialLog: InternalUtils.FactorialLog - private val lambda: Double - private val logLambda: Double - private val logLambdaFactorial: Double - private val delta: Double - private val halfDelta: Double - private val twolpd: Double - private val p1: Double - private val p2: Double - private val c1: Double - private val smallMeanPoissonSampler: SharedStateDiscreteSampler - - constructor( - rng: UniformRandomProvider, - mean: Double - ) { - require(mean >= 1) { "mean is not >= 1: $mean" } - // The algorithm is not valid if Math.floor(mean) is not an integer. - require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" } - this.rng = rng - gaussian = - ZigguratNormalizedGaussianSampler(rng) - exponential = - AhrensDieterExponentialSampler.of( - rng, - 1.0 - ) - // Plain constructor uses the uncached function. - factorialLog = NO_CACHE_FACTORIAL_LOG!! - // Cache values used in the algorithm - lambda = floor(mean) - logLambda = ln(lambda) - logLambdaFactorial = getFactorialLog(lambda.toInt()) - delta = sqrt(lambda * ln(32 * lambda / PI + 1)) - halfDelta = delta / 2 - twolpd = 2 * lambda + delta - c1 = 1 / (8 * lambda) - val a1: Double = sqrt(PI * twolpd) * exp(c1) - val a2: Double = twolpd / delta * exp(-delta * (1 + delta) / twolpd) - val aSum = a1 + a2 + 1 - p1 = a1 / aSum - p2 = a2 / aSum - - // The algorithm requires a Poisson sample from the remaining lambda fraction. - val lambdaFractional = mean - lambda - smallMeanPoissonSampler = - if (lambdaFractional < Double.MIN_VALUE) NO_SMALL_MEAN_POISSON_SAMPLER else // Not used. - KempSmallMeanPoissonSampler.of( - rng, - lambdaFractional - ) - } - - internal constructor( - rng: UniformRandomProvider, - state: LargeMeanPoissonSamplerState, - lambdaFractional: Double - ) { - require(!(lambdaFractional < 0 || lambdaFractional >= 1)) { "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): $lambdaFractional" } - this.rng = rng - gaussian = - ZigguratNormalizedGaussianSampler(rng) - exponential = - AhrensDieterExponentialSampler.of( - rng, - 1.0 - ) - // Plain constructor uses the uncached function. - factorialLog = NO_CACHE_FACTORIAL_LOG!! - // Use the state to initialise the algorithm - lambda = state.lambdaRaw - logLambda = state.logLambda - logLambdaFactorial = state.logLambdaFactorial - delta = state.delta - halfDelta = state.halfDelta - twolpd = state.twolpd - p1 = state.p1 - p2 = state.p2 - c1 = state.c1 - - // The algorithm requires a Poisson sample from the remaining lambda fraction. - smallMeanPoissonSampler = - if (lambdaFractional < Double.MIN_VALUE) - NO_SMALL_MEAN_POISSON_SAMPLER - else // Not used. - KempSmallMeanPoissonSampler.of( - rng, - lambdaFractional - ) - } - - /** - * @param rng Generator of uniformly distributed random numbers. - * @param source Source to copy. - */ - private constructor( - rng: UniformRandomProvider, - source: LargeMeanPoissonSampler - ) { - this.rng = rng - gaussian = source.gaussian.withUniformRandomProvider(rng)!! - exponential = source.exponential.withUniformRandomProvider(rng)!! - // Reuse the cache - factorialLog = source.factorialLog - lambda = source.lambda - logLambda = source.logLambda - logLambdaFactorial = source.logLambdaFactorial - delta = source.delta - halfDelta = source.halfDelta - twolpd = source.twolpd - p1 = source.p1 - p2 = source.p2 - c1 = source.c1 - - // Share the state of the small sampler - smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng)!! - } - - /** {@inheritDoc} */ - override fun sample(): Int { - // This will never be null. It may be a no-op delegate that returns zero. - val y2: Int = smallMeanPoissonSampler.sample() - var x: Double - var y: Double - var v: Double - var a: Int - var t: Double - var qr: Double - var qa: Double - while (true) { - // Step 1: - val u = rng.nextDouble() - - if (u <= p1) { - // Step 2: - val n = gaussian.sample() - x = n * sqrt(lambda + halfDelta) - 0.5 - if (x > delta || x < -lambda) continue - y = if (x < 0) floor(x) else ceil(x) - val e = exponential.sample() - v = -e - 0.5 * n * n + c1 - } else { - // Step 3: - if (u > p1 + p2) { - y = lambda - break - } - - x = delta + twolpd / delta * exponential.sample() - y = ceil(x) - v = -exponential.sample() - delta * (x + 1) / twolpd - } - // The Squeeze Principle - // Step 4.1: - a = if (x < 0) 1 else 0 - t = y * (y + 1) / (2 * lambda) - - // Step 4.2 - if (v < -t && a == 0) { - y += lambda - break - } - - // Step 4.3: - qr = t * ((2 * y + 1) / (6 * lambda) - 1) - qa = qr - t * t / (3 * (lambda + a * (y + 1))) - - // Step 4.4: - if (v < qa) { - y += lambda - break - } - - // Step 4.5: - if (v > qr) continue - - // Step 4.6: - if (v < y * logLambda - getFactorialLog((y + lambda).toInt()) + logLambdaFactorial) { - y += lambda - break - } - } - - return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt() - } - - - private fun getFactorialLog(n: Int): Double = factorialLog.value(n) - - override fun toString(): String = "Large Mean Poisson deviate [$rng]" - - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler = - LargeMeanPoissonSampler(rng, this) - - val state: LargeMeanPoissonSamplerState - get() = LargeMeanPoissonSamplerState( - lambda, logLambda, logLambdaFactorial, - delta, halfDelta, twolpd, p1, p2, c1 - ) - - class LargeMeanPoissonSamplerState( - val lambdaRaw: Double, - val logLambda: Double, - val logLambdaFactorial: Double, - val delta: Double, - val halfDelta: Double, - val twolpd: Double, - val p1: Double, - val p2: Double, - val c1: Double - ) { - fun getLambda(): Int = lambdaRaw.toInt() - } - - companion object { - private const val MAX_MEAN = 0.5 * Int.MAX_VALUE - private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null - - private val NO_SMALL_MEAN_POISSON_SAMPLER: SharedStateDiscreteSampler = - object : SharedStateDiscreteSampler { - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler = -// No requirement for RNG - this - - override fun sample(): Int =// No Poisson sample - 0 - } - - fun of( - rng: UniformRandomProvider, - mean: Double - ): SharedStateDiscreteSampler = - LargeMeanPoissonSampler(rng, mean) - - init { - // Create without a cache. - NO_CACHE_FACTORIAL_LOG = - InternalUtils.FactorialLog.create() - } - } -} \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/NormalizedGaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/NormalizedGaussianSampler.kt deleted file mode 100644 index c47b9b4e9..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/NormalizedGaussianSampler.kt +++ /dev/null @@ -1,4 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -interface NormalizedGaussianSampler : - ContinuousSampler diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/PoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/PoissonSampler.kt deleted file mode 100644 index a440c3a7d..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/PoissonSampler.kt +++ /dev/null @@ -1,40 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -import scientifik.kmath.commons.rng.UniformRandomProvider - -class PoissonSampler( - rng: UniformRandomProvider, - mean: Double -) : SamplerBase(null), - SharedStateDiscreteSampler { - private val poissonSamplerDelegate: SharedStateDiscreteSampler - - override fun sample(): Int = poissonSamplerDelegate.sample() - override fun toString(): String = poissonSamplerDelegate.toString() - - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler? = -// Direct return of the optimised sampler - poissonSamplerDelegate.withUniformRandomProvider(rng) - - companion object { - const val PIVOT = 40.0 - - fun of( - rng: UniformRandomProvider, - mean: Double - ): SharedStateDiscreteSampler =// Each sampler should check the input arguments. - if (mean < PIVOT) SmallMeanPoissonSampler.of( - rng, - mean - ) else LargeMeanPoissonSampler.of( - rng, - mean - ) - } - - init { - // Delegate all work to specialised samplers. - poissonSamplerDelegate = - of(rng, mean) - } -} \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SamplerBase.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SamplerBase.kt deleted file mode 100644 index 4170e9db3..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SamplerBase.kt +++ /dev/null @@ -1,12 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -import scientifik.kmath.commons.rng.UniformRandomProvider - -@Deprecated("Since version 1.1. Class intended for internal use only.") -open class SamplerBase protected constructor(private val rng: UniformRandomProvider?) { - protected fun nextDouble(): Double = rng!!.nextDouble() - protected fun nextInt(): Int = rng!!.nextInt() - protected fun nextInt(max: Int): Int = rng!!.nextInt(max) - protected fun nextLong(): Long = rng!!.nextLong() - override fun toString(): String = "rng=$rng" -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SharedStateContinuousSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SharedStateContinuousSampler.kt deleted file mode 100644 index c833ce8de..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SharedStateContinuousSampler.kt +++ /dev/null @@ -1,7 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -import scientifik.kmath.commons.rng.sampling.SharedStateSampler - -interface SharedStateContinuousSampler : ContinuousSampler, - SharedStateSampler { -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SharedStateDiscreteSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SharedStateDiscreteSampler.kt deleted file mode 100644 index 7b17e4670..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SharedStateDiscreteSampler.kt +++ /dev/null @@ -1,7 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -import scientifik.kmath.commons.rng.sampling.SharedStateSampler - -interface SharedStateDiscreteSampler : DiscreteSampler, - SharedStateSampler { // Composite interface -} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SmallMeanPoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SmallMeanPoissonSampler.kt deleted file mode 100644 index a35dd277c..000000000 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/SmallMeanPoissonSampler.kt +++ /dev/null @@ -1,60 +0,0 @@ -package scientifik.kmath.commons.rng.sampling.distribution - -import scientifik.kmath.commons.rng.UniformRandomProvider -import kotlin.math.ceil -import kotlin.math.exp - -class SmallMeanPoissonSampler : - SharedStateDiscreteSampler { - private val p0: Double - private val limit: Int - private val rng: UniformRandomProvider - - constructor( - rng: UniformRandomProvider, - mean: Double - ) { - this.rng = rng - require(mean > 0) { "mean is not strictly positive: $mean" } - p0 = exp(-mean) - - limit = (if (p0 > 0) ceil(1000 * mean) else throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt() - // This excludes NaN values for the mean - // else - // The returned sample is bounded by 1000 * mean - } - - private constructor( - rng: UniformRandomProvider, - source: SmallMeanPoissonSampler - ) { - this.rng = rng - p0 = source.p0 - limit = source.limit - } - - override fun sample(): Int { - var n = 0 - var r = 1.0 - - while (n < limit) { - r *= rng.nextDouble() - if (r >= p0) n++ else break - } - - return n - } - - override fun toString(): String = "Small Mean Poisson deviate [$rng]" - - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler = - SmallMeanPoissonSampler(rng, this) - - companion object { - fun of( - rng: UniformRandomProvider, - mean: Double - ): SharedStateDiscreteSampler = - SmallMeanPoissonSampler(rng, mean) - } -} \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt index 3b874adaa..a99052171 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt @@ -1,5 +1,6 @@ package scientifik.kmath.prob +import kotlinx.coroutines.flow.first import scientifik.kmath.chains.Chain import scientifik.kmath.chains.collect import scientifik.kmath.structures.Buffer @@ -69,6 +70,8 @@ fun Sampler.sampleBuffer( } } +suspend fun Sampler.next(generator: RandomGenerator) = sample(generator).first() + /** * Generate a bunch of samples from real distributions */ diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt index ea526c058..ae3f918ff 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt @@ -12,7 +12,6 @@ interface NamedDistribution : Distribution> * A multivariate distribution that has independent distributions for separate axis */ class FactorizedDistribution(val distributions: Collection>) : NamedDistribution { - override fun probability(arg: Map): Double { return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) } } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt index 2a225fe47..4b42db927 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt @@ -7,13 +7,11 @@ import kotlin.random.Random */ interface RandomGenerator { fun nextBoolean(): Boolean - fun nextDouble(): Double fun nextInt(): Int fun nextInt(until: Int): Int fun nextLong(): Long fun nextLong(until: Long): Long - fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size) fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) } @@ -28,21 +26,16 @@ interface RandomGenerator { companion object { val default by lazy { DefaultGenerator() } - fun default(seed: Long) = DefaultGenerator(Random(seed)) } } -inline class DefaultGenerator(val random: Random = Random) : RandomGenerator { +inline class DefaultGenerator(private val random: Random = Random) : RandomGenerator { override fun nextBoolean(): Boolean = random.nextBoolean() - override fun nextDouble(): Double = random.nextDouble() - override fun nextInt(): Int = random.nextInt() override fun nextInt(until: Int): Int = random.nextInt(until) - override fun nextLong(): Long = random.nextLong() - override fun nextLong(until: Long): Long = random.nextLong(until) override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { @@ -50,6 +43,5 @@ inline class DefaultGenerator(val random: Random = Random) : RandomGenerator { } override fun nextBytes(size: Int): ByteArray = random.nextBytes(size) - override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong()) } \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/AhrensDieterExponentialSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/AhrensDieterExponentialSampler.kt new file mode 100644 index 000000000..5dec2b641 --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/AhrensDieterExponentialSampler.kt @@ -0,0 +1,68 @@ +package scientifik.kmath.prob.samplers + +import scientifik.kmath.chains.Chain +import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.Sampler +import scientifik.kmath.prob.chain +import kotlin.math.ln +import kotlin.math.pow + +class AhrensDieterExponentialSampler(val mean: Double) : Sampler { + init { + require(mean > 0) { "mean is not strictly positive: $mean" } + } + + override fun sample(generator: RandomGenerator): Chain = generator.chain { + // Step 1: + var a = 0.0 + var u = nextDouble() + + // Step 2 and 3: + while (u < 0.5) { + a += EXPONENTIAL_SA_QI[0] + u *= 2.0 + } + + // Step 4 (now u >= 0.5): + u += u - 1 + // Step 5: + if (u <= EXPONENTIAL_SA_QI[0]) return@chain mean * (a + u) + // Step 6: + var i = 0 // Should be 1, be we iterate before it in while using 0. + var u2 = nextDouble() + var umin = u2 + + // Step 7 and 8: + do { + ++i + u2 = nextDouble() + if (u2 < umin) umin = u2 + // Step 8: + } while (u > EXPONENTIAL_SA_QI[i]) // Ensured to exit since EXPONENTIAL_SA_QI[MAX] = 1. + + mean * (a + umin * EXPONENTIAL_SA_QI[0]) + } + + override fun toString(): String = "Ahrens-Dieter Exponential deviate" + + companion object { + private val EXPONENTIAL_SA_QI = DoubleArray(16) + + fun of(mean: Double): Sampler = + AhrensDieterExponentialSampler(mean) + + init { + /** + * Filling EXPONENTIAL_SA_QI table. + * Note that we don't want qi = 0 in the table. + */ + val ln2 = ln(2.0) + var qi = 0.0 + + EXPONENTIAL_SA_QI.indices.forEach { i -> + qi += ln2.pow(i + 1.0) / InternalUtils.factorial(i + 1) + EXPONENTIAL_SA_QI[i] = qi + } + } + } +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/BoxMullerNormalizedGaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/BoxMullerNormalizedGaussianSampler.kt new file mode 100644 index 000000000..3235b06f5 --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/BoxMullerNormalizedGaussianSampler.kt @@ -0,0 +1,37 @@ +package scientifik.kmath.prob.samplers + +import scientifik.kmath.chains.Chain +import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.Sampler +import scientifik.kmath.prob.chain +import kotlin.math.* + +class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler { + private var nextGaussian: Double = Double.NaN + + override fun sample(generator: RandomGenerator): Chain = generator.chain { + val random: Double + + if (nextGaussian.isNaN()) { + // Generate a pair of Gaussian numbers. + val x = nextDouble() + val y = nextDouble() + val alpha = 2 * PI * x + val r = sqrt(-2 * ln(y)) + // Return the first element of the generated pair. + random = r * cos(alpha) + // Keep second element of the pair for next invocation. + nextGaussian = r * sin(alpha) + } else { + // Use the second element of the pair (generated at the + // previous invocation). + random = nextGaussian + // Both elements of the pair have been used. + nextGaussian = Double.NaN + } + + random + } + + override fun toString(): String = "Box-Muller normalized Gaussian deviate" +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/GaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/GaussianSampler.kt new file mode 100644 index 000000000..0c984ab6a --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/GaussianSampler.kt @@ -0,0 +1,34 @@ +package scientifik.kmath.prob.samplers + +import scientifik.kmath.chains.Chain +import scientifik.kmath.chains.map +import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.Sampler + +class GaussianSampler( + private val mean: Double, + private val standardDeviation: Double, + private val normalized: NormalizedGaussianSampler +) : Sampler { + + init { + require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" } + } + + override fun sample(generator: RandomGenerator): Chain = + normalized.sample(generator).map { standardDeviation * it + mean } + + override fun toString(): String = "Gaussian deviate [$normalized]" + + companion object { + fun of( + normalized: NormalizedGaussianSampler, + mean: Double, + standardDeviation: Double + ) = GaussianSampler( + mean, + standardDeviation, + normalized + ) + } +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/InternalGamma.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/InternalGamma.kt similarity index 93% rename from kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/InternalGamma.kt rename to kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/InternalGamma.kt index 79426d7ae..50c5f4ce0 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/InternalGamma.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/InternalGamma.kt @@ -1,10 +1,10 @@ -package scientifik.kmath.commons.rng.sampling.distribution +package scientifik.kmath.prob.samplers import kotlin.math.PI import kotlin.math.ln internal object InternalGamma { - const val LANCZOS_G = 607.0 / 128.0 + private const val LANCZOS_G = 607.0 / 128.0 private val LANCZOS_COEFFICIENTS = doubleArrayOf( 0.99999999999999709182, diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/InternalUtils.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/InternalUtils.kt similarity index 55% rename from kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/InternalUtils.kt rename to kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/InternalUtils.kt index 0f563b956..f2f1b3865 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/InternalUtils.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/InternalUtils.kt @@ -1,7 +1,5 @@ -package scientifik.kmath.commons.rng.sampling.distribution +package scientifik.kmath.prob.samplers -import scientifik.kmath.commons.rng.UniformRandomProvider -import scientifik.kmath.commons.rng.sampling.SharedStateSampler import kotlin.math.ln import kotlin.math.min @@ -20,8 +18,8 @@ internal object InternalUtils { fun factorial(n: Int): Long = FACTORIALS[n] - fun validateProbabilities(probabilities: DoubleArray?): Double { - require(!(probabilities == null || probabilities.isEmpty())) { "Probabilities must not be empty." } + fun validateProbabilities(probabilities: DoubleArray): Double { + require(probabilities.isNotEmpty()) { "Probabilities must not be empty." } var sumProb = 0.0 probabilities.forEach { prob -> @@ -33,23 +31,10 @@ internal object InternalUtils { return sumProb } - fun validateProbability(probability: Double): Unit = - require(!(probability < 0 || probability.isInfinite() || probability.isNaN())) { "Invalid probability: $probability" } - - fun newNormalizedGaussianSampler( - sampler: NormalizedGaussianSampler, - rng: UniformRandomProvider - ): NormalizedGaussianSampler { - if (sampler !is SharedStateSampler<*>) throw UnsupportedOperationException("The underlying sampler cannot share state") - - val newSampler: Any = - (sampler as SharedStateSampler<*>).withUniformRandomProvider(rng) as? NormalizedGaussianSampler - ?: throw UnsupportedOperationException( - "The underlying sampler did not create a normalized Gaussian sampler" - ) - - return newSampler as NormalizedGaussianSampler - } + private fun validateProbability(probability: Double): Unit = + require(!(probability < 0 || probability.isInfinite() || probability.isNaN())) { + "Invalid probability: $probability" + } class FactorialLog private constructor( numValues: Int, @@ -68,23 +53,19 @@ internal object InternalUtils { BEGIN_LOG_FACTORIALS, endCopy) } // All values to be computed - else - endCopy = - BEGIN_LOG_FACTORIALS + else endCopy = BEGIN_LOG_FACTORIALS // Compute remaining values. (endCopy until numValues).forEach { i -> - if (i < FACTORIALS.size) logFactorials[i] = ln( - FACTORIALS[i].toDouble()) else logFactorials[i] = - logFactorials[i - 1] + ln(i.toDouble()) + if (i < FACTORIALS.size) + logFactorials[i] = ln(FACTORIALS[i].toDouble()) + else + logFactorials[i] = logFactorials[i - 1] + ln(i.toDouble()) } } fun withCache(cacheSize: Int): FactorialLog = - FactorialLog( - cacheSize, - logFactorials - ) + FactorialLog(cacheSize, logFactorials) fun value(n: Int): Double { if (n < logFactorials.size) @@ -98,10 +79,7 @@ internal object InternalUtils { companion object { fun create(): FactorialLog = - FactorialLog( - 0, - null - ) + FactorialLog(0, null) } } } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/KempSmallMeanPoissonSampler.kt similarity index 66% rename from kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.kt rename to kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/KempSmallMeanPoissonSampler.kt index 3072cf020..cdd4c4cd1 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/KempSmallMeanPoissonSampler.kt @@ -1,14 +1,16 @@ -package scientifik.kmath.commons.rng.sampling.distribution +package scientifik.kmath.prob.samplers -import scientifik.kmath.commons.rng.UniformRandomProvider +import scientifik.kmath.chains.Chain +import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.Sampler +import scientifik.kmath.prob.chain import kotlin.math.exp class KempSmallMeanPoissonSampler private constructor( - private val rng: UniformRandomProvider, private val p0: Double, private val mean: Double -) : SharedStateDiscreteSampler { - override fun sample(): Int { +) : Sampler { + override fun sample(generator: RandomGenerator): Chain = generator.chain { // Note on the algorithm: // - X is the unknown sample deviate (the output of the algorithm) // - x is the current value from the distribution @@ -16,7 +18,7 @@ class KempSmallMeanPoissonSampler private constructor( // - u is effectively the cumulative probability that the sample X // is equal or above the current value x, p(X>=x) // So if p(X>=x) > p(X=x) the sample must be above x, otherwise it is x - var u = rng.nextDouble() + var u = nextDouble() var x = 0 var p = p0 @@ -28,31 +30,20 @@ class KempSmallMeanPoissonSampler private constructor( // The algorithm listed in Kemp (1981) does not check that the rolling probability // is positive. This check is added to ensure no errors when the limit of the summation // 1 - sum(p(x)) is above 0 due to cumulative error in floating point arithmetic. - if (p == 0.0) return x + if (p == 0.0) return@chain x } - return x + x } - override fun toString(): String = "Kemp Small Mean Poisson deviate [$rng]" - - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateDiscreteSampler = - KempSmallMeanPoissonSampler(rng, p0, mean) + override fun toString(): String = "Kemp Small Mean Poisson deviate" companion object { - fun of( - rng: UniformRandomProvider, - mean: Double - ): SharedStateDiscreteSampler { + fun of(mean: Double): KempSmallMeanPoissonSampler { require(mean > 0) { "Mean is not strictly positive: $mean" } val p0: Double = exp(-mean) - // Probability must be positive. As mean increases then p(0) decreases. - if (p0 > 0) return KempSmallMeanPoissonSampler( - rng, - p0, - mean - ) + if (p0 > 0) return KempSmallMeanPoissonSampler(p0, mean) throw IllegalArgumentException("No probability for mean: $mean") } } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/LargeMeanPoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/LargeMeanPoissonSampler.kt new file mode 100644 index 000000000..33b1b3f8a --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/LargeMeanPoissonSampler.kt @@ -0,0 +1,132 @@ +package scientifik.kmath.prob.samplers + +import kotlinx.coroutines.flow.first +import scientifik.kmath.chains.Chain +import scientifik.kmath.chains.ConstantChain +import scientifik.kmath.chains.SimpleChain +import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.Sampler +import scientifik.kmath.prob.next +import kotlin.math.* + +class LargeMeanPoissonSampler(val mean: Double) : Sampler { + private val exponential: Sampler = + AhrensDieterExponentialSampler.of(1.0) + private val gaussian: Sampler = + ZigguratNormalizedGaussianSampler() + private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG!! + private val lambda: Double = floor(mean) + private val logLambda: Double = ln(lambda) + private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt()) + private val delta: Double = sqrt(lambda * ln(32 * lambda / PI + 1)) + private val halfDelta: Double = delta / 2 + private val twolpd: Double = 2 * lambda + delta + private val c1: Double = 1 / (8 * lambda) + private val a1: Double = sqrt(PI * twolpd) * exp(c1) + private val a2: Double = twolpd / delta * exp(-delta * (1 + delta) / twolpd) + private val aSum: Double = a1 + a2 + 1 + private val p1: Double = a1 / aSum + private val p2: Double = a2 / aSum + + private val smallMeanPoissonSampler: Sampler = if (mean - lambda < Double.MIN_VALUE) + NO_SMALL_MEAN_POISSON_SAMPLER + else // Not used. + KempSmallMeanPoissonSampler.of(mean - lambda) + + init { + require(mean >= 1) { "mean is not >= 1: $mean" } + // The algorithm is not valid if Math.floor(mean) is not an integer. + require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" } + } + + override fun sample(generator: RandomGenerator): Chain = SimpleChain { + // This will never be null. It may be a no-op delegate that returns zero. + val y2 = smallMeanPoissonSampler.next(generator) + var x: Double + var y: Double + var v: Double + var a: Int + var t: Double + var qr: Double + var qa: Double + + while (true) { + // Step 1: + val u = generator.nextDouble() + + if (u <= p1) { + // Step 2: + val n = gaussian.next(generator) + x = n * sqrt(lambda + halfDelta) - 0.5 + if (x > delta || x < -lambda) continue + y = if (x < 0) floor(x) else ceil(x) + val e = exponential.next(generator) + v = -e - 0.5 * n * n + c1 + } else { + // Step 3: + if (u > p1 + p2) { + y = lambda + break + } + + x = delta + twolpd / delta * exponential.next(generator) + y = ceil(x) + v = -exponential.next(generator) - delta * (x + 1) / twolpd + } + + // The Squeeze Principle + // Step 4.1: + a = if (x < 0) 1 else 0 + t = y * (y + 1) / (2 * lambda) + + // Step 4.2 + if (v < -t && a == 0) { + y += lambda + break + } + + // Step 4.3: + qr = t * ((2 * y + 1) / (6 * lambda) - 1) + qa = qr - t * t / (3 * (lambda + a * (y + 1))) + + // Step 4.4: + if (v < qa) { + y += lambda + break + } + + // Step 4.5: + if (v > qr) continue + + // Step 4.6: + if (v < y * logLambda - getFactorialLog((y + lambda).toInt()) + logLambdaFactorial) { + y += lambda + break + } + } + + min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt() + } + + private fun getFactorialLog(n: Int): Double = factorialLog.value(n) + + override fun toString(): String = "Large Mean Poisson deviate" + + companion object { + private const val MAX_MEAN = 0.5 * Int.MAX_VALUE + private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null + + private val NO_SMALL_MEAN_POISSON_SAMPLER = object : Sampler { + override fun sample(generator: RandomGenerator): Chain = ConstantChain(0) + } + + fun of(mean: Double): LargeMeanPoissonSampler = + LargeMeanPoissonSampler(mean) + + init { + // Create without a cache. + NO_CACHE_FACTORIAL_LOG = + InternalUtils.FactorialLog.create() + } + } +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/MarsagliaNormalizedGaussianSampler.kt similarity index 53% rename from kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.kt rename to kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/MarsagliaNormalizedGaussianSampler.kt index ee346718e..c44943056 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/MarsagliaNormalizedGaussianSampler.kt @@ -1,35 +1,42 @@ -package scientifik.kmath.commons.rng.sampling.distribution +package scientifik.kmath.prob.samplers -import scientifik.kmath.commons.rng.UniformRandomProvider +import scientifik.kmath.chains.Chain +import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.Sampler +import scientifik.kmath.prob.chain import kotlin.math.ln import kotlin.math.sqrt -class MarsagliaNormalizedGaussianSampler(private val rng: UniformRandomProvider) : - NormalizedGaussianSampler, - SharedStateContinuousSampler { +class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler { private var nextGaussian = Double.NaN - override fun sample(): Double { + override fun sample(generator: RandomGenerator): Chain = generator.chain { if (nextGaussian.isNaN()) { + val alpha: Double + var x: Double + // Rejection scheme for selecting a pair that lies within the unit circle. while (true) { // Generate a pair of numbers within [-1 , 1). - val x = 2.0 * rng.nextDouble() - 1.0 - val y = 2.0 * rng.nextDouble() - 1.0 + x = 2.0 * generator.nextDouble() - 1.0 + val y = 2.0 * generator.nextDouble() - 1.0 val r2 = x * x + y * y + if (r2 < 1 && r2 > 0) { // Pair (x, y) is within unit circle. - val alpha = sqrt(-2 * ln(r2) / r2) + alpha = sqrt(-2 * ln(r2) / r2) // Keep second element of the pair for next invocation. nextGaussian = alpha * y // Return the first element of the generated pair. - return alpha * x + break } - // Pair is not within the unit circle: Generate another one. } + + // Return the first element of the generated pair. + alpha * x } else { // Use the second element of the pair (generated at the // previous invocation). @@ -37,18 +44,9 @@ class MarsagliaNormalizedGaussianSampler(private val rng: UniformRandomProvider) // Both elements of the pair have been used. nextGaussian = Double.NaN - return r + r } } - override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate [$rng]" - - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler = - MarsagliaNormalizedGaussianSampler(rng) - - companion object { - @Suppress("UNCHECKED_CAST") - fun of(rng: UniformRandomProvider): S where S : NormalizedGaussianSampler?, S : SharedStateContinuousSampler? = - MarsagliaNormalizedGaussianSampler(rng) as S - } + override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate" } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/NormalizedGaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/NormalizedGaussianSampler.kt new file mode 100644 index 000000000..0e5d6db59 --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/NormalizedGaussianSampler.kt @@ -0,0 +1,5 @@ +package scientifik.kmath.prob.samplers + +import scientifik.kmath.prob.Sampler + +interface NormalizedGaussianSampler : Sampler diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/PoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/PoissonSampler.kt new file mode 100644 index 000000000..d3c53fbb4 --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/PoissonSampler.kt @@ -0,0 +1,29 @@ +package scientifik.kmath.prob.samplers + +import scientifik.kmath.chains.Chain +import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.Sampler + + +class PoissonSampler( + mean: Double +) : Sampler { + private val poissonSamplerDelegate: Sampler + + init { + // Delegate all work to specialised samplers. + poissonSamplerDelegate = of(mean) + } + + override fun sample(generator: RandomGenerator): Chain = poissonSamplerDelegate.sample(generator) + override fun toString(): String = poissonSamplerDelegate.toString() + + companion object { + private const val PIVOT = 40.0 + + fun of(mean: Double) =// Each sampler should check the input arguments. + if (mean < PIVOT) SmallMeanPoissonSampler.of( + mean + ) else LargeMeanPoissonSampler.of(mean) + } +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/SmallMeanPoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/SmallMeanPoissonSampler.kt new file mode 100644 index 000000000..6509563dd --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/SmallMeanPoissonSampler.kt @@ -0,0 +1,43 @@ +package scientifik.kmath.prob.samplers + +import scientifik.kmath.chains.Chain +import scientifik.kmath.chains.SimpleChain +import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.Sampler +import scientifik.kmath.prob.chain +import kotlin.math.ceil +import kotlin.math.exp + +class SmallMeanPoissonSampler(mean: Double) : Sampler { + private val p0: Double + private val limit: Int + + init { + require(mean > 0) { "mean is not strictly positive: $mean" } + p0 = exp(-mean) + + limit = (if (p0 > 0) + ceil(1000 * mean) + else + throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt() + } + + override fun sample(generator: RandomGenerator): Chain = generator.chain { + var n = 0 + var r = 1.0 + + while (n < limit) { + r *= nextDouble() + if (r >= p0) n++ else break + } + + n + } + + override fun toString(): String = "Small Mean Poisson deviate" + + companion object { + fun of(mean: Double): SmallMeanPoissonSampler = + SmallMeanPoissonSampler(mean) + } +} diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/ZigguratNormalizedGaussianSampler.kt similarity index 58% rename from kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.kt rename to kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/ZigguratNormalizedGaussianSampler.kt index 26e089599..57e7fc47e 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/ZigguratNormalizedGaussianSampler.kt @@ -1,38 +1,76 @@ -package scientifik.kmath.commons.rng.sampling.distribution +package scientifik.kmath.prob.samplers -import scientifik.kmath.commons.rng.UniformRandomProvider +import scientifik.kmath.chains.Chain +import scientifik.kmath.chains.SimpleChain +import scientifik.kmath.prob.RandomGenerator +import scientifik.kmath.prob.Sampler +import scientifik.kmath.prob.chain import kotlin.math.* -class ZigguratNormalizedGaussianSampler(private val rng: UniformRandomProvider) : - NormalizedGaussianSampler, - SharedStateContinuousSampler { +class ZigguratNormalizedGaussianSampler() : + NormalizedGaussianSampler, Sampler { + + private fun sampleOne(generator: RandomGenerator): Double { + val j = generator.nextLong() + val i = (j and LAST.toLong()).toInt() + return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i) + } + + override fun sample(generator: RandomGenerator): Chain = generator.chain { sampleOne(this) } + + override fun toString(): String = "Ziggurat normalized Gaussian deviate" + + private fun fix( + generator: RandomGenerator, + hz: Long, + iz: Int + ): Double { + var x: Double + var y: Double + x = hz * W[iz] + + return if (iz == 0) { + // Base strip. + // This branch is called about 5.7624515E-4 times per sample. + do { + y = -ln(generator.nextDouble()) + x = -ln(generator.nextDouble()) * ONE_OVER_R + } while (y + y < x * x) + + val out = R + x + if (hz > 0) out else -out + } else { + // Wedge of other strips. + // This branch is called about 0.027323 times per sample. + // else + // Try again. + // This branch is called about 0.012362 times per sample. + if (F[iz] + generator.nextDouble() * (F[iz - 1] - F[iz]) < gauss( + x + ) + ) x else sampleOne(generator) + } + } companion object { - private const val R = 3.442619855899 + private const val R: Double = 3.442619855899 private const val ONE_OVER_R: Double = 1 / R private const val V = 9.91256303526217e-3 private val MAX: Double = 2.0.pow(63.0) private val ONE_OVER_MAX: Double = 1.0 / MAX - private const val LEN = 128 + private const val LEN: Int = 128 private const val LAST: Int = LEN - 1 - private val K = LongArray(LEN) - private val W = DoubleArray(LEN) - private val F = DoubleArray(LEN) + private val K: LongArray = LongArray(LEN) + private val W: DoubleArray = DoubleArray(LEN) + private val F: DoubleArray = DoubleArray(LEN) private fun gauss(x: Double): Double = exp(-0.5 * x * x) - @Suppress("UNCHECKED_CAST") - fun of(rng: UniformRandomProvider): S where S : NormalizedGaussianSampler?, S : SharedStateContinuousSampler? = - ZigguratNormalizedGaussianSampler(rng) as S - init { // Filling the tables. - var d = - R + var d = R var t = d var fd = - gauss( - d - ) + gauss(d) val q = V / fd K[0] = (d / q * MAX).toLong() K[1] = 0 @@ -54,46 +92,4 @@ class ZigguratNormalizedGaussianSampler(private val rng: UniformRandomProvider) } } } - - override fun sample(): Double { - val j = rng.nextLong() - val i = (j and LAST.toLong()).toInt() - return if (abs(j) < K[i]) j * W[i] else fix(j, i) - } - - override fun toString(): String = "Ziggurat normalized Gaussian deviate [$rng]" - - private fun fix( - hz: Long, - iz: Int - ): Double { - var x: Double - var y: Double - x = hz * W[iz] - - return if (iz == 0) { - // Base strip. - // This branch is called about 5.7624515E-4 times per sample. - do { - y = -ln(rng.nextDouble()) - x = -ln(rng.nextDouble()) * ONE_OVER_R - } while (y + y < x * x) - - val out = R + x - if (hz > 0) out else -out - } else { - // Wedge of other strips. - // This branch is called about 0.027323 times per sample. - // else - // Try again. - // This branch is called about 0.012362 times per sample. - if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < gauss( - x - ) - ) x else sample() - } - } - - override fun withUniformRandomProvider(rng: UniformRandomProvider): SharedStateContinuousSampler = - ZigguratNormalizedGaussianSampler(rng) } diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/DistributionsJVM.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/DistributionsJVM.kt deleted file mode 100644 index 649cae961..000000000 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/DistributionsJVM.kt +++ /dev/null @@ -1,66 +0,0 @@ -package scientifik.kmath.prob - -import scientifik.kmath.commons.rng.sampling.distribution.ContinuousSampler -import scientifik.kmath.commons.rng.sampling.distribution.DiscreteSampler -import scientifik.kmath.commons.rng.sampling.distribution.GaussianSampler -import scientifik.kmath.commons.rng.sampling.distribution.PoissonSampler -import kotlin.math.PI -import kotlin.math.exp -import kotlin.math.pow -import kotlin.math.sqrt - -fun Distribution.Companion.normal( - method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat -): Distribution = object : ContinuousSamplerDistribution() { - override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { - val provider = generator.asUniformRandomProvider() - return normalSampler(method, provider) - } - - override fun probability(arg: Double): Double { - return exp(-arg.pow(2) / 2) / sqrt(PI * 2) - } -} - -fun Distribution.Companion.normal( - mean: Double, - sigma: Double, - method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat -): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() { - private val sigma2 = sigma.pow(2) - private val norm = sigma * sqrt(PI * 2) - - override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { - val provider = generator.asUniformRandomProvider() - val normalizedSampler = normalSampler(method, provider) - return GaussianSampler( - normalizedSampler, - mean, - sigma - ) - } - - override fun probability(arg: Double): Double { - return exp(-(arg - mean).pow(2) / 2 / sigma2) / norm - } -} - -fun Distribution.Companion.poisson( - lambda: Double -): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() { - - override fun buildSampler(generator: RandomGenerator): DiscreteSampler { - return PoissonSampler.of(generator.asUniformRandomProvider(), lambda) - } - - private val computedProb: HashMap = hashMapOf(0 to exp(-lambda)) - - override fun probability(arg: Int): Double { - require(arg >= 0) { "The argument must be >= 0" } - - return if (arg > 40) - exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda) - else - computedProb.getOrPut(arg) { probability(arg - 1) * lambda / arg } - } -} diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt index 4c8d630ae..797c932ad 100644 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt @@ -1,11 +1,10 @@ package scientifik.kmath.prob import org.apache.commons.rng.simple.RandomSource -import scientifik.kmath.commons.rng.UniformRandomProvider -class RandomSourceGenerator(val source: RandomSource, seed: Long?) : +class RandomSourceGenerator(private val source: RandomSource, seed: Long?) : RandomGenerator { - internal val random = seed?.let { + private val random = seed?.let { RandomSource.create(source, seed) } ?: RandomSource.create(source) @@ -24,24 +23,6 @@ class RandomSourceGenerator(val source: RandomSource, seed: Long?) : override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong()) } -/** - * Represent this [RandomGenerator] as commons-rng [UniformRandomProvider] preserving and mirroring its current state. - * Getting new value from one of those changes the state of another. - */ -fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this is RandomSourceGenerator) { - object : UniformRandomProvider { - override fun nextBytes(bytes: ByteArray) = random.nextBytes(bytes) - override fun nextBytes(bytes: ByteArray, start: Int, len: Int) = random.nextBytes(bytes, start, len) - override fun nextInt(): Int = random.nextInt() - override fun nextInt(n: Int): Int = random.nextInt(n) - override fun nextLong(): Long = random.nextLong() - override fun nextLong(n: Long): Long = random.nextLong(n) - override fun nextBoolean(): Boolean = random.nextBoolean() - override fun nextFloat(): Float = random.nextFloat() - override fun nextDouble(): Double = random.nextDouble() - } -} else RandomGeneratorProvider(this) - fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator = RandomSourceGenerator(source, seed)