Add 2 more samplers, replace SimpleChain with generator.chain

This commit is contained in:
Commander Tvis 2020-06-12 01:13:15 +07:00
parent 8cdc596549
commit 46f6d57fd9
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B
6 changed files with 396 additions and 12 deletions

View File

@ -0,0 +1,112 @@
package scientifik.kmath.prob.samplers
import scientifik.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.chain
import scientifik.kmath.prob.next
import kotlin.math.*
class AhrensDieterMarsagliaTsangGammaSampler private constructor(
alpha: Double,
theta: Double
) : Sampler<Double> {
private val delegate: BaseGammaSampler =
if (alpha < 1) AhrensDieterGammaSampler(alpha, theta) else MarsagliaTsangGammaSampler(
alpha,
theta
)
private abstract class BaseGammaSampler internal constructor(
protected val alpha: Double,
protected val theta: Double
) : Sampler<Double> {
init {
require(alpha > 0) { "alpha is not strictly positive: $alpha" }
require(theta > 0) { "theta is not strictly positive: $theta" }
}
override fun toString(): String = "Ahrens-Dieter-Marsaglia-Tsang Gamma deviate"
}
private class AhrensDieterGammaSampler internal constructor(alpha: Double, theta: Double) :
BaseGammaSampler(alpha, theta) {
private val oneOverAlpha: Double = 1.0 / alpha
private val bGSOptim: Double = 1.0 + alpha / E
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
var x: Double
// [1]: p. 228, Algorithm GS.
while (true) {
// Step 1:
val u = generator.nextDouble()
val p = bGSOptim * u
if (p <= 1) {
// Step 2:
x = p.pow(oneOverAlpha)
val u2 = generator.nextDouble()
if (u2 > exp(-x)) // Reject.
continue
break
}
// Step 3:
x = -ln((bGSOptim - p) * oneOverAlpha)
val u2: Double = generator.nextDouble()
if (u2 <= x.pow(alpha - 1.0)) break
// Reject and continue.
}
x * theta
}
}
private class MarsagliaTsangGammaSampler internal constructor(alpha: Double, theta: Double) :
BaseGammaSampler(alpha, theta) {
private val dOptim: Double
private val cOptim: Double
private val gaussian: NormalizedGaussianSampler
init {
gaussian = ZigguratNormalizedGaussianSampler.of()
dOptim = alpha - ONE_THIRD
cOptim = ONE_THIRD / sqrt(dOptim)
}
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
var v: Double
while (true) {
val x = gaussian.next(generator)
val oPcTx = 1 + cOptim * x
v = oPcTx * oPcTx * oPcTx
if (v <= 0) continue
val x2 = x * x
val u = generator.nextDouble()
// Squeeze.
if (u < 1 - 0.0331 * x2 * x2) break
if (ln(u) < 0.5 * x2 + dOptim * (1 - v + ln(v))) break
}
theta * dOptim * v
}
companion object {
private const val ONE_THIRD = 1.0 / 3.0
}
}
override fun sample(generator: RandomGenerator): Chain<Double> = delegate.sample(generator)
override fun toString(): String = delegate.toString()
companion object {
fun of(
alpha: Double,
theta: Double
): Sampler<Double> = AhrensDieterMarsagliaTsangGammaSampler(alpha, theta)
}
}

View File

@ -0,0 +1,260 @@
package scientifik.kmath.prob.samplers
import scientifik.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.chain
import kotlin.jvm.JvmOverloads
import kotlin.math.ceil
import kotlin.math.max
import kotlin.math.min
open class AliasMethodDiscreteSampler private constructor(
// Deliberate direct storage of input arrays
protected val probability: LongArray,
protected val alias: IntArray
) : Sampler<Int> {
private class SmallTableAliasMethodDiscreteSampler internal constructor(
probability: LongArray,
alias: IntArray
) : AliasMethodDiscreteSampler(probability, alias) {
// Assume the table size is a power of 2 and create the mask
private val mask: Int = alias.size - 1
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
val bits = generator.nextInt()
// Isolate lower bits
val j = bits and mask
// Optimisation for zero-padded input tables
if (j >= probability.size)
// No probability must use the alias
return@chain alias[j]
// Create a uniform random deviate as a long.
// This replicates functionality from the o.a.c.rng.core.utils.NumberFactory.makeLong
val longBits = generator.nextInt().toLong() shl 32 or (bits.toLong() and hex_ffffffff)
// Choose between the two. Use a 53-bit long for the probability.
if (longBits ushr 11 < probability[j]) j else alias[j]
}
private companion object {
private const val hex_ffffffff = 4294967295L
}
}
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
// This implements the algorithm as per Vose (1991):
// v = uniform() in [0, 1)
// j = uniform(n) in [0, n)
// if v < prob[j] then
// return j
// else
// return alias[j]
val j = generator.nextInt(alias.size)
// Optimisation for zero-padded input tables
// No probability must use the alias
if (j >= probability.size) return@chain alias[j]
// Note: We could check the probability before computing a deviate.
// p(j) == 0 => alias[j]
// p(j) == 1 => j
// However it is assumed these edge cases are rare:
//
// The probability table will be 1 for approximately 1/n samples, i.e. only the
// last unpaired probability. This is only worth checking for when the table size (n)
// is small. But in that case the user should zero-pad the table for performance.
//
// The probability table will be 0 when an input probability was zero. We
// will assume this is also rare if modelling a discrete distribution where
// all samples are possible. The edge case for zero-padded tables is handled above.
// Choose between the two. Use a 53-bit long for the probability.
if (generator.nextLong() ushr 11 < probability[j]) j else alias[j]
}
override fun toString(): String = "Alias method"
companion object {
private const val DEFAULT_ALPHA = 0
private const val ZERO = 0.0
private const val ONE_AS_NUMERATOR = 1L shl 53
private const val CONVERT_TO_NUMERATOR: Double = ONE_AS_NUMERATOR.toDouble()
private const val MAX_SMALL_POWER_2_SIZE = 1 shl 11
@JvmOverloads
fun of(
probabilities: DoubleArray,
alpha: Int = DEFAULT_ALPHA
): Sampler<Int> {
// The Alias method balances N categories with counts around the mean into N sections,
// each allocated 'mean' observations.
//
// Consider 4 categories with counts 6,3,2,1. The histogram can be balanced into a
// 2D array as 4 sections with a height of the mean:
//
// 6
// 6
// 6
// 63 => 6366 --
// 632 6326 |-- mean
// 6321 6321 --
//
// section abcd
//
// Each section is divided as:
// a: 6=1/1
// b: 3=1/1
// c: 2=2/3; 6=1/3 (6 is the alias)
// d: 1=1/3; 6=2/3 (6 is the alias)
//
// The sample is obtained by randomly selecting a section, then choosing which category
// from the pair based on a uniform random deviate.
val sumProb = InternalUtils.validateProbabilities(probabilities)
// Allow zero-padding
val n = computeSize(probabilities.size, alpha)
// Partition into small and large by splitting on the average.
val mean = sumProb / n
// The cardinality of smallSize + largeSize = n.
// So fill the same array from either end.
val indices = IntArray(n)
var large = n
var small = 0
probabilities.indices.forEach { i ->
if (probabilities[i] >= mean) indices[--large] = i else indices[small++] = i
}
small = fillRemainingIndices(probabilities.size, indices, small)
// This may be smaller than the input length if the probabilities were already padded.
val nonZeroIndex = findLastNonZeroIndex(probabilities)
// The probabilities are modified so use a copy.
// Note: probabilities are required only up to last nonZeroIndex
val remainingProbabilities = probabilities.copyOf(nonZeroIndex + 1)
// Allocate the final tables.
// Probability table may be truncated (when zero padded).
// The alias table is full length.
val probability = LongArray(remainingProbabilities.size)
val alias = IntArray(n)
// This loop uses each large in turn to fill the alias table for small probabilities that
// do not reach the requirement to fill an entire section alone (i.e. p < mean).
// Since the sum of the small should be less than the sum of the large it should use up
// all the small first. However floating point round-off can result in
// misclassification of items as small or large. The Vose algorithm handles this using
// a while loop conditioned on the size of both sets and a subsequent loop to use
// unpaired items.
while (large != n && small != 0) {
// Index of the small and the large probabilities.
val j = indices[--small]
val k = indices[large++]
// Optimisation for zero-padded input:
// p(j) = 0 above the last nonZeroIndex
if (j > nonZeroIndex)
// The entire amount for the section is taken from the alias.
remainingProbabilities[k] -= mean
else {
val pj = remainingProbabilities[j]
// Item j is a small probability that is below the mean.
// Compute the weight of the section for item j: pj / mean.
// This is scaled by 2^53 and the ceiling function used to round-up
// the probability to a numerator of a fraction in the range [1,2^53].
// Ceiling ensures non-zero values.
probability[j] = ceil(CONVERT_TO_NUMERATOR * (pj / mean)).toLong()
// The remaining amount for the section is taken from the alias.
// Effectively: probabilities[k] -= (mean - pj)
remainingProbabilities[k] += pj - mean
}
// If not j then the alias is k
alias[j] = k
// Add the remaining probability from large to the appropriate list.
if (remainingProbabilities[k] >= mean) indices[--large] = k else indices[small++] = k
}
// Final loop conditions to consume unpaired items.
// Note: The large set should never be non-empty but this can occur due to round-off
// error so consume from both.
fillTable(probability, alias, indices, 0, small)
fillTable(probability, alias, indices, large, n)
// Change the algorithm for small power of 2 sized tables
return if (isSmallPowerOf2(n))
SmallTableAliasMethodDiscreteSampler(probability, alias)
else
AliasMethodDiscreteSampler(probability, alias)
}
private fun fillRemainingIndices(length: Int, indices: IntArray, small: Int): Int {
var updatedSmall = small
(length until indices.size).forEach { i -> indices[updatedSmall++] = i }
return updatedSmall
}
private fun findLastNonZeroIndex(probabilities: DoubleArray): Int {
// No bounds check is performed when decrementing as the array contains at least one
// value above zero.
var nonZeroIndex = probabilities.size - 1
while (probabilities[nonZeroIndex] == ZERO) nonZeroIndex--
return nonZeroIndex
}
private fun computeSize(length: Int, alpha: Int): Int {
// If No padding
if (alpha < 0) return length
// Use the number of leading zeros function to find the next power of 2,
// i.e. ceil(log2(x))
var pow2 = 32 - numberOfLeadingZeros(length - 1)
// Increase by the alpha. Clip this to limit to a positive integer (2^30)
pow2 = min(30, pow2 + alpha)
// Use max to handle a length above the highest possible power of 2
return max(length, 1 shl pow2)
}
private fun fillTable(
probability: LongArray,
alias: IntArray,
indices: IntArray,
start: Int,
end: Int
) = (start until end).forEach { i ->
val index = indices[i]
probability[index] = ONE_AS_NUMERATOR
alias[index] = index
}
private fun isSmallPowerOf2(n: Int): Boolean = n <= MAX_SMALL_POWER_2_SIZE && n and n - 1 == 0
private fun numberOfLeadingZeros(i: Int): Int {
var mutI = i
if (mutI <= 0) return if (mutI == 0) 32 else 0
var n = 31
if (mutI >= 1 shl 16) {
n -= 16
mutI = mutI ushr 16
}
if (mutI >= 1 shl 8) {
n -= 8
mutI = mutI ushr 8
}
if (mutI >= 1 shl 4) {
n -= 4
mutI = mutI ushr 4
}
if (mutI >= 1 shl 2) {
n -= 2
mutI = mutI ushr 2
}
return n - (mutI ushr 1)
}
}
}

View File

@ -18,6 +18,22 @@ internal object InternalUtils {
fun factorial(n: Int): Long = FACTORIALS[n]
fun validateProbabilities(probabilities: DoubleArray?): Double {
require(!(probabilities == null || probabilities.isEmpty())) { "Probabilities must not be empty." }
var sumProb = 0.0
probabilities.forEach { prob ->
validateProbability(prob)
sumProb += prob
}
require(!(sumProb.isInfinite() || sumProb <= 0)) { "Invalid sum of probabilities: $sumProb" }
return sumProb
}
private fun validateProbability(probability: Double): Unit =
require(!(probability < 0 || probability.isInfinite() || probability.isNaN())) { "Invalid probability: $probability" }
class FactorialLog private constructor(
numValues: Int,
cache: DoubleArray?
@ -30,9 +46,11 @@ internal object InternalUtils {
if (cache != null && cache.size > BEGIN_LOG_FACTORIALS) {
// Copy available values.
endCopy = min(cache.size, numValues)
cache.copyInto(logFactorials,
cache.copyInto(
logFactorials,
BEGIN_LOG_FACTORIALS,
BEGIN_LOG_FACTORIALS, endCopy)
BEGIN_LOG_FACTORIALS, endCopy
)
}
// All values to be computed
else endCopy = BEGIN_LOG_FACTORIALS
@ -50,15 +68,11 @@ internal object InternalUtils {
if (n < logFactorials.size)
return logFactorials[n]
return if (n < FACTORIALS.size) ln(
FACTORIALS[n].toDouble()) else InternalGamma.logGamma(
n + 1.0
)
return if (n < FACTORIALS.size) ln(FACTORIALS[n].toDouble()) else InternalGamma.logGamma(n + 1.0)
}
companion object {
fun create(): FactorialLog =
FactorialLog(0, null)
fun create(): FactorialLog = FactorialLog(0, null)
}
}
}

View File

@ -2,9 +2,9 @@ package scientifik.kmath.prob.samplers
import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.ConstantChain
import scientifik.kmath.chains.SimpleChain
import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.chain
import scientifik.kmath.prob.next
import kotlin.math.*
@ -35,7 +35,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
else // Not used.
KempSmallMeanPoissonSampler.of(mean - lambda)
override fun sample(generator: RandomGenerator): Chain<Int> = SimpleChain {
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
// This will never be null. It may be a no-op delegate that returns zero.
val y2 = smallMeanPoissonSampler.next(generator)
var x: Double

View File

@ -1,7 +1,6 @@
package scientifik.kmath.prob.samplers
import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.SimpleChain
import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.chain

View File

@ -1,7 +1,6 @@
package scientifik.kmath.prob.samplers
import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.SimpleChain
import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler
import scientifik.kmath.prob.chain