Code refactoring, implement NormalDistribution
This commit is contained in:
parent
26d81bddb5
commit
0c6fff3878
@ -33,10 +33,7 @@ public class DerivativeStructureField(
|
|||||||
variables[name] ?: default ?: error("A variable with name $name does not exist")
|
variables[name] ?: default ?: error("A variable with name $name does not exist")
|
||||||
|
|
||||||
public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
|
public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
|
||||||
|
public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double = deriv(mapOf(parName to order))
|
||||||
public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
|
|
||||||
return deriv(mapOf(parName to order))
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
|
public fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
|
||||||
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
|
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
|
||||||
@ -75,7 +72,6 @@ public class DerivativeStructureField(
|
|||||||
public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
|
public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
|
||||||
public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
|
public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
|
||||||
public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
||||||
|
|
||||||
public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
||||||
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||||
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||||
|
@ -6,7 +6,7 @@ import kscience.kmath.chains.collect
|
|||||||
import kscience.kmath.structures.Buffer
|
import kscience.kmath.structures.Buffer
|
||||||
import kscience.kmath.structures.BufferFactory
|
import kscience.kmath.structures.BufferFactory
|
||||||
|
|
||||||
public interface Sampler<T : Any> {
|
public fun interface Sampler<T : Any> {
|
||||||
public fun sample(generator: RandomGenerator): Chain<T>
|
public fun sample(generator: RandomGenerator): Chain<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20,11 +20,7 @@ public interface Distribution<T : Any> : Sampler<T> {
|
|||||||
*/
|
*/
|
||||||
public fun probability(arg: T): Double
|
public fun probability(arg: T): Double
|
||||||
|
|
||||||
/**
|
public override fun sample(generator: RandomGenerator): Chain<T>
|
||||||
* Create a chain of samples from this distribution.
|
|
||||||
* The chain is not guaranteed to be stateless, but different sample chains should be independent.
|
|
||||||
*/
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<T>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An empty companion. Distribution factories should be written as its extensions
|
* An empty companion. Distribution factories should be written as its extensions
|
||||||
@ -63,9 +59,7 @@ public fun <T : Any> Sampler<T>.sampleBuffer(
|
|||||||
//clear list from previous run
|
//clear list from previous run
|
||||||
tmp.clear()
|
tmp.clear()
|
||||||
//Fill list
|
//Fill list
|
||||||
repeat(size) {
|
repeat(size) { tmp.add(chain.next()) }
|
||||||
tmp.add(chain.next())
|
|
||||||
}
|
|
||||||
//return new buffer with elements from tmp
|
//return new buffer with elements from tmp
|
||||||
bufferFactory(size) { tmp[it] }
|
bufferFactory(size) { tmp[it] }
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
package kscience.kmath.prob.distributions
|
||||||
|
|
||||||
|
import kscience.kmath.chains.Chain
|
||||||
|
import kscience.kmath.prob.RandomGenerator
|
||||||
|
import kscience.kmath.prob.UnivariateDistribution
|
||||||
|
import kscience.kmath.prob.internal.InternalErf
|
||||||
|
import kscience.kmath.prob.samplers.GaussianSampler
|
||||||
|
import kotlin.math.*
|
||||||
|
|
||||||
|
public inline class NormalDistribution(public val sampler: GaussianSampler) : UnivariateDistribution<Double> {
|
||||||
|
public override fun probability(arg: Double): Double {
|
||||||
|
val x1 = (arg - sampler.mean) / sampler.standardDeviation
|
||||||
|
return exp(-0.5 * x1 * x1 - (ln(sampler.standardDeviation) + 0.5 * ln(2 * PI)))
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun sample(generator: RandomGenerator): Chain<Double> = sampler.sample(generator)
|
||||||
|
|
||||||
|
public override fun cumulative(arg: Double): Double {
|
||||||
|
val dev = arg - sampler.mean
|
||||||
|
|
||||||
|
return when {
|
||||||
|
abs(dev) > 40 * sampler.standardDeviation -> if (dev < 0) 0.0 else 1.0
|
||||||
|
else -> 0.5 * InternalErf.erfc(-dev / (sampler.standardDeviation * SQRT2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private companion object {
|
||||||
|
private val SQRT2 = sqrt(2.0)
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,7 @@ import kscience.kmath.chains.Chain
|
|||||||
import kscience.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import kscience.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import kscience.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
|
import kscience.kmath.prob.internal.InternalUtils
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import kscience.kmath.chains.Chain
|
|||||||
import kscience.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import kscience.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import kscience.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
|
import kscience.kmath.prob.internal.InternalUtils
|
||||||
import kotlin.math.ceil
|
import kotlin.math.ceil
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
|
@ -10,10 +10,13 @@ import kscience.kmath.prob.Sampler
|
|||||||
*
|
*
|
||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html
|
||||||
|
*
|
||||||
|
* @property mean the mean of the distribution.
|
||||||
|
* @property standardDeviation the variance of the distribution.
|
||||||
*/
|
*/
|
||||||
public class GaussianSampler private constructor(
|
public class GaussianSampler private constructor(
|
||||||
private val mean: Double,
|
public val mean: Double,
|
||||||
private val standardDeviation: Double,
|
public val standardDeviation: Double,
|
||||||
private val normalized: NormalizedGaussianSampler
|
private val normalized: NormalizedGaussianSampler
|
||||||
) : Sampler<Double> {
|
) : Sampler<Double> {
|
||||||
public override fun sample(generator: RandomGenerator): Chain<Double> = normalized
|
public override fun sample(generator: RandomGenerator): Chain<Double> = normalized
|
||||||
|
@ -1,43 +1,2 @@
|
|||||||
package kscience.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import kotlin.math.PI
|
|
||||||
import kotlin.math.ln
|
|
||||||
|
|
||||||
internal object InternalGamma {
|
|
||||||
private const val LANCZOS_G = 607.0 / 128.0
|
|
||||||
|
|
||||||
private val LANCZOS_COEFFICIENTS = doubleArrayOf(
|
|
||||||
0.99999999999999709182,
|
|
||||||
57.156235665862923517,
|
|
||||||
-59.597960355475491248,
|
|
||||||
14.136097974741747174,
|
|
||||||
-0.49191381609762019978,
|
|
||||||
.33994649984811888699e-4,
|
|
||||||
.46523628927048575665e-4,
|
|
||||||
-.98374475304879564677e-4,
|
|
||||||
.15808870322491248884e-3,
|
|
||||||
-.21026444172410488319e-3,
|
|
||||||
.21743961811521264320e-3,
|
|
||||||
-.16431810653676389022e-3,
|
|
||||||
.84418223983852743293e-4,
|
|
||||||
-.26190838401581408670e-4,
|
|
||||||
.36899182659531622704e-5
|
|
||||||
)
|
|
||||||
|
|
||||||
private val HALF_LOG_2_PI: Double = 0.5 * ln(2.0 * PI)
|
|
||||||
|
|
||||||
fun logGamma(x: Double): Double {
|
|
||||||
// Stripped-down version of the same method defined in "Commons Math":
|
|
||||||
// Unused "if" branches (for when x < 8) have been removed here since
|
|
||||||
// this method is only used (by class "InternalUtils") in order to
|
|
||||||
// compute log(n!) for x > 20.
|
|
||||||
val sum = lanczos(x)
|
|
||||||
val tmp = x + LANCZOS_G + 0.5
|
|
||||||
return (x + 0.5) * ln(tmp) - tmp + HALF_LOG_2_PI + ln(sum / x)
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun lanczos(x: Double): Double {
|
|
||||||
val sum = (LANCZOS_COEFFICIENTS.size - 1 downTo 1).sumByDouble { LANCZOS_COEFFICIENTS[it] / (x + it) }
|
|
||||||
return sum + LANCZOS_COEFFICIENTS[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,78 +1,2 @@
|
|||||||
package kscience.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import kotlin.math.ln
|
|
||||||
import kotlin.math.min
|
|
||||||
|
|
||||||
internal object InternalUtils {
|
|
||||||
private val FACTORIALS = longArrayOf(
|
|
||||||
1L, 1L, 2L,
|
|
||||||
6L, 24L, 120L,
|
|
||||||
720L, 5040L, 40320L,
|
|
||||||
362880L, 3628800L, 39916800L,
|
|
||||||
479001600L, 6227020800L, 87178291200L,
|
|
||||||
1307674368000L, 20922789888000L, 355687428096000L,
|
|
||||||
6402373705728000L, 121645100408832000L, 2432902008176640000L
|
|
||||||
)
|
|
||||||
|
|
||||||
private const val BEGIN_LOG_FACTORIALS = 2
|
|
||||||
|
|
||||||
fun factorial(n: Int): Long = FACTORIALS[n]
|
|
||||||
|
|
||||||
fun validateProbabilities(probabilities: DoubleArray?): Double {
|
|
||||||
require(!(probabilities == null || probabilities.isEmpty())) { "Probabilities must not be empty." }
|
|
||||||
var sumProb = 0.0
|
|
||||||
|
|
||||||
probabilities.forEach { prob ->
|
|
||||||
validateProbability(prob)
|
|
||||||
sumProb += prob
|
|
||||||
}
|
|
||||||
|
|
||||||
require(!(sumProb.isInfinite() || sumProb <= 0)) { "Invalid sum of probabilities: $sumProb" }
|
|
||||||
return sumProb
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun validateProbability(probability: Double): Unit =
|
|
||||||
require(!(probability < 0 || probability.isInfinite() || probability.isNaN())) { "Invalid probability: $probability" }
|
|
||||||
|
|
||||||
class FactorialLog private constructor(
|
|
||||||
numValues: Int,
|
|
||||||
cache: DoubleArray?
|
|
||||||
) {
|
|
||||||
private val logFactorials: DoubleArray = DoubleArray(numValues)
|
|
||||||
|
|
||||||
init {
|
|
||||||
val endCopy: Int
|
|
||||||
|
|
||||||
if (cache != null && cache.size > BEGIN_LOG_FACTORIALS) {
|
|
||||||
// Copy available values.
|
|
||||||
endCopy = min(cache.size, numValues)
|
|
||||||
cache.copyInto(
|
|
||||||
logFactorials,
|
|
||||||
BEGIN_LOG_FACTORIALS,
|
|
||||||
BEGIN_LOG_FACTORIALS, endCopy
|
|
||||||
)
|
|
||||||
}
|
|
||||||
// All values to be computed
|
|
||||||
else endCopy = BEGIN_LOG_FACTORIALS
|
|
||||||
|
|
||||||
// Compute remaining values.
|
|
||||||
(endCopy until numValues).forEach { i ->
|
|
||||||
if (i < FACTORIALS.size)
|
|
||||||
logFactorials[i] = ln(FACTORIALS[i].toDouble())
|
|
||||||
else
|
|
||||||
logFactorials[i] = logFactorials[i - 1] + ln(i.toDouble())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun value(n: Int): Double {
|
|
||||||
if (n < logFactorials.size)
|
|
||||||
return logFactorials[n]
|
|
||||||
|
|
||||||
return if (n < FACTORIALS.size) ln(FACTORIALS[n].toDouble()) else InternalGamma.logGamma(n + 1.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
fun create(): FactorialLog = FactorialLog(0, null)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -5,6 +5,7 @@ import kscience.kmath.chains.ConstantChain
|
|||||||
import kscience.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import kscience.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import kscience.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
|
import kscience.kmath.prob.internal.InternalUtils
|
||||||
import kscience.kmath.prob.next
|
import kscience.kmath.prob.next
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
@ -111,16 +112,13 @@ public class LargeMeanPoissonSampler private constructor(public val mean: Double
|
|||||||
}
|
}
|
||||||
|
|
||||||
private fun getFactorialLog(n: Int): Double = factorialLog.value(n)
|
private fun getFactorialLog(n: Int): Double = factorialLog.value(n)
|
||||||
|
public override fun toString(): String = "Large Mean Poisson deviate"
|
||||||
override fun toString(): String = "Large Mean Poisson deviate"
|
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
|
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
|
||||||
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
|
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
|
||||||
|
|
||||||
private val NO_SMALL_MEAN_POISSON_SAMPLER: Sampler<Int> = object : Sampler<Int> {
|
private val NO_SMALL_MEAN_POISSON_SAMPLER: Sampler<Int> = Sampler { ConstantChain(0) }
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun of(mean: Double): LargeMeanPoissonSampler {
|
public fun of(mean: Double): LargeMeanPoissonSampler {
|
||||||
require(mean >= 1) { "mean is not >= 1: $mean" }
|
require(mean >= 1) { "mean is not >= 1: $mean" }
|
||||||
|
@ -26,10 +26,7 @@ public class RandomSourceGenerator(public val source: RandomSource, seed: Long?)
|
|||||||
public inline class RandomGeneratorProvider(public val generator: RandomGenerator) : UniformRandomProvider {
|
public inline class RandomGeneratorProvider(public val generator: RandomGenerator) : UniformRandomProvider {
|
||||||
public override fun nextBoolean(): Boolean = generator.nextBoolean()
|
public override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||||
public override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
public override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
||||||
|
public override fun nextBytes(bytes: ByteArray): Unit = generator.fillBytes(bytes)
|
||||||
public override fun nextBytes(bytes: ByteArray) {
|
|
||||||
generator.fillBytes(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
public override fun nextBytes(bytes: ByteArray, start: Int, len: Int) {
|
public override fun nextBytes(bytes: ByteArray, start: Int, len: Int) {
|
||||||
generator.fillBytes(bytes, start, start + len)
|
generator.fillBytes(bytes, start, start + len)
|
||||||
|
Loading…
Reference in New Issue
Block a user