Move init blocks checks to factory method of and make all the samplers' constructor private
This commit is contained in:
parent
28062cb096
commit
e7b1203c2d
@ -7,11 +7,7 @@ import scientifik.kmath.prob.chain
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.pow
|
||||
|
||||
class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
|
||||
init {
|
||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||
}
|
||||
|
||||
class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sampler<Double> {
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||
// Step 1:
|
||||
var a = 0.0
|
||||
@ -46,10 +42,7 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
|
||||
override fun toString(): String = "Ahrens-Dieter Exponential deviate"
|
||||
|
||||
companion object {
|
||||
private val EXPONENTIAL_SA_QI = DoubleArray(16)
|
||||
|
||||
fun of(mean: Double): Sampler<Double> =
|
||||
AhrensDieterExponentialSampler(mean)
|
||||
private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) }
|
||||
|
||||
init {
|
||||
/**
|
||||
@ -64,5 +57,10 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
|
||||
EXPONENTIAL_SA_QI[i] = qi
|
||||
}
|
||||
}
|
||||
|
||||
fun of(mean: Double): AhrensDieterExponentialSampler {
|
||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||
return AhrensDieterExponentialSampler(mean)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import kotlin.math.*
|
||||
|
||||
class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> {
|
||||
class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
||||
private var nextGaussian: Double = Double.NaN
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||
@ -34,4 +34,8 @@ class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Do
|
||||
}
|
||||
|
||||
override fun toString(): String = "Box-Muller normalized Gaussian deviate"
|
||||
|
||||
companion object {
|
||||
fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
|
||||
}
|
||||
}
|
||||
|
@ -5,16 +5,11 @@ import scientifik.kmath.chains.map
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
|
||||
class GaussianSampler(
|
||||
class GaussianSampler private constructor(
|
||||
private val mean: Double,
|
||||
private val standardDeviation: Double,
|
||||
private val normalized: NormalizedGaussianSampler
|
||||
) : Sampler<Double> {
|
||||
|
||||
init {
|
||||
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
|
||||
}
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> =
|
||||
normalized.sample(generator).map { standardDeviation * it + mean }
|
||||
|
||||
@ -22,13 +17,17 @@ class GaussianSampler(
|
||||
|
||||
companion object {
|
||||
fun of(
|
||||
normalized: NormalizedGaussianSampler,
|
||||
mean: Double,
|
||||
standardDeviation: Double
|
||||
) = GaussianSampler(
|
||||
standardDeviation: Double,
|
||||
normalized: NormalizedGaussianSampler
|
||||
): GaussianSampler {
|
||||
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
|
||||
|
||||
return GaussianSampler(
|
||||
mean,
|
||||
standardDeviation,
|
||||
normalized
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -41,10 +41,10 @@ class KempSmallMeanPoissonSampler private constructor(
|
||||
companion object {
|
||||
fun of(mean: Double): KempSmallMeanPoissonSampler {
|
||||
require(mean > 0) { "Mean is not strictly positive: $mean" }
|
||||
val p0: Double = exp(-mean)
|
||||
val p0 = exp(-mean)
|
||||
// Probability must be positive. As mean increases then p(0) decreases.
|
||||
if (p0 > 0) return KempSmallMeanPoissonSampler(p0, mean)
|
||||
throw IllegalArgumentException("No probability for mean: $mean")
|
||||
require(p0 > 0) { "No probability for mean: $mean" }
|
||||
return KempSmallMeanPoissonSampler(p0, mean)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
|
||||
import kotlinx.coroutines.flow.first
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.chains.ConstantChain
|
||||
import scientifik.kmath.chains.SimpleChain
|
||||
@ -9,12 +8,10 @@ import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.next
|
||||
import kotlin.math.*
|
||||
|
||||
class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
|
||||
private val exponential: Sampler<Double> =
|
||||
AhrensDieterExponentialSampler.of(1.0)
|
||||
private val gaussian: Sampler<Double> =
|
||||
ZigguratNormalizedGaussianSampler()
|
||||
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG!!
|
||||
class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<Int> {
|
||||
private val exponential: Sampler<Double> = AhrensDieterExponentialSampler.of(1.0)
|
||||
private val gaussian: Sampler<Double> = ZigguratNormalizedGaussianSampler()
|
||||
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG
|
||||
private val lambda: Double = floor(mean)
|
||||
private val logLambda: Double = ln(lambda)
|
||||
private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt())
|
||||
@ -33,12 +30,6 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
|
||||
else // Not used.
|
||||
KempSmallMeanPoissonSampler.of(mean - lambda)
|
||||
|
||||
init {
|
||||
require(mean >= 1) { "mean is not >= 1: $mean" }
|
||||
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
||||
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
||||
}
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Int> = SimpleChain {
|
||||
// This will never be null. It may be a no-op delegate that returns zero.
|
||||
val y2 = smallMeanPoissonSampler.next(generator)
|
||||
@ -113,20 +104,18 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
|
||||
override fun toString(): String = "Large Mean Poisson deviate"
|
||||
|
||||
companion object {
|
||||
private const val MAX_MEAN = 0.5 * Int.MAX_VALUE
|
||||
private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null
|
||||
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
|
||||
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
|
||||
|
||||
private val NO_SMALL_MEAN_POISSON_SAMPLER = object : Sampler<Int> {
|
||||
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
|
||||
}
|
||||
|
||||
fun of(mean: Double): LargeMeanPoissonSampler =
|
||||
LargeMeanPoissonSampler(mean)
|
||||
|
||||
init {
|
||||
// Create without a cache.
|
||||
NO_CACHE_FACTORIAL_LOG =
|
||||
InternalUtils.FactorialLog.create()
|
||||
fun of(mean: Double): LargeMeanPoissonSampler {
|
||||
require(mean >= 1) { "mean is not >= 1: $mean" }
|
||||
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
||||
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
||||
return LargeMeanPoissonSampler(mean)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import scientifik.kmath.prob.chain
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.sqrt
|
||||
|
||||
class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> {
|
||||
class MarsagliaNormalizedGaussianSampler private constructor(): NormalizedGaussianSampler, Sampler<Double> {
|
||||
private var nextGaussian = Double.NaN
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||
@ -49,4 +49,8 @@ class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Do
|
||||
}
|
||||
|
||||
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
|
||||
|
||||
companion object {
|
||||
fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler()
|
||||
}
|
||||
}
|
||||
|
@ -5,16 +5,10 @@ import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
|
||||
|
||||
class PoissonSampler(
|
||||
class PoissonSampler private constructor(
|
||||
mean: Double
|
||||
) : Sampler<Int> {
|
||||
private val poissonSamplerDelegate: Sampler<Int>
|
||||
|
||||
init {
|
||||
// Delegate all work to specialised samplers.
|
||||
poissonSamplerDelegate = of(mean)
|
||||
}
|
||||
|
||||
private val poissonSamplerDelegate: Sampler<Int> = of(mean)
|
||||
override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
|
||||
override fun toString(): String = poissonSamplerDelegate.toString()
|
||||
|
||||
@ -22,8 +16,6 @@ class PoissonSampler(
|
||||
private const val PIVOT = 40.0
|
||||
|
||||
fun of(mean: Double) =// Each sampler should check the input arguments.
|
||||
if (mean < PIVOT) SmallMeanPoissonSampler.of(
|
||||
mean
|
||||
) else LargeMeanPoissonSampler.of(mean)
|
||||
if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean)
|
||||
}
|
||||
}
|
||||
|
@ -8,19 +8,13 @@ import scientifik.kmath.prob.chain
|
||||
import kotlin.math.ceil
|
||||
import kotlin.math.exp
|
||||
|
||||
class SmallMeanPoissonSampler(mean: Double) : Sampler<Int> {
|
||||
private val p0: Double
|
||||
private val limit: Int
|
||||
class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
||||
private val p0: Double = exp(-mean)
|
||||
|
||||
init {
|
||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||
p0 = exp(-mean)
|
||||
|
||||
limit = (if (p0 > 0)
|
||||
private val limit: Int = (if (p0 > 0)
|
||||
ceil(1000 * mean)
|
||||
else
|
||||
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
|
||||
}
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||
var n = 0
|
||||
@ -37,7 +31,9 @@ class SmallMeanPoissonSampler(mean: Double) : Sampler<Int> {
|
||||
override fun toString(): String = "Small Mean Poisson deviate"
|
||||
|
||||
companion object {
|
||||
fun of(mean: Double): SmallMeanPoissonSampler =
|
||||
SmallMeanPoissonSampler(mean)
|
||||
fun of(mean: Double): SmallMeanPoissonSampler {
|
||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||
return SmallMeanPoissonSampler(mean)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import kotlin.math.*
|
||||
|
||||
class ZigguratNormalizedGaussianSampler() :
|
||||
class ZigguratNormalizedGaussianSampler private constructor() :
|
||||
NormalizedGaussianSampler, Sampler<Double> {
|
||||
|
||||
private fun sampleOne(generator: RandomGenerator): Double {
|
||||
@ -63,7 +63,6 @@ class ZigguratNormalizedGaussianSampler() :
|
||||
private val K: LongArray = LongArray(LEN)
|
||||
private val W: DoubleArray = DoubleArray(LEN)
|
||||
private val F: DoubleArray = DoubleArray(LEN)
|
||||
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
||||
|
||||
init {
|
||||
// Filling the tables.
|
||||
@ -91,5 +90,8 @@ class ZigguratNormalizedGaussianSampler() :
|
||||
W[i] = d * ONE_OVER_MAX
|
||||
}
|
||||
}
|
||||
|
||||
fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler()
|
||||
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user