From 0c6fff3878d71c23f1254cb52e5a3139ac0f4697 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 15 Oct 2020 23:52:50 +0700 Subject: [PATCH] Code refactoring, implement NormalDistribution --- .../commons/expressions/DiffExpression.kt | 6 +- .../kscience/kmath/prob/Distribution.kt | 12 +-- .../prob/distributions/NormalDistribution.kt | 30 ++++++++ .../AhrensDieterExponentialSampler.kt | 1 + .../samplers/AliasMethodDiscreteSampler.kt | 1 + .../kmath/prob/samplers/GaussianSampler.kt | 7 +- .../kmath/prob/samplers/InternalGamma.kt | 41 ---------- .../kmath/prob/samplers/InternalUtils.kt | 76 ------------------- .../prob/samplers/LargeMeanPoissonSampler.kt | 8 +- .../kmath/prob/RandomSourceGenerator.kt | 5 +- 10 files changed, 45 insertions(+), 142 deletions(-) create mode 100644 kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/distributions/NormalDistribution.kt diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt index c39f0d04c..601675167 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt @@ -33,10 +33,7 @@ public class DerivativeStructureField( variables[name] ?: default ?: error("A variable with name $name does not exist") public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble()) - - public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double { - return deriv(mapOf(parName to order)) - } + public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double = deriv(mapOf(parName to order)) public fun DerivativeStructure.deriv(orders: Map): Double { return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray()) @@ -75,7 +72,6 @@ public class DerivativeStructureField( public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() - public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt index 965635e09..bbb1de1e3 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Distribution.kt @@ -6,7 +6,7 @@ import kscience.kmath.chains.collect import kscience.kmath.structures.Buffer import kscience.kmath.structures.BufferFactory -public interface Sampler { +public fun interface Sampler { public fun sample(generator: RandomGenerator): Chain } @@ -20,11 +20,7 @@ public interface Distribution : Sampler { */ public fun probability(arg: T): Double - /** - * Create a chain of samples from this distribution. - * The chain is not guaranteed to be stateless, but different sample chains should be independent. - */ - override fun sample(generator: RandomGenerator): Chain + public override fun sample(generator: RandomGenerator): Chain /** * An empty companion. Distribution factories should be written as its extensions @@ -63,9 +59,7 @@ public fun Sampler.sampleBuffer( //clear list from previous run tmp.clear() //Fill list - repeat(size) { - tmp.add(chain.next()) - } + repeat(size) { tmp.add(chain.next()) } //return new buffer with elements from tmp bufferFactory(size) { tmp[it] } } diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/distributions/NormalDistribution.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/distributions/NormalDistribution.kt new file mode 100644 index 000000000..095ac5ea9 --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/distributions/NormalDistribution.kt @@ -0,0 +1,30 @@ +package kscience.kmath.prob.distributions + +import kscience.kmath.chains.Chain +import kscience.kmath.prob.RandomGenerator +import kscience.kmath.prob.UnivariateDistribution +import kscience.kmath.prob.internal.InternalErf +import kscience.kmath.prob.samplers.GaussianSampler +import kotlin.math.* + +public inline class NormalDistribution(public val sampler: GaussianSampler) : UnivariateDistribution { + public override fun probability(arg: Double): Double { + val x1 = (arg - sampler.mean) / sampler.standardDeviation + return exp(-0.5 * x1 * x1 - (ln(sampler.standardDeviation) + 0.5 * ln(2 * PI))) + } + + public override fun sample(generator: RandomGenerator): Chain = sampler.sample(generator) + + public override fun cumulative(arg: Double): Double { + val dev = arg - sampler.mean + + return when { + abs(dev) > 40 * sampler.standardDeviation -> if (dev < 0) 0.0 else 1.0 + else -> 0.5 * InternalErf.erfc(-dev / (sampler.standardDeviation * SQRT2)) + } + } + + private companion object { + private val SQRT2 = sqrt(2.0) + } +} diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/AhrensDieterExponentialSampler.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/AhrensDieterExponentialSampler.kt index dc388a3ea..d4b443e9c 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/AhrensDieterExponentialSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/AhrensDieterExponentialSampler.kt @@ -4,6 +4,7 @@ import kscience.kmath.chains.Chain import kscience.kmath.prob.RandomGenerator import kscience.kmath.prob.Sampler import kscience.kmath.prob.chain +import kscience.kmath.prob.internal.InternalUtils import kotlin.math.ln import kotlin.math.pow diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/AliasMethodDiscreteSampler.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/AliasMethodDiscreteSampler.kt index c5eed2990..f04b5bcb0 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/AliasMethodDiscreteSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/AliasMethodDiscreteSampler.kt @@ -4,6 +4,7 @@ import kscience.kmath.chains.Chain import kscience.kmath.prob.RandomGenerator import kscience.kmath.prob.Sampler import kscience.kmath.prob.chain +import kscience.kmath.prob.internal.InternalUtils import kotlin.math.ceil import kotlin.math.max import kotlin.math.min diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/GaussianSampler.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/GaussianSampler.kt index 1a0ccac90..1a5e4cfdd 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/GaussianSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/GaussianSampler.kt @@ -10,10 +10,13 @@ import kscience.kmath.prob.Sampler * * Based on Commons RNG implementation. * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html + * + * @property mean the mean of the distribution. + * @property standardDeviation the variance of the distribution. */ public class GaussianSampler private constructor( - private val mean: Double, - private val standardDeviation: Double, + public val mean: Double, + public val standardDeviation: Double, private val normalized: NormalizedGaussianSampler ) : Sampler { public override fun sample(generator: RandomGenerator): Chain = normalized diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/InternalGamma.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/InternalGamma.kt index 16a5c96e0..d970d1447 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/InternalGamma.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/InternalGamma.kt @@ -1,43 +1,2 @@ package kscience.kmath.prob.samplers -import kotlin.math.PI -import kotlin.math.ln - -internal object InternalGamma { - private const val LANCZOS_G = 607.0 / 128.0 - - private val LANCZOS_COEFFICIENTS = doubleArrayOf( - 0.99999999999999709182, - 57.156235665862923517, - -59.597960355475491248, - 14.136097974741747174, - -0.49191381609762019978, - .33994649984811888699e-4, - .46523628927048575665e-4, - -.98374475304879564677e-4, - .15808870322491248884e-3, - -.21026444172410488319e-3, - .21743961811521264320e-3, - -.16431810653676389022e-3, - .84418223983852743293e-4, - -.26190838401581408670e-4, - .36899182659531622704e-5 - ) - - private val HALF_LOG_2_PI: Double = 0.5 * ln(2.0 * PI) - - fun logGamma(x: Double): Double { - // Stripped-down version of the same method defined in "Commons Math": - // Unused "if" branches (for when x < 8) have been removed here since - // this method is only used (by class "InternalUtils") in order to - // compute log(n!) for x > 20. - val sum = lanczos(x) - val tmp = x + LANCZOS_G + 0.5 - return (x + 0.5) * ln(tmp) - tmp + HALF_LOG_2_PI + ln(sum / x) - } - - private fun lanczos(x: Double): Double { - val sum = (LANCZOS_COEFFICIENTS.size - 1 downTo 1).sumByDouble { LANCZOS_COEFFICIENTS[it] / (x + it) } - return sum + LANCZOS_COEFFICIENTS[0] - } -} diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/InternalUtils.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/InternalUtils.kt index 08a321b75..d970d1447 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/InternalUtils.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/InternalUtils.kt @@ -1,78 +1,2 @@ package kscience.kmath.prob.samplers -import kotlin.math.ln -import kotlin.math.min - -internal object InternalUtils { - private val FACTORIALS = longArrayOf( - 1L, 1L, 2L, - 6L, 24L, 120L, - 720L, 5040L, 40320L, - 362880L, 3628800L, 39916800L, - 479001600L, 6227020800L, 87178291200L, - 1307674368000L, 20922789888000L, 355687428096000L, - 6402373705728000L, 121645100408832000L, 2432902008176640000L - ) - - private const val BEGIN_LOG_FACTORIALS = 2 - - fun factorial(n: Int): Long = FACTORIALS[n] - - fun validateProbabilities(probabilities: DoubleArray?): Double { - require(!(probabilities == null || probabilities.isEmpty())) { "Probabilities must not be empty." } - var sumProb = 0.0 - - probabilities.forEach { prob -> - validateProbability(prob) - sumProb += prob - } - - require(!(sumProb.isInfinite() || sumProb <= 0)) { "Invalid sum of probabilities: $sumProb" } - return sumProb - } - - private fun validateProbability(probability: Double): Unit = - require(!(probability < 0 || probability.isInfinite() || probability.isNaN())) { "Invalid probability: $probability" } - - class FactorialLog private constructor( - numValues: Int, - cache: DoubleArray? - ) { - private val logFactorials: DoubleArray = DoubleArray(numValues) - - init { - val endCopy: Int - - if (cache != null && cache.size > BEGIN_LOG_FACTORIALS) { - // Copy available values. - endCopy = min(cache.size, numValues) - cache.copyInto( - logFactorials, - BEGIN_LOG_FACTORIALS, - BEGIN_LOG_FACTORIALS, endCopy - ) - } - // All values to be computed - else endCopy = BEGIN_LOG_FACTORIALS - - // Compute remaining values. - (endCopy until numValues).forEach { i -> - if (i < FACTORIALS.size) - logFactorials[i] = ln(FACTORIALS[i].toDouble()) - else - logFactorials[i] = logFactorials[i - 1] + ln(i.toDouble()) - } - } - - fun value(n: Int): Double { - if (n < logFactorials.size) - return logFactorials[n] - - return if (n < FACTORIALS.size) ln(FACTORIALS[n].toDouble()) else InternalGamma.logGamma(n + 1.0) - } - - companion object { - fun create(): FactorialLog = FactorialLog(0, null) - } - } -} diff --git a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/LargeMeanPoissonSampler.kt b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/LargeMeanPoissonSampler.kt index dba2550cb..8a54fabaa 100644 --- a/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/LargeMeanPoissonSampler.kt +++ b/kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/samplers/LargeMeanPoissonSampler.kt @@ -5,6 +5,7 @@ import kscience.kmath.chains.ConstantChain import kscience.kmath.prob.RandomGenerator import kscience.kmath.prob.Sampler import kscience.kmath.prob.chain +import kscience.kmath.prob.internal.InternalUtils import kscience.kmath.prob.next import kotlin.math.* @@ -111,16 +112,13 @@ public class LargeMeanPoissonSampler private constructor(public val mean: Double } private fun getFactorialLog(n: Int): Double = factorialLog.value(n) - - override fun toString(): String = "Large Mean Poisson deviate" + public override fun toString(): String = "Large Mean Poisson deviate" public companion object { private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create() - private val NO_SMALL_MEAN_POISSON_SAMPLER: Sampler = object : Sampler { - override fun sample(generator: RandomGenerator): Chain = ConstantChain(0) - } + private val NO_SMALL_MEAN_POISSON_SAMPLER: Sampler = Sampler { ConstantChain(0) } public fun of(mean: Double): LargeMeanPoissonSampler { require(mean >= 1) { "mean is not >= 1: $mean" } diff --git a/kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/RandomSourceGenerator.kt b/kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/RandomSourceGenerator.kt index 18be6f019..9cdde6b4b 100644 --- a/kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/RandomSourceGenerator.kt +++ b/kmath-prob/src/jvmMain/kotlin/kscience/kmath/prob/RandomSourceGenerator.kt @@ -26,10 +26,7 @@ public class RandomSourceGenerator(public val source: RandomSource, seed: Long?) public inline class RandomGeneratorProvider(public val generator: RandomGenerator) : UniformRandomProvider { public override fun nextBoolean(): Boolean = generator.nextBoolean() public override fun nextFloat(): Float = generator.nextDouble().toFloat() - - public override fun nextBytes(bytes: ByteArray) { - generator.fillBytes(bytes) - } + public override fun nextBytes(bytes: ByteArray): Unit = generator.fillBytes(bytes) public override fun nextBytes(bytes: ByteArray, start: Int, len: Int) { generator.fillBytes(bytes, start, start + len)