Implement Commons RNG-like samplers in kmath-prob module for Multiplatform #164

Merged
CommanderTvis merged 44 commits from feature/mp-samplers into dev 2021-03-31 09:25:44 +03:00
31 changed files with 175 additions and 158 deletions
Showing only changes of commit bc59f8b287 - Show all commits

View File

@ -1,2 +0,0 @@
plugins { id("scientifik.mpp") }
kotlin.sourceSets { commonMain.get().dependencies { api(project(":kmath-coroutines")) } }

View File

@ -1,7 +0,0 @@
package scientifik.commons.rng.sampling
import scientifik.commons.rng.UniformRandomProvider
interface SharedStateSampler<R> {
fun withUniformRandomProvider(rng: UniformRandomProvider): R
}

View File

@ -1,5 +0,0 @@
package scientifik.commons.rng.sampling.distribution
interface ContinuousSampler {
fun sample(): Double
}

View File

@ -1,5 +0,0 @@
package scientifik.commons.rng.sampling.distribution
interface DiscreteSampler {
fun sample(): Int
}

View File

@ -1,3 +0,0 @@
package scientifik.commons.rng.sampling.distribution
interface NormalizedGaussianSampler : ContinuousSampler

View File

@ -1,7 +0,0 @@
package scientifik.commons.rng.sampling.distribution
import scientifik.commons.rng.sampling.SharedStateSampler
interface SharedStateContinuousSampler : ContinuousSampler,
SharedStateSampler<SharedStateContinuousSampler?> {
}

View File

@ -3,17 +3,10 @@ plugins {
}
kotlin.sourceSets {
commonMain {
dependencies {
api(project(":kmath-coroutines"))
api(project(":kmath-commons-rng-part"))
}
}
commonMain.get().dependencies { api(project(":kmath-coroutines")) }
jvmMain {
dependencies {
api("org.apache.commons:commons-rng-sampling:1.3")
api("org.apache.commons:commons-rng-simple:1.3")
}
jvmMain.get().dependencies {
api("org.apache.commons:commons-rng-sampling:1.3")
api("org.apache.commons:commons-rng-simple:1.3")
}
}

View File

@ -1,4 +1,4 @@
package scientifik.commons.rng
package scientifik.kmath.commons.rng
interface UniformRandomProvider {
fun nextBytes(bytes: ByteArray)

View File

@ -0,0 +1,7 @@
package scientifik.kmath.commons.rng.sampling
import scientifik.kmath.commons.rng.UniformRandomProvider
interface SharedStateSampler<R> {
fun withUniformRandomProvider(rng: UniformRandomProvider): R
}

View File

@ -1,6 +1,6 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
import kotlin.math.ln
import kotlin.math.pow
@ -76,7 +76,11 @@ class AhrensDieterExponentialSampler : SamplerBase,
fun of(
rng: UniformRandomProvider,
mean: Double
): SharedStateContinuousSampler = AhrensDieterExponentialSampler(rng, mean)
): SharedStateContinuousSampler =
AhrensDieterExponentialSampler(
rng,
mean
)
init {
/**
@ -87,7 +91,9 @@ class AhrensDieterExponentialSampler : SamplerBase,
var qi = 0.0
EXPONENTIAL_SA_QI.indices.forEach { i ->
qi += ln2.pow(i + 1.0) / InternalUtils.factorial(i + 1)
qi += ln2.pow(i + 1.0) / InternalUtils.factorial(
i + 1
)
EXPONENTIAL_SA_QI[i] = qi
}
}

View File

@ -1,6 +1,6 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
import kotlin.math.*
class BoxMullerNormalizedGaussianSampler(

View File

@ -0,0 +1,5 @@
package scientifik.kmath.commons.rng.sampling.distribution
interface ContinuousSampler {
fun sample(): Double
}

View File

@ -0,0 +1,5 @@
package scientifik.kmath.commons.rng.sampling.distribution
interface DiscreteSampler {
fun sample(): Int
}

View File

@ -1,8 +1,9 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
class GaussianSampler : SharedStateContinuousSampler {
class GaussianSampler :
SharedStateContinuousSampler {
private val mean: Double
private val standardDeviation: Double
private val normalized: NormalizedGaussianSampler
@ -24,7 +25,11 @@ class GaussianSampler : SharedStateContinuousSampler {
) {
mean = source.mean
standardDeviation = source.standardDeviation
normalized = InternalUtils.newNormalizedGaussianSampler(source.normalized, rng)
normalized =
InternalUtils.newNormalizedGaussianSampler(
source.normalized,
rng
)
}
override fun sample(): Double = standardDeviation * normalized.sample() + mean
@ -40,6 +45,11 @@ class GaussianSampler : SharedStateContinuousSampler {
normalized: NormalizedGaussianSampler,
mean: Double,
standardDeviation: Double
): SharedStateContinuousSampler = GaussianSampler(normalized, mean, standardDeviation)
): SharedStateContinuousSampler =
GaussianSampler(
normalized,
mean,
standardDeviation
)
}
}

View File

@ -1,4 +1,4 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import kotlin.math.PI
import kotlin.math.ln

View File

@ -1,7 +1,7 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.commons.rng.sampling.SharedStateSampler
import scientifik.kmath.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.sampling.SharedStateSampler
import kotlin.math.ln
import kotlin.math.min
@ -63,30 +63,45 @@ internal object InternalUtils {
if (cache != null && cache.size > BEGIN_LOG_FACTORIALS) {
// Copy available values.
endCopy = min(cache.size, numValues)
cache.copyInto(logFactorials, BEGIN_LOG_FACTORIALS, BEGIN_LOG_FACTORIALS, endCopy)
cache.copyInto(logFactorials,
BEGIN_LOG_FACTORIALS,
BEGIN_LOG_FACTORIALS, endCopy)
}
// All values to be computed
else
endCopy = BEGIN_LOG_FACTORIALS
endCopy =
BEGIN_LOG_FACTORIALS
// Compute remaining values.
(endCopy until numValues).forEach { i ->
if (i < FACTORIALS.size) logFactorials[i] = ln(FACTORIALS[i].toDouble()) else logFactorials[i] =
if (i < FACTORIALS.size) logFactorials[i] = ln(
FACTORIALS[i].toDouble()) else logFactorials[i] =
logFactorials[i - 1] + ln(i.toDouble())
}
}
fun withCache(cacheSize: Int): FactorialLog = FactorialLog(cacheSize, logFactorials)
fun withCache(cacheSize: Int): FactorialLog =
FactorialLog(
cacheSize,
logFactorials
)
fun value(n: Int): Double {
if (n < logFactorials.size)
return logFactorials[n]
return if (n < FACTORIALS.size) ln(FACTORIALS[n].toDouble()) else InternalGamma.logGamma(n + 1.0)
return if (n < FACTORIALS.size) ln(
FACTORIALS[n].toDouble()) else InternalGamma.logGamma(
n + 1.0
)
}
companion object {
fun create(): FactorialLog = FactorialLog(0, null)
fun create(): FactorialLog =
FactorialLog(
0,
null
)
}
}
}

View File

@ -1,6 +1,6 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
import kotlin.math.exp
class KempSmallMeanPoissonSampler private constructor(
@ -48,7 +48,11 @@ class KempSmallMeanPoissonSampler private constructor(
val p0: Double = exp(-mean)
// Probability must be positive. As mean increases then p(0) decreases.
if (p0 > 0) return KempSmallMeanPoissonSampler(rng, p0, mean)
if (p0 > 0) return KempSmallMeanPoissonSampler(
rng,
p0,
mean
)
throw IllegalArgumentException("No probability for mean: $mean")
}
}

View File

@ -1,9 +1,10 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
import kotlin.math.*
class LargeMeanPoissonSampler : SharedStateDiscreteSampler {
class LargeMeanPoissonSampler :
SharedStateDiscreteSampler {
private val rng: UniformRandomProvider
private val exponential: SharedStateContinuousSampler
private val gaussian: SharedStateContinuousSampler
@ -27,8 +28,13 @@ class LargeMeanPoissonSampler : SharedStateDiscreteSampler {
// The algorithm is not valid if Math.floor(mean) is not an integer.
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
this.rng = rng
gaussian = ZigguratNormalizedGaussianSampler(rng)
exponential = AhrensDieterExponentialSampler.of(rng, 1.0)
gaussian =
ZigguratNormalizedGaussianSampler(rng)
exponential =
AhrensDieterExponentialSampler.of(
rng,
1.0
)
// Plain constructor uses the uncached function.
factorialLog = NO_CACHE_FACTORIAL_LOG!!
// Cache values used in the algorithm
@ -49,7 +55,10 @@ class LargeMeanPoissonSampler : SharedStateDiscreteSampler {
val lambdaFractional = mean - lambda
smallMeanPoissonSampler =
if (lambdaFractional < Double.MIN_VALUE) NO_SMALL_MEAN_POISSON_SAMPLER else // Not used.
KempSmallMeanPoissonSampler.of(rng, lambdaFractional)
KempSmallMeanPoissonSampler.of(
rng,
lambdaFractional
)
}
internal constructor(
@ -59,8 +68,13 @@ class LargeMeanPoissonSampler : SharedStateDiscreteSampler {
) {
require(!(lambdaFractional < 0 || lambdaFractional >= 1)) { "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): $lambdaFractional" }
this.rng = rng
gaussian = ZigguratNormalizedGaussianSampler(rng)
exponential = AhrensDieterExponentialSampler.of(rng, 1.0)
gaussian =
ZigguratNormalizedGaussianSampler(rng)
exponential =
AhrensDieterExponentialSampler.of(
rng,
1.0
)
// Plain constructor uses the uncached function.
factorialLog = NO_CACHE_FACTORIAL_LOG!!
// Use the state to initialise the algorithm
@ -79,7 +93,10 @@ class LargeMeanPoissonSampler : SharedStateDiscreteSampler {
if (lambdaFractional < Double.MIN_VALUE)
NO_SMALL_MEAN_POISSON_SAMPLER
else // Not used.
KempSmallMeanPoissonSampler.of(rng, lambdaFractional)
KempSmallMeanPoissonSampler.of(
rng,
lambdaFractional
)
}
/**
@ -222,7 +239,8 @@ class LargeMeanPoissonSampler : SharedStateDiscreteSampler {
fun of(
rng: UniformRandomProvider,
mean: Double
): SharedStateDiscreteSampler = LargeMeanPoissonSampler(rng, mean)
): SharedStateDiscreteSampler =
LargeMeanPoissonSampler(rng, mean)
init {
// Create without a cache.

View File

@ -1,6 +1,6 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
import kotlin.math.ln
import kotlin.math.sqrt

View File

@ -0,0 +1,4 @@
package scientifik.kmath.commons.rng.sampling.distribution
interface NormalizedGaussianSampler :
ContinuousSampler

View File

@ -1,11 +1,12 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
class PoissonSampler(
rng: UniformRandomProvider,
mean: Double
) : SamplerBase(null), SharedStateDiscreteSampler {
) : SamplerBase(null),
SharedStateDiscreteSampler {
private val poissonSamplerDelegate: SharedStateDiscreteSampler
override fun sample(): Int = poissonSamplerDelegate.sample()
@ -22,11 +23,18 @@ class PoissonSampler(
rng: UniformRandomProvider,
mean: Double
): SharedStateDiscreteSampler =// Each sampler should check the input arguments.
if (mean < PIVOT) SmallMeanPoissonSampler.of(rng, mean) else LargeMeanPoissonSampler.of(rng, mean)
if (mean < PIVOT) SmallMeanPoissonSampler.of(
rng,
mean
) else LargeMeanPoissonSampler.of(
rng,
mean
)
}
init {
// Delegate all work to specialised samplers.
poissonSamplerDelegate = of(rng, mean)
poissonSamplerDelegate =
of(rng, mean)
}
}

View File

@ -1,6 +1,6 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
@Deprecated("Since version 1.1. Class intended for internal use only.")
open class SamplerBase protected constructor(private val rng: UniformRandomProvider?) {

View File

@ -0,0 +1,7 @@
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.kmath.commons.rng.sampling.SharedStateSampler
interface SharedStateContinuousSampler : ContinuousSampler,
SharedStateSampler<SharedStateContinuousSampler?> {
}

View File

@ -1,6 +1,6 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.sampling.SharedStateSampler
import scientifik.kmath.commons.rng.sampling.SharedStateSampler
interface SharedStateDiscreteSampler : DiscreteSampler,
SharedStateSampler<SharedStateDiscreteSampler?> { // Composite interface

View File

@ -1,10 +1,11 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
import kotlin.math.ceil
import kotlin.math.exp
class SmallMeanPoissonSampler : SharedStateDiscreteSampler {
class SmallMeanPoissonSampler :
SharedStateDiscreteSampler {
private val p0: Double
private val limit: Int
private val rng: UniformRandomProvider
@ -53,6 +54,7 @@ class SmallMeanPoissonSampler : SharedStateDiscreteSampler {
fun of(
rng: UniformRandomProvider,
mean: Double
): SharedStateDiscreteSampler = SmallMeanPoissonSampler(rng, mean)
): SharedStateDiscreteSampler =
SmallMeanPoissonSampler(rng, mean)
}
}

View File

@ -1,6 +1,6 @@
package scientifik.commons.rng.sampling.distribution
package scientifik.kmath.commons.rng.sampling.distribution
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
import kotlin.math.*
class ZigguratNormalizedGaussianSampler(private val rng: UniformRandomProvider) :
@ -26,9 +26,13 @@ class ZigguratNormalizedGaussianSampler(private val rng: UniformRandomProvider)
init {
// Filling the tables.
var d = R
var d =
R
var t = d
var fd = gauss(d)
var fd =
gauss(
d
)
val q = V / fd
K[0] = (d / q * MAX).toLong()
K[1] = 0
@ -39,7 +43,10 @@ class ZigguratNormalizedGaussianSampler(private val rng: UniformRandomProvider)
(LAST - 1 downTo 1).forEach { i ->
d = sqrt(-2 * ln(V / d + fd))
fd = gauss(d)
fd =
gauss(
d
)
K[i + 1] = (d / t * MAX).toLong()
t = d
F[i] = fd
@ -80,7 +87,10 @@ class ZigguratNormalizedGaussianSampler(private val rng: UniformRandomProvider)
// else
// Try again.
// This branch is called about 0.012362 times per sample.
if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x)) x else sample()
if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < gauss(
x
)
) x else sample()
}
}

View File

@ -1,53 +0,0 @@
package scientifik.kmath.prob
import scientifik.commons.rng.UniformRandomProvider
import scientifik.commons.rng.sampling.distribution.*
import scientifik.kmath.chains.BlockingIntChain
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.chains.Chain
abstract class ContinuousSamplerDistribution : Distribution<Double> {
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() {
private val sampler = buildCMSampler(generator)
override fun nextDouble(): Double = sampler.sample()
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
}
protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler
override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator)
}
abstract class DiscreteSamplerDistribution : Distribution<Int> {
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() {
private val sampler = buildSampler(generator)
override fun nextInt(): Int = sampler.sample()
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
}
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator)
}
enum class NormalSamplerMethod {
BoxMuller,
Marsaglia,
Ziggurat
}
fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomProvider): NormalizedGaussianSampler =
when (method) {
NormalSamplerMethod.BoxMuller -> BoxMullerNormalizedGaussianSampler(
provider
)
NormalSamplerMethod.Marsaglia -> MarsagliaNormalizedGaussianSampler(provider)
NormalSamplerMethod.Ziggurat -> ZigguratNormalizedGaussianSampler(provider)
}

View File

@ -1,9 +1,10 @@
package scientifik.kmath.prob
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
inline class RandomGeneratorProvider(val generator: RandomGenerator) : UniformRandomProvider {
inline class RandomGeneratorProvider(val generator: RandomGenerator) :
UniformRandomProvider {
override fun nextBoolean(): Boolean = generator.nextBoolean()
override fun nextFloat(): Float = generator.nextDouble().toFloat()

View File

@ -1,9 +1,9 @@
package scientifik.kmath.prob
import scientifik.commons.rng.sampling.distribution.ContinuousSampler
import scientifik.commons.rng.sampling.distribution.DiscreteSampler
import scientifik.commons.rng.sampling.distribution.GaussianSampler
import scientifik.commons.rng.sampling.distribution.PoissonSampler
import scientifik.kmath.commons.rng.sampling.distribution.ContinuousSampler
import scientifik.kmath.commons.rng.sampling.distribution.DiscreteSampler
import scientifik.kmath.commons.rng.sampling.distribution.GaussianSampler
import scientifik.kmath.commons.rng.sampling.distribution.PoissonSampler
import kotlin.math.PI
import kotlin.math.exp
import kotlin.math.pow
@ -33,7 +33,11 @@ fun Distribution.Companion.normal(
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
val provider = generator.asUniformRandomProvider()
val normalizedSampler = normalSampler(method, provider)
return GaussianSampler(normalizedSampler, mean, sigma)
return GaussianSampler(
normalizedSampler,
mean,
sigma
)
}
override fun probability(arg: Double): Double {

View File

@ -1,7 +1,7 @@
package scientifik.kmath.prob
import org.apache.commons.rng.simple.RandomSource
import scientifik.commons.rng.UniformRandomProvider
import scientifik.kmath.commons.rng.UniformRandomProvider
class RandomSourceGenerator(val source: RandomSource, seed: Long?) :
RandomGenerator {

View File

@ -29,13 +29,13 @@ pluginManagement {
}
rootProject.name = "kmath"
include(
":kmath-memory",
":kmath-core",
":kmath-functions",
// ":kmath-io",
":kmath-coroutines",
"kmath-commons-rng-part",
":kmath-histograms",
":kmath-commons",
":kmath-viktor",