Move init blocks checks to factory method of and make all the samplers' constructor private
This commit is contained in:
parent
28062cb096
commit
e7b1203c2d
@ -7,11 +7,7 @@ import scientifik.kmath.prob.chain
|
|||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
|
class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sampler<Double> {
|
||||||
init {
|
|
||||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||||
// Step 1:
|
// Step 1:
|
||||||
var a = 0.0
|
var a = 0.0
|
||||||
@ -46,10 +42,7 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
|
|||||||
override fun toString(): String = "Ahrens-Dieter Exponential deviate"
|
override fun toString(): String = "Ahrens-Dieter Exponential deviate"
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private val EXPONENTIAL_SA_QI = DoubleArray(16)
|
private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) }
|
||||||
|
|
||||||
fun of(mean: Double): Sampler<Double> =
|
|
||||||
AhrensDieterExponentialSampler(mean)
|
|
||||||
|
|
||||||
init {
|
init {
|
||||||
/**
|
/**
|
||||||
@ -64,5 +57,10 @@ class AhrensDieterExponentialSampler(val mean: Double) : Sampler<Double> {
|
|||||||
EXPONENTIAL_SA_QI[i] = qi
|
EXPONENTIAL_SA_QI[i] = qi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun of(mean: Double): AhrensDieterExponentialSampler {
|
||||||
|
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||||
|
return AhrensDieterExponentialSampler(mean)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import scientifik.kmath.prob.Sampler
|
|||||||
import scientifik.kmath.prob.chain
|
import scientifik.kmath.prob.chain
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> {
|
class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
||||||
private var nextGaussian: Double = Double.NaN
|
private var nextGaussian: Double = Double.NaN
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||||
@ -34,4 +34,8 @@ class BoxMullerNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Do
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "Box-Muller normalized Gaussian deviate"
|
override fun toString(): String = "Box-Muller normalized Gaussian deviate"
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,16 +5,11 @@ import scientifik.kmath.chains.map
|
|||||||
import scientifik.kmath.prob.RandomGenerator
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import scientifik.kmath.prob.Sampler
|
||||||
|
|
||||||
class GaussianSampler(
|
class GaussianSampler private constructor(
|
||||||
private val mean: Double,
|
private val mean: Double,
|
||||||
private val standardDeviation: Double,
|
private val standardDeviation: Double,
|
||||||
private val normalized: NormalizedGaussianSampler
|
private val normalized: NormalizedGaussianSampler
|
||||||
) : Sampler<Double> {
|
) : Sampler<Double> {
|
||||||
|
|
||||||
init {
|
|
||||||
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> =
|
override fun sample(generator: RandomGenerator): Chain<Double> =
|
||||||
normalized.sample(generator).map { standardDeviation * it + mean }
|
normalized.sample(generator).map { standardDeviation * it + mean }
|
||||||
|
|
||||||
@ -22,13 +17,17 @@ class GaussianSampler(
|
|||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun of(
|
fun of(
|
||||||
normalized: NormalizedGaussianSampler,
|
|
||||||
mean: Double,
|
mean: Double,
|
||||||
standardDeviation: Double
|
standardDeviation: Double,
|
||||||
) = GaussianSampler(
|
normalized: NormalizedGaussianSampler
|
||||||
mean,
|
): GaussianSampler {
|
||||||
standardDeviation,
|
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
|
||||||
normalized
|
|
||||||
)
|
return GaussianSampler(
|
||||||
|
mean,
|
||||||
|
standardDeviation,
|
||||||
|
normalized
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -41,10 +41,10 @@ class KempSmallMeanPoissonSampler private constructor(
|
|||||||
companion object {
|
companion object {
|
||||||
fun of(mean: Double): KempSmallMeanPoissonSampler {
|
fun of(mean: Double): KempSmallMeanPoissonSampler {
|
||||||
require(mean > 0) { "Mean is not strictly positive: $mean" }
|
require(mean > 0) { "Mean is not strictly positive: $mean" }
|
||||||
val p0: Double = exp(-mean)
|
val p0 = exp(-mean)
|
||||||
// Probability must be positive. As mean increases then p(0) decreases.
|
// Probability must be positive. As mean increases then p(0) decreases.
|
||||||
if (p0 > 0) return KempSmallMeanPoissonSampler(p0, mean)
|
require(p0 > 0) { "No probability for mean: $mean" }
|
||||||
throw IllegalArgumentException("No probability for mean: $mean")
|
return KempSmallMeanPoissonSampler(p0, mean)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package scientifik.kmath.prob.samplers
|
||||||
|
|
||||||
import kotlinx.coroutines.flow.first
|
|
||||||
import scientifik.kmath.chains.Chain
|
import scientifik.kmath.chains.Chain
|
||||||
import scientifik.kmath.chains.ConstantChain
|
import scientifik.kmath.chains.ConstantChain
|
||||||
import scientifik.kmath.chains.SimpleChain
|
import scientifik.kmath.chains.SimpleChain
|
||||||
@ -9,12 +8,10 @@ import scientifik.kmath.prob.Sampler
|
|||||||
import scientifik.kmath.prob.next
|
import scientifik.kmath.prob.next
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
|
class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<Int> {
|
||||||
private val exponential: Sampler<Double> =
|
private val exponential: Sampler<Double> = AhrensDieterExponentialSampler.of(1.0)
|
||||||
AhrensDieterExponentialSampler.of(1.0)
|
private val gaussian: Sampler<Double> = ZigguratNormalizedGaussianSampler()
|
||||||
private val gaussian: Sampler<Double> =
|
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG
|
||||||
ZigguratNormalizedGaussianSampler()
|
|
||||||
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG!!
|
|
||||||
private val lambda: Double = floor(mean)
|
private val lambda: Double = floor(mean)
|
||||||
private val logLambda: Double = ln(lambda)
|
private val logLambda: Double = ln(lambda)
|
||||||
private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt())
|
private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt())
|
||||||
@ -33,12 +30,6 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
|
|||||||
else // Not used.
|
else // Not used.
|
||||||
KempSmallMeanPoissonSampler.of(mean - lambda)
|
KempSmallMeanPoissonSampler.of(mean - lambda)
|
||||||
|
|
||||||
init {
|
|
||||||
require(mean >= 1) { "mean is not >= 1: $mean" }
|
|
||||||
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
|
||||||
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = SimpleChain {
|
override fun sample(generator: RandomGenerator): Chain<Int> = SimpleChain {
|
||||||
// This will never be null. It may be a no-op delegate that returns zero.
|
// This will never be null. It may be a no-op delegate that returns zero.
|
||||||
val y2 = smallMeanPoissonSampler.next(generator)
|
val y2 = smallMeanPoissonSampler.next(generator)
|
||||||
@ -113,20 +104,18 @@ class LargeMeanPoissonSampler(val mean: Double) : Sampler<Int> {
|
|||||||
override fun toString(): String = "Large Mean Poisson deviate"
|
override fun toString(): String = "Large Mean Poisson deviate"
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private const val MAX_MEAN = 0.5 * Int.MAX_VALUE
|
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
|
||||||
private var NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog? = null
|
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
|
||||||
|
|
||||||
private val NO_SMALL_MEAN_POISSON_SAMPLER = object : Sampler<Int> {
|
private val NO_SMALL_MEAN_POISSON_SAMPLER = object : Sampler<Int> {
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
|
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun of(mean: Double): LargeMeanPoissonSampler =
|
fun of(mean: Double): LargeMeanPoissonSampler {
|
||||||
LargeMeanPoissonSampler(mean)
|
require(mean >= 1) { "mean is not >= 1: $mean" }
|
||||||
|
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
||||||
init {
|
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
||||||
// Create without a cache.
|
return LargeMeanPoissonSampler(mean)
|
||||||
NO_CACHE_FACTORIAL_LOG =
|
|
||||||
InternalUtils.FactorialLog.create()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import scientifik.kmath.prob.chain
|
|||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Double> {
|
class MarsagliaNormalizedGaussianSampler private constructor(): NormalizedGaussianSampler, Sampler<Double> {
|
||||||
private var nextGaussian = Double.NaN
|
private var nextGaussian = Double.NaN
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||||
@ -49,4 +49,8 @@ class MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler, Sampler<Do
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
|
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,16 +5,10 @@ import scientifik.kmath.prob.RandomGenerator
|
|||||||
import scientifik.kmath.prob.Sampler
|
import scientifik.kmath.prob.Sampler
|
||||||
|
|
||||||
|
|
||||||
class PoissonSampler(
|
class PoissonSampler private constructor(
|
||||||
mean: Double
|
mean: Double
|
||||||
) : Sampler<Int> {
|
) : Sampler<Int> {
|
||||||
private val poissonSamplerDelegate: Sampler<Int>
|
private val poissonSamplerDelegate: Sampler<Int> = of(mean)
|
||||||
|
|
||||||
init {
|
|
||||||
// Delegate all work to specialised samplers.
|
|
||||||
poissonSamplerDelegate = of(mean)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
|
override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
|
||||||
override fun toString(): String = poissonSamplerDelegate.toString()
|
override fun toString(): String = poissonSamplerDelegate.toString()
|
||||||
|
|
||||||
@ -22,8 +16,6 @@ class PoissonSampler(
|
|||||||
private const val PIVOT = 40.0
|
private const val PIVOT = 40.0
|
||||||
|
|
||||||
fun of(mean: Double) =// Each sampler should check the input arguments.
|
fun of(mean: Double) =// Each sampler should check the input arguments.
|
||||||
if (mean < PIVOT) SmallMeanPoissonSampler.of(
|
if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean)
|
||||||
mean
|
|
||||||
) else LargeMeanPoissonSampler.of(mean)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,19 +8,13 @@ import scientifik.kmath.prob.chain
|
|||||||
import kotlin.math.ceil
|
import kotlin.math.ceil
|
||||||
import kotlin.math.exp
|
import kotlin.math.exp
|
||||||
|
|
||||||
class SmallMeanPoissonSampler(mean: Double) : Sampler<Int> {
|
class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
||||||
private val p0: Double
|
private val p0: Double = exp(-mean)
|
||||||
private val limit: Int
|
|
||||||
|
|
||||||
init {
|
private val limit: Int = (if (p0 > 0)
|
||||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
ceil(1000 * mean)
|
||||||
p0 = exp(-mean)
|
else
|
||||||
|
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
|
||||||
limit = (if (p0 > 0)
|
|
||||||
ceil(1000 * mean)
|
|
||||||
else
|
|
||||||
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||||
var n = 0
|
var n = 0
|
||||||
@ -37,7 +31,9 @@ class SmallMeanPoissonSampler(mean: Double) : Sampler<Int> {
|
|||||||
override fun toString(): String = "Small Mean Poisson deviate"
|
override fun toString(): String = "Small Mean Poisson deviate"
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun of(mean: Double): SmallMeanPoissonSampler =
|
fun of(mean: Double): SmallMeanPoissonSampler {
|
||||||
SmallMeanPoissonSampler(mean)
|
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||||
|
return SmallMeanPoissonSampler(mean)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import scientifik.kmath.prob.Sampler
|
|||||||
import scientifik.kmath.prob.chain
|
import scientifik.kmath.prob.chain
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
class ZigguratNormalizedGaussianSampler() :
|
class ZigguratNormalizedGaussianSampler private constructor() :
|
||||||
NormalizedGaussianSampler, Sampler<Double> {
|
NormalizedGaussianSampler, Sampler<Double> {
|
||||||
|
|
||||||
private fun sampleOne(generator: RandomGenerator): Double {
|
private fun sampleOne(generator: RandomGenerator): Double {
|
||||||
@ -63,7 +63,6 @@ class ZigguratNormalizedGaussianSampler() :
|
|||||||
private val K: LongArray = LongArray(LEN)
|
private val K: LongArray = LongArray(LEN)
|
||||||
private val W: DoubleArray = DoubleArray(LEN)
|
private val W: DoubleArray = DoubleArray(LEN)
|
||||||
private val F: DoubleArray = DoubleArray(LEN)
|
private val F: DoubleArray = DoubleArray(LEN)
|
||||||
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
|
||||||
|
|
||||||
init {
|
init {
|
||||||
// Filling the tables.
|
// Filling the tables.
|
||||||
@ -91,5 +90,8 @@ class ZigguratNormalizedGaussianSampler() :
|
|||||||
W[i] = d * ONE_OVER_MAX
|
W[i] = d * ONE_OVER_MAX
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler()
|
||||||
|
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user