Move init blocks checks to factory method of and make all the samplers' constructor private

This commit is contained in:
Iaroslav 2020-06-08 17:29:57 +07:00
parent 28062cb096
commit e7b1203c2d
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
9 changed files with 60 additions and 76 deletions

View File

@ -7,11 +7,7 @@ import scientifik.kmath.prob.chain
import kotlin.math.ln
import kotlin.math.pow
class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
init {
require(mean > 0) { "mean is not strictly positive: $mean" }
}
class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sampler<Double> {
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
// Step 1:
var a = 0.0
@ -46,10 +42,7 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
override fun toString(): String = "Ahrens-Dieter Exponential deviate"
companion object {
private val EXPONENTIAL_SA_QI = DoubleArray(16)
fun of(mean: Double): Sampler<Double> =
AhrensDieterExponentialSampler(mean)
private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) }
init {
/**
@ -64,5 +57,10 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
EXPONENTIAL_SA_QI[i] = qi
}
}
fun of(mean: Double): AhrensDieterExponentialSampler {
require(mean > 0) { "mean is not strictly positive: $mean" }
return AhrensDieterExponentialSampler(mean)
}
}
}

View File

@ -6,7 +6,7 @@ import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.chain
import kotlin.math.*
class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> {
class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
private var nextGaussian: Double = Double.NaN
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
@ -34,4 +34,8 @@ class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Do
}
override fun toString(): String = "Box-Muller normalized Gaussian deviate"
companion object {
fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
}
}

View File

@ -5,16 +5,11 @@ import scientifik.kmath.chains.map
import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler
class GaussianSampler(
class GaussianSampler private constructor(
private val mean: Double,
private val standardDeviation: Double,
private val normalized: NormalizedGaussianSampler
) : Sampler<Double> {
init {
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
}
override fun sample(generator: RandomGenerator): Chain<Double> =
normalized.sample(generator).map { standardDeviation * it + mean }
@ -22,13 +17,17 @@ class GaussianSampler(
companion object {
fun of(
normalized: NormalizedGaussianSampler,
mean: Double,
standardDeviation: Double
) = GaussianSampler(
mean,
standardDeviation,
normalized
)
standardDeviation: Double,
normalized: NormalizedGaussianSampler
): GaussianSampler {
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
return GaussianSampler(
mean,
standardDeviation,
normalized
)
}
}
}

View File

@ -41,10 +41,10 @@ class KempSmallMeanPoissonSampler private constructor(
companion object {
fun of(mean: Double): KempSmallMeanPoissonSampler {
require(mean > 0) { "Mean is not strictly positive: $mean" }
val p0: Double = exp(-mean)
val p0 = exp(-mean)
// Probability must be positive. As mean increases then p(0) decreases.
if (p0 > 0) return KempSmallMeanPoissonSampler(p0, mean)
throw IllegalArgumentException("No probability for mean: $mean")
require(p0 > 0) { "No probability for mean: $mean" }
return KempSmallMeanPoissonSampler(p0, mean)
}
}
}

View File

@ -1,6 +1,5 @@
package scientifik.kmath.prob.samplers
import kotlinx.coroutines.flow.first
import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.ConstantChain
import scientifik.kmath.chains.SimpleChain
@ -9,12 +8,10 @@ import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.next
import kotlin.math.*
class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
private val exponential: Sampler<Double> =
AhrensDieterExponentialSampler.of(1.0)
private val gaussian: Sampler<Double> =
ZigguratNormalizedGaussianSampler()
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG!!
class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<Int> {
private val exponential: Sampler<Double> = AhrensDieterExponentialSampler.of(1.0)
private val gaussian: Sampler<Double> = ZigguratNormalizedGaussianSampler()
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG
private val lambda: Double = floor(mean)
private val logLambda: Double = ln(lambda)
private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt())
@ -33,12 +30,6 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
else // Not used.
KempSmallMeanPoissonSampler.of(mean - lambda)
init {
require(mean >= 1) { "mean is not >= 1: $mean" }
// The algorithm is not valid if Math.floor(mean) is not an integer.
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
}
override fun sample(generator: RandomGenerator): Chain<Int> = SimpleChain {
// This will never be null. It may be a no-op delegate that returns zero.
val y2 = smallMeanPoissonSampler.next(generator)
@ -113,20 +104,18 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
override fun toString(): String = "Large Mean Poisson deviate"
companion object {
private const val MAX_MEAN = 0.5 * Int.MAX_VALUE
private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
private val NO_SMALL_MEAN_POISSON_SAMPLER = object : Sampler<Int> {
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
}
fun of(mean: Double): LargeMeanPoissonSampler =
LargeMeanPoissonSampler(mean)
init {
// Create without a cache.
NO_CACHE_FACTORIAL_LOG =
InternalUtils.FactorialLog.create()
fun of(mean: Double): LargeMeanPoissonSampler {
require(mean >= 1) { "mean is not >= 1: $mean" }
// The algorithm is not valid if Math.floor(mean) is not an integer.
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
return LargeMeanPoissonSampler(mean)
}
}
}

View File

@ -7,7 +7,7 @@ import scientifik.kmath.prob.chain
import kotlin.math.ln
import kotlin.math.sqrt
class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> {
class MarsagliaNormalizedGaussianSampler private constructor(): NormalizedGaussianSampler, Sampler<Double> {
private var nextGaussian = Double.NaN
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
@ -49,4 +49,8 @@ class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Do
}
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
companion object {
fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler()
}
}

View File

@ -5,16 +5,10 @@ import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler
class PoissonSampler(
class PoissonSampler private constructor(
mean: Double
) : Sampler<Int> {
private val poissonSamplerDelegate: Sampler<Int>
init {
// Delegate all work to specialised samplers.
poissonSamplerDelegate = of(mean)
}
private val poissonSamplerDelegate: Sampler<Int> = of(mean)
override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
override fun toString(): String = poissonSamplerDelegate.toString()
@ -22,8 +16,6 @@ class PoissonSampler(
private const val PIVOT = 40.0
fun of(mean: Double) =// Each sampler should check the input arguments.
if (mean < PIVOT) SmallMeanPoissonSampler.of(
mean
) else LargeMeanPoissonSampler.of(mean)
if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean)
}
}

View File

@ -8,19 +8,13 @@ import scientifik.kmath.prob.chain
import kotlin.math.ceil
import kotlin.math.exp
class SmallMeanPoissonSampler(mean: Double) : Sampler<Int> {
private val p0: Double
private val limit: Int
class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
private val p0: Double = exp(-mean)
init {
require(mean > 0) { "mean is not strictly positive: $mean" }
p0 = exp(-mean)
limit = (if (p0 > 0)
ceil(1000 * mean)
else
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
}
private val limit: Int = (if (p0 > 0)
ceil(1000 * mean)
else
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
var n = 0
@ -37,7 +31,9 @@ class SmallMeanPoissonSampler(mean: Double) : Sampler<Int> {
override fun toString(): String = "Small Mean Poisson deviate"
companion object {
fun of(mean: Double): SmallMeanPoissonSampler =
SmallMeanPoissonSampler(mean)
fun of(mean: Double): SmallMeanPoissonSampler {
require(mean > 0) { "mean is not strictly positive: $mean" }
return SmallMeanPoissonSampler(mean)
}
}
}

View File

@ -7,7 +7,7 @@ import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.chain
import kotlin.math.*
class ZigguratNormalizedGaussianSampler() :
class ZigguratNormalizedGaussianSampler private constructor() :
NormalizedGaussianSampler, Sampler<Double> {
private fun sampleOne(generator: RandomGenerator): Double {
@ -63,7 +63,6 @@ class ZigguratNormalizedGaussianSampler() :
private val K: LongArray = LongArray(LEN)
private val W: DoubleArray = DoubleArray(LEN)
private val F: DoubleArray = DoubleArray(LEN)
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
init {
// Filling the tables.
@ -91,5 +90,8 @@ class ZigguratNormalizedGaussianSampler() :
W[i] = d * ONE_OVER_MAX
}
}
fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler()
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
}
}