Move init blocks checks to factory method of and make all the samplers' constructor private

This commit is contained in:
Iaroslav 2020-06-08 17:29:57 +07:00
parent 28062cb096
commit e7b1203c2d
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
9 changed files with 60 additions and 76 deletions

View File

@ -7,11 +7,7 @@ import scientifik.kmath.prob.chain
import kotlin.math.ln import kotlin.math.ln
import kotlin.math.pow import kotlin.math.pow
class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> { class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sampler<Double> {
init {
require(mean > 0) { "mean is not strictly positive: $mean" }
}
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
// Step 1: // Step 1:
var a = 0.0 var a = 0.0
@ -46,10 +42,7 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
override fun toString(): String = "Ahrens-Dieter Exponential deviate" override fun toString(): String = "Ahrens-Dieter Exponential deviate"
companion object { companion object {
private val EXPONENTIAL_SA_QI = DoubleArray(16) private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) }
fun of(mean: Double): Sampler<Double> =
AhrensDieterExponentialSampler(mean)
init { init {
/** /**
@ -64,5 +57,10 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
EXPONENTIAL_SA_QI[i] = qi EXPONENTIAL_SA_QI[i] = qi
} }
} }
fun of(mean: Double): AhrensDieterExponentialSampler {
require(mean > 0) { "mean is not strictly positive: $mean" }
return AhrensDieterExponentialSampler(mean)
}
} }
} }

View File

@ -6,7 +6,7 @@ import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.chain import scientifik.kmath.prob.chain
import kotlin.math.* import kotlin.math.*
class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> { class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
private var nextGaussian: Double = Double.NaN private var nextGaussian: Double = Double.NaN
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
@ -34,4 +34,8 @@ class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Do
} }
override fun toString(): String = "Box-Muller normalized Gaussian deviate" override fun toString(): String = "Box-Muller normalized Gaussian deviate"
companion object {
fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
}
} }

View File

@ -5,16 +5,11 @@ import scientifik.kmath.chains.map
import scientifik.kmath.prob.RandomGenerator import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import scientifik.kmath.prob.Sampler
class GaussianSampler( class GaussianSampler private constructor(
private val mean: Double, private val mean: Double,
private val standardDeviation: Double, private val standardDeviation: Double,
private val normalized: NormalizedGaussianSampler private val normalized: NormalizedGaussianSampler
) : Sampler<Double> { ) : Sampler<Double> {
init {
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
}
override fun sample(generator: RandomGenerator): Chain<Double> = override fun sample(generator: RandomGenerator): Chain<Double> =
normalized.sample(generator).map { standardDeviation * it + mean } normalized.sample(generator).map { standardDeviation * it + mean }
@ -22,13 +17,17 @@ class GaussianSampler(
companion object { companion object {
fun of( fun of(
normalized: NormalizedGaussianSampler,
mean: Double, mean: Double,
standardDeviation: Double standardDeviation: Double,
) = GaussianSampler( normalized: NormalizedGaussianSampler
): GaussianSampler {
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
return GaussianSampler(
mean, mean,
standardDeviation, standardDeviation,
normalized normalized
) )
} }
}
} }

View File

@ -41,10 +41,10 @@ class KempSmallMeanPoissonSampler private constructor(
companion object { companion object {
fun of(mean: Double): KempSmallMeanPoissonSampler { fun of(mean: Double): KempSmallMeanPoissonSampler {
require(mean > 0) { "Mean is not strictly positive: $mean" } require(mean > 0) { "Mean is not strictly positive: $mean" }
val p0: Double = exp(-mean) val p0 = exp(-mean)
// Probability must be positive. As mean increases then p(0) decreases. // Probability must be positive. As mean increases then p(0) decreases.
if (p0 > 0) return KempSmallMeanPoissonSampler(p0, mean) require(p0 > 0) { "No probability for mean: $mean" }
throw IllegalArgumentException("No probability for mean: $mean") return KempSmallMeanPoissonSampler(p0, mean)
} }
} }
} }

View File

@ -1,6 +1,5 @@
package scientifik.kmath.prob.samplers package scientifik.kmath.prob.samplers
import kotlinx.coroutines.flow.first
import scientifik.kmath.chains.Chain import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.ConstantChain import scientifik.kmath.chains.ConstantChain
import scientifik.kmath.chains.SimpleChain import scientifik.kmath.chains.SimpleChain
@ -9,12 +8,10 @@ import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.next import scientifik.kmath.prob.next
import kotlin.math.* import kotlin.math.*
class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> { class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<Int> {
private val exponential: Sampler<Double> = private val exponential: Sampler<Double> = AhrensDieterExponentialSampler.of(1.0)
AhrensDieterExponentialSampler.of(1.0) private val gaussian: Sampler<Double> = ZigguratNormalizedGaussianSampler()
private val gaussian: Sampler<Double> = private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG
ZigguratNormalizedGaussianSampler()
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG!!
private val lambda: Double = floor(mean) private val lambda: Double = floor(mean)
private val logLambda: Double = ln(lambda) private val logLambda: Double = ln(lambda)
private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt()) private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt())
@ -33,12 +30,6 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
else // Not used. else // Not used.
KempSmallMeanPoissonSampler.of(mean - lambda) KempSmallMeanPoissonSampler.of(mean - lambda)
init {
require(mean >= 1) { "mean is not >= 1: $mean" }
// The algorithm is not valid if Math.floor(mean) is not an integer.
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
}
override fun sample(generator: RandomGenerator): Chain<Int> = SimpleChain { override fun sample(generator: RandomGenerator): Chain<Int> = SimpleChain {
// This will never be null. It may be a no-op delegate that returns zero. // This will never be null. It may be a no-op delegate that returns zero.
val y2 = smallMeanPoissonSampler.next(generator) val y2 = smallMeanPoissonSampler.next(generator)
@ -113,20 +104,18 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
override fun toString(): String = "Large Mean Poisson deviate" override fun toString(): String = "Large Mean Poisson deviate"
companion object { companion object {
private const val MAX_MEAN = 0.5 * Int.MAX_VALUE private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
private val NO_SMALL_MEAN_POISSON_SAMPLER = object : Sampler<Int> { private val NO_SMALL_MEAN_POISSON_SAMPLER = object : Sampler<Int> {
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0) override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
} }
fun of(mean: Double): LargeMeanPoissonSampler = fun of(mean: Double): LargeMeanPoissonSampler {
LargeMeanPoissonSampler(mean) require(mean >= 1) { "mean is not >= 1: $mean" }
// The algorithm is not valid if Math.floor(mean) is not an integer.
init { require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
// Create without a cache. return LargeMeanPoissonSampler(mean)
NO_CACHE_FACTORIAL_LOG =
InternalUtils.FactorialLog.create()
} }
} }
} }

View File

@ -7,7 +7,7 @@ import scientifik.kmath.prob.chain
import kotlin.math.ln import kotlin.math.ln
import kotlin.math.sqrt import kotlin.math.sqrt
class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> { class MarsagliaNormalizedGaussianSampler private constructor(): NormalizedGaussianSampler, Sampler<Double> {
private var nextGaussian = Double.NaN private var nextGaussian = Double.NaN
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
@ -49,4 +49,8 @@ class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Do
} }
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate" override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
companion object {
fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler()
}
} }

View File

@ -5,16 +5,10 @@ import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import scientifik.kmath.prob.Sampler
class PoissonSampler( class PoissonSampler private constructor(
mean: Double mean: Double
) : Sampler<Int> { ) : Sampler<Int> {
private val poissonSamplerDelegate: Sampler<Int> private val poissonSamplerDelegate: Sampler<Int> = of(mean)
init {
// Delegate all work to specialised samplers.
poissonSamplerDelegate = of(mean)
}
override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator) override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
override fun toString(): String = poissonSamplerDelegate.toString() override fun toString(): String = poissonSamplerDelegate.toString()
@ -22,8 +16,6 @@ class PoissonSampler(
private const val PIVOT = 40.0 private const val PIVOT = 40.0
fun of(mean: Double) =// Each sampler should check the input arguments. fun of(mean: Double) =// Each sampler should check the input arguments.
if (mean < PIVOT) SmallMeanPoissonSampler.of( if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean)
mean
) else LargeMeanPoissonSampler.of(mean)
} }
} }

View File

@ -8,19 +8,13 @@ import scientifik.kmath.prob.chain
import kotlin.math.ceil import kotlin.math.ceil
import kotlin.math.exp import kotlin.math.exp
class SmallMeanPoissonSampler(mean: Double) : Sampler<Int> { class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
private val p0: Double private val p0: Double = exp(-mean)
private val limit: Int
init { private val limit: Int = (if (p0 > 0)
require(mean > 0) { "mean is not strictly positive: $mean" }
p0 = exp(-mean)
limit = (if (p0 > 0)
ceil(1000 * mean) ceil(1000 * mean)
else else
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt() throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
}
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain { override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
var n = 0 var n = 0
@ -37,7 +31,9 @@ class SmallMeanPoissonSampler(mean: Double) : Sampler<Int> {
override fun toString(): String = "Small Mean Poisson deviate" override fun toString(): String = "Small Mean Poisson deviate"
companion object { companion object {
fun of(mean: Double): SmallMeanPoissonSampler = fun of(mean: Double): SmallMeanPoissonSampler {
SmallMeanPoissonSampler(mean) require(mean > 0) { "mean is not strictly positive: $mean" }
return SmallMeanPoissonSampler(mean)
}
} }
} }

View File

@ -7,7 +7,7 @@ import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.chain import scientifik.kmath.prob.chain
import kotlin.math.* import kotlin.math.*
class ZigguratNormalizedGaussianSampler() : class ZigguratNormalizedGaussianSampler private constructor() :
NormalizedGaussianSampler, Sampler<Double> { NormalizedGaussianSampler, Sampler<Double> {
private fun sampleOne(generator: RandomGenerator): Double { private fun sampleOne(generator: RandomGenerator): Double {
@ -63,7 +63,6 @@ class ZigguratNormalizedGaussianSampler() :
private val K: LongArray = LongArray(LEN) private val K: LongArray = LongArray(LEN)
private val W: DoubleArray = DoubleArray(LEN) private val W: DoubleArray = DoubleArray(LEN)
private val F: DoubleArray = DoubleArray(LEN) private val F: DoubleArray = DoubleArray(LEN)
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
init { init {
// Filling the tables. // Filling the tables.
@ -91,5 +90,8 @@ class ZigguratNormalizedGaussianSampler() :
W[i] = d * ONE_OVER_MAX W[i] = d * ONE_OVER_MAX
} }
} }
fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler()
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
} }
} }