From e7b1203c2ddca46225c9714d86856365d906b029 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Mon, 8 Jun 2020 17:29:57 +0700 Subject: [PATCH] Move init blocks checks to factory method of and make all the samplers' constructor private --- .../AhrensDieterExponentialSampler.kt | 16 ++++----- .../BoxMullerNormalizedGaussianSampler.kt | 6 +++- .../kmath/prob/samplers/GaussianSampler.kt | 25 +++++++------- .../samplers/KempSmallMeanPoissonSampler.kt | 6 ++-- .../prob/samplers/LargeMeanPoissonSampler.kt | 33 +++++++------------ .../MarsagliaNormalizedGaussianSampler.kt | 6 +++- .../kmath/prob/samplers/PoissonSampler.kt | 14 ++------ .../prob/samplers/SmallMeanPoissonSampler.kt | 24 ++++++-------- .../ZigguratNormalizedGaussianSampler.kt | 6 ++-- 9 files changed, 60 insertions(+), 76 deletions(-) diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/AhrensDieterExponentialSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/AhrensDieterExponentialSampler.kt index 5dec2b641..794c1010d 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/AhrensDieterExponentialSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/AhrensDieterExponentialSampler.kt @@ -7,11 +7,7 @@ import scientifik.kmath.prob.chain import kotlin.math.ln import kotlin.math.pow -class AhrensDieterExponentialSampler(val mean: Double) : Sampler { - init { - require(mean > 0) { "mean is not strictly positive: $mean" } - } - +class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sampler { override fun sample(generator: RandomGenerator): Chain = generator.chain { // Step 1: var a = 0.0 @@ -46,10 +42,7 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler { override fun toString(): String = "Ahrens-Dieter Exponential deviate" companion object { - private val EXPONENTIAL_SA_QI = DoubleArray(16) - - fun of(mean: Double): Sampler = - AhrensDieterExponentialSampler(mean) + private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) } init { /** @@ -64,5 +57,10 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler { EXPONENTIAL_SA_QI[i] = qi } } + + fun of(mean: Double): AhrensDieterExponentialSampler { + require(mean > 0) { "mean is not strictly positive: $mean" } + return AhrensDieterExponentialSampler(mean) + } } } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/BoxMullerNormalizedGaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/BoxMullerNormalizedGaussianSampler.kt index 3235b06f5..57c38d6b7 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/BoxMullerNormalizedGaussianSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/BoxMullerNormalizedGaussianSampler.kt @@ -6,7 +6,7 @@ import scientifik.kmath.prob.Sampler import scientifik.kmath.prob.chain import kotlin.math.* -class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler { +class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler { private var nextGaussian: Double = Double.NaN override fun sample(generator: RandomGenerator): Chain = generator.chain { @@ -34,4 +34,8 @@ class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler { - - init { - require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" } - } - override fun sample(generator: RandomGenerator): Chain = normalized.sample(generator).map { standardDeviation * it + mean } @@ -22,13 +17,17 @@ class GaussianSampler( companion object { fun of( - normalized: NormalizedGaussianSampler, mean: Double, - standardDeviation: Double - ) = GaussianSampler( - mean, - standardDeviation, - normalized - ) + standardDeviation: Double, + normalized: NormalizedGaussianSampler + ): GaussianSampler { + require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" } + + return GaussianSampler( + mean, + standardDeviation, + normalized + ) + } } } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/KempSmallMeanPoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/KempSmallMeanPoissonSampler.kt index cdd4c4cd1..65e762b68 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/KempSmallMeanPoissonSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/KempSmallMeanPoissonSampler.kt @@ -41,10 +41,10 @@ class KempSmallMeanPoissonSampler private constructor( companion object { fun of(mean: Double): KempSmallMeanPoissonSampler { require(mean > 0) { "Mean is not strictly positive: $mean" } - val p0: Double = exp(-mean) + val p0 = exp(-mean) // Probability must be positive. As mean increases then p(0) decreases. - if (p0 > 0) return KempSmallMeanPoissonSampler(p0, mean) - throw IllegalArgumentException("No probability for mean: $mean") + require(p0 > 0) { "No probability for mean: $mean" } + return KempSmallMeanPoissonSampler(p0, mean) } } } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/LargeMeanPoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/LargeMeanPoissonSampler.kt index 33b1b3f8a..39a3dc168 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/LargeMeanPoissonSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/LargeMeanPoissonSampler.kt @@ -1,6 +1,5 @@ package scientifik.kmath.prob.samplers -import kotlinx.coroutines.flow.first import scientifik.kmath.chains.Chain import scientifik.kmath.chains.ConstantChain import scientifik.kmath.chains.SimpleChain @@ -9,12 +8,10 @@ import scientifik.kmath.prob.Sampler import scientifik.kmath.prob.next import kotlin.math.* -class LargeMeanPoissonSampler(val mean: Double) : Sampler { - private val exponential: Sampler = - AhrensDieterExponentialSampler.of(1.0) - private val gaussian: Sampler = - ZigguratNormalizedGaussianSampler() - private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG!! +class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler { + private val exponential: Sampler = AhrensDieterExponentialSampler.of(1.0) + private val gaussian: Sampler = ZigguratNormalizedGaussianSampler() + private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG private val lambda: Double = floor(mean) private val logLambda: Double = ln(lambda) private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt()) @@ -33,12 +30,6 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler { else // Not used. KempSmallMeanPoissonSampler.of(mean - lambda) - init { - require(mean >= 1) { "mean is not >= 1: $mean" } - // The algorithm is not valid if Math.floor(mean) is not an integer. - require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" } - } - override fun sample(generator: RandomGenerator): Chain = SimpleChain { // This will never be null. It may be a no-op delegate that returns zero. val y2 = smallMeanPoissonSampler.next(generator) @@ -113,20 +104,18 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler { override fun toString(): String = "Large Mean Poisson deviate" companion object { - private const val MAX_MEAN = 0.5 * Int.MAX_VALUE - private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null + private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE + private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create() private val NO_SMALL_MEAN_POISSON_SAMPLER = object : Sampler { override fun sample(generator: RandomGenerator): Chain = ConstantChain(0) } - fun of(mean: Double): LargeMeanPoissonSampler = - LargeMeanPoissonSampler(mean) - - init { - // Create without a cache. - NO_CACHE_FACTORIAL_LOG = - InternalUtils.FactorialLog.create() + fun of(mean: Double): LargeMeanPoissonSampler { + require(mean >= 1) { "mean is not >= 1: $mean" } + // The algorithm is not valid if Math.floor(mean) is not an integer. + require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" } + return LargeMeanPoissonSampler(mean) } } } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/MarsagliaNormalizedGaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/MarsagliaNormalizedGaussianSampler.kt index c44943056..dd8f82a2c 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/MarsagliaNormalizedGaussianSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/MarsagliaNormalizedGaussianSampler.kt @@ -7,7 +7,7 @@ import scientifik.kmath.prob.chain import kotlin.math.ln import kotlin.math.sqrt -class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler { +class MarsagliaNormalizedGaussianSampler private constructor(): NormalizedGaussianSampler, Sampler { private var nextGaussian = Double.NaN override fun sample(generator: RandomGenerator): Chain = generator.chain { @@ -49,4 +49,8 @@ class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler { - private val poissonSamplerDelegate: Sampler - - init { - // Delegate all work to specialised samplers. - poissonSamplerDelegate = of(mean) - } - + private val poissonSamplerDelegate: Sampler = of(mean) override fun sample(generator: RandomGenerator): Chain = poissonSamplerDelegate.sample(generator) override fun toString(): String = poissonSamplerDelegate.toString() @@ -22,8 +16,6 @@ class PoissonSampler( private const val PIVOT = 40.0 fun of(mean: Double) =// Each sampler should check the input arguments. - if (mean < PIVOT) SmallMeanPoissonSampler.of( - mean - ) else LargeMeanPoissonSampler.of(mean) + if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean) } } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/SmallMeanPoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/SmallMeanPoissonSampler.kt index 6509563dd..7022b0f3a 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/SmallMeanPoissonSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/SmallMeanPoissonSampler.kt @@ -8,19 +8,13 @@ import scientifik.kmath.prob.chain import kotlin.math.ceil import kotlin.math.exp -class SmallMeanPoissonSampler(mean: Double) : Sampler { - private val p0: Double - private val limit: Int +class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler { + private val p0: Double = exp(-mean) - init { - require(mean > 0) { "mean is not strictly positive: $mean" } - p0 = exp(-mean) - - limit = (if (p0 > 0) - ceil(1000 * mean) - else - throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt() - } + private val limit: Int = (if (p0 > 0) + ceil(1000 * mean) + else + throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt() override fun sample(generator: RandomGenerator): Chain = generator.chain { var n = 0 @@ -37,7 +31,9 @@ class SmallMeanPoissonSampler(mean: Double) : Sampler { override fun toString(): String = "Small Mean Poisson deviate" companion object { - fun of(mean: Double): SmallMeanPoissonSampler = - SmallMeanPoissonSampler(mean) + fun of(mean: Double): SmallMeanPoissonSampler { + require(mean > 0) { "mean is not strictly positive: $mean" } + return SmallMeanPoissonSampler(mean) + } } } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/ZigguratNormalizedGaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/ZigguratNormalizedGaussianSampler.kt index 57e7fc47e..3d22c38a5 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/ZigguratNormalizedGaussianSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/samplers/ZigguratNormalizedGaussianSampler.kt @@ -7,7 +7,7 @@ import scientifik.kmath.prob.Sampler import scientifik.kmath.prob.chain import kotlin.math.* -class ZigguratNormalizedGaussianSampler() : +class ZigguratNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler { private fun sampleOne(generator: RandomGenerator): Double { @@ -63,7 +63,6 @@ class ZigguratNormalizedGaussianSampler() : private val K: LongArray = LongArray(LEN) private val W: DoubleArray = DoubleArray(LEN) private val F: DoubleArray = DoubleArray(LEN) - private fun gauss(x: Double): Double = exp(-0.5 * x * x) init { // Filling the tables. @@ -91,5 +90,8 @@ class ZigguratNormalizedGaussianSampler() : W[i] = d * ONE_OVER_MAX } } + + fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler() + private fun gauss(x: Double): Double = exp(-0.5 * x * x) } }