forked from kscience/kmath
Perform 1.4 and explicit API migrations, refactor blocking chain, make tests work
This commit is contained in:
parent
e35a364aa8
commit
53ebec2e01
@ -3,10 +3,7 @@ package kscience.kmath.chains
|
||||
/**
|
||||
* Performance optimized chain for integer values
|
||||
*/
|
||||
public abstract class BlockingIntChain : Chain<Int> {
|
||||
public abstract fun nextInt(): Int
|
||||
|
||||
override suspend fun next(): Int = nextInt()
|
||||
|
||||
public fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
|
||||
public interface BlockingIntChain : Chain<Int> {
|
||||
public override suspend fun next(): Int
|
||||
public suspend fun nextBlock(size: Int): IntArray = IntArray(size) { next() }
|
||||
}
|
||||
|
@ -3,10 +3,7 @@ package kscience.kmath.chains
|
||||
/**
|
||||
* Performance optimized chain for real values
|
||||
*/
|
||||
public abstract class BlockingRealChain : Chain<Double> {
|
||||
public abstract fun nextDouble(): Double
|
||||
|
||||
override suspend fun next(): Double = nextDouble()
|
||||
|
||||
public fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
|
||||
public interface BlockingRealChain : Chain<Double> {
|
||||
public override suspend fun next(): Double
|
||||
public suspend fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { next() }
|
||||
}
|
||||
|
@ -1,8 +0,0 @@
|
||||
package scientifik.kmath.chains
|
||||
|
||||
/**
|
||||
* Performance optimized chain for integer values
|
||||
*/
|
||||
abstract class BlockingIntChain : Chain<Int> {
|
||||
suspend fun nextBlock(size: Int): IntArray = IntArray(size) { next() }
|
||||
}
|
@ -1,8 +0,0 @@
|
||||
package scientifik.kmath.chains
|
||||
|
||||
/**
|
||||
* Performance optimized chain for real values
|
||||
*/
|
||||
abstract class BlockingRealChain : Chain<Double> {
|
||||
suspend fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { next() }
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
package kscience.kmath.prob
|
||||
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.chains.collect
|
||||
import kscience.kmath.structures.Buffer
|
||||
@ -70,7 +71,7 @@ public fun <T : Any> Sampler<T>.sampleBuffer(
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator) = sample(generator).first()
|
||||
public suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator): T = sample(generator).first()
|
||||
|
||||
/**
|
||||
* Generate a bunch of samples from real distributions
|
||||
|
@ -14,7 +14,7 @@ public interface NamedDistribution<T> : Distribution<Map<String, T>>
|
||||
public class FactorizedDistribution<T>(public val distributions: Collection<NamedDistribution<T>>) :
|
||||
NamedDistribution<T> {
|
||||
override fun probability(arg: Map<String, T>): Double =
|
||||
distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
|
||||
distributions.fold(1.0) { acc, dist -> acc * dist.probability(arg) }
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
||||
val chains = distributions.map { it.sample(generator) }
|
||||
|
@ -1,17 +1,22 @@
|
||||
package kscience.kmath.prob
|
||||
|
||||
import kscience.kmath.chains.BlockingIntChain
|
||||
import kscience.kmath.chains.BlockingRealChain
|
||||
import kscience.kmath.chains.Chain
|
||||
|
||||
/**
|
||||
* A possibly stateful chain producing random values.
|
||||
*
|
||||
* @property generator the underlying [RandomGenerator] instance.
|
||||
*/
|
||||
public class RandomChain<out R>(
|
||||
public val generator: RandomGenerator,
|
||||
private val gen: suspend RandomGenerator.() -> R
|
||||
) : Chain<R> {
|
||||
override suspend fun next(): R = generator.gen()
|
||||
|
||||
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
||||
}
|
||||
|
||||
public fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
||||
public fun Chain<Double>.blocking(): BlockingRealChain = object : Chain<Double> by this, BlockingRealChain {}
|
||||
public fun Chain<Int>.blocking(): BlockingIntChain = object : Chain<Int> by this, BlockingIntChain {}
|
||||
|
@ -1,9 +1,9 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.chain
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.pow
|
||||
|
||||
@ -13,8 +13,8 @@ import kotlin.math.pow
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.html
|
||||
*/
|
||||
class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sampler<Double> {
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||
public class AhrensDieterExponentialSampler private constructor(public val mean: Double) : Sampler<Double> {
|
||||
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||
// Step 1:
|
||||
var a = 0.0
|
||||
var u = nextDouble()
|
||||
@ -47,7 +47,7 @@ class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sam
|
||||
|
||||
override fun toString(): String = "Ahrens-Dieter Exponential deviate"
|
||||
|
||||
companion object {
|
||||
public companion object {
|
||||
private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) }
|
||||
|
||||
init {
|
||||
@ -64,7 +64,7 @@ class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sam
|
||||
}
|
||||
}
|
||||
|
||||
fun of(mean: Double): AhrensDieterExponentialSampler {
|
||||
public fun of(mean: Double): AhrensDieterExponentialSampler {
|
||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||
return AhrensDieterExponentialSampler(mean)
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import scientifik.kmath.prob.next
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.chain
|
||||
import kscience.kmath.prob.next
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
@ -17,15 +17,12 @@ import kotlin.math.*
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.html
|
||||
*/
|
||||
class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
||||
public class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
||||
alpha: Double,
|
||||
theta: Double
|
||||
) : Sampler<Double> {
|
||||
private val delegate: BaseGammaSampler =
|
||||
if (alpha < 1) AhrensDieterGammaSampler(alpha, theta) else MarsagliaTsangGammaSampler(
|
||||
alpha,
|
||||
theta
|
||||
)
|
||||
if (alpha < 1) AhrensDieterGammaSampler(alpha, theta) else MarsagliaTsangGammaSampler(alpha, theta)
|
||||
|
||||
private abstract class BaseGammaSampler internal constructor(
|
||||
protected val alpha: Double,
|
||||
@ -39,7 +36,7 @@ class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
||||
override fun toString(): String = "Ahrens-Dieter-Marsaglia-Tsang Gamma deviate"
|
||||
}
|
||||
|
||||
private class AhrensDieterGammaSampler internal constructor(alpha: Double, theta: Double) :
|
||||
private class AhrensDieterGammaSampler(alpha: Double, theta: Double) :
|
||||
BaseGammaSampler(alpha, theta) {
|
||||
private val oneOverAlpha: Double = 1.0 / alpha
|
||||
private val bGSOptim: Double = 1.0 + alpha / E
|
||||
@ -75,7 +72,7 @@ class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
||||
}
|
||||
}
|
||||
|
||||
private class MarsagliaTsangGammaSampler internal constructor(alpha: Double, theta: Double) :
|
||||
private class MarsagliaTsangGammaSampler(alpha: Double, theta: Double) :
|
||||
BaseGammaSampler(alpha, theta) {
|
||||
private val dOptim: Double
|
||||
private val cOptim: Double
|
||||
@ -110,11 +107,11 @@ class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
||||
}
|
||||
}
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> = delegate.sample(generator)
|
||||
override fun toString(): String = delegate.toString()
|
||||
public override fun sample(generator: RandomGenerator): Chain<Double> = delegate.sample(generator)
|
||||
public override fun toString(): String = delegate.toString()
|
||||
|
||||
companion object {
|
||||
fun of(
|
||||
public companion object {
|
||||
public fun of(
|
||||
alpha: Double,
|
||||
theta: Double
|
||||
): Sampler<Double> = AhrensDieterMarsagliaTsangGammaSampler(alpha, theta)
|
||||
|
@ -1,10 +1,9 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import kotlin.jvm.JvmOverloads
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.chain
|
||||
import kotlin.math.ceil
|
||||
import kotlin.math.max
|
||||
import kotlin.math.min
|
||||
@ -36,7 +35,7 @@ import kotlin.math.min
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSampler.html
|
||||
*/
|
||||
open class AliasMethodDiscreteSampler private constructor(
|
||||
public open class AliasMethodDiscreteSampler private constructor(
|
||||
// Deliberate direct storage of input arrays
|
||||
protected val probability: LongArray,
|
||||
protected val alias: IntArray
|
||||
@ -102,17 +101,16 @@ open class AliasMethodDiscreteSampler private constructor(
|
||||
if (generator.nextLong() ushr 11 < probability[j]) j else alias[j]
|
||||
}
|
||||
|
||||
override fun toString(): String = "Alias method"
|
||||
public override fun toString(): String = "Alias method"
|
||||
|
||||
companion object {
|
||||
public companion object {
|
||||
private const val DEFAULT_ALPHA = 0
|
||||
private const val ZERO = 0.0
|
||||
private const val ONE_AS_NUMERATOR = 1L shl 53
|
||||
private const val CONVERT_TO_NUMERATOR: Double = ONE_AS_NUMERATOR.toDouble()
|
||||
private const val MAX_SMALL_POWER_2_SIZE = 1 shl 11
|
||||
|
||||
@JvmOverloads
|
||||
fun of(
|
||||
public fun of(
|
||||
probabilities: DoubleArray,
|
||||
alpha: Int = DEFAULT_ALPHA
|
||||
): Sampler<Int> {
|
||||
|
@ -1,9 +1,9 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.chain
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
@ -13,10 +13,10 @@ import kotlin.math.*
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.html
|
||||
*/
|
||||
class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
||||
public class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
||||
private var nextGaussian: Double = Double.NaN
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||
val random: Double
|
||||
|
||||
if (nextGaussian.isNaN()) {
|
||||
@ -40,9 +40,9 @@ class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGauss
|
||||
random
|
||||
}
|
||||
|
||||
override fun toString(): String = "Box-Muller normalized Gaussian deviate"
|
||||
public override fun toString(): String = "Box-Muller normalized Gaussian deviate"
|
||||
|
||||
companion object {
|
||||
fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
|
||||
public companion object {
|
||||
public fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.chains.map
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.chains.map
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
|
||||
/**
|
||||
* Sampling from a Gaussian distribution with given mean and standard deviation.
|
||||
@ -11,24 +11,24 @@ import scientifik.kmath.prob.Sampler
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html
|
||||
*/
|
||||
class GaussianSampler private constructor(
|
||||
public class GaussianSampler private constructor(
|
||||
private val mean: Double,
|
||||
private val standardDeviation: Double,
|
||||
private val normalized: NormalizedGaussianSampler
|
||||
) : Sampler<Double> {
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> = normalized
|
||||
public override fun sample(generator: RandomGenerator): Chain<Double> = normalized
|
||||
.sample(generator)
|
||||
.map { standardDeviation * it + mean }
|
||||
|
||||
override fun toString(): String = "Gaussian deviate [$normalized]"
|
||||
|
||||
companion object {
|
||||
fun of(
|
||||
public companion object {
|
||||
public fun of(
|
||||
mean: Double,
|
||||
standardDeviation: Double,
|
||||
normalized: NormalizedGaussianSampler
|
||||
normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler.of()
|
||||
): GaussianSampler {
|
||||
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
|
||||
require(standardDeviation > 0.0) { "standard deviation is not strictly positive: $standardDeviation" }
|
||||
|
||||
return GaussianSampler(
|
||||
mean,
|
||||
|
@ -1,4 +1,4 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.ln
|
||||
|
@ -1,4 +1,4 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.min
|
||||
|
@ -1,9 +1,9 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.chain
|
||||
import kotlin.math.exp
|
||||
|
||||
/**
|
||||
@ -18,11 +18,11 @@ import kotlin.math.exp
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.html
|
||||
*/
|
||||
class KempSmallMeanPoissonSampler private constructor(
|
||||
public class KempSmallMeanPoissonSampler private constructor(
|
||||
private val p0: Double,
|
||||
private val mean: Double
|
||||
) : Sampler<Int> {
|
||||
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||
public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||
// Note on the algorithm:
|
||||
// - X is the unknown sample deviate (the output of the algorithm)
|
||||
// - x is the current value from the distribution
|
||||
@ -48,10 +48,10 @@ class KempSmallMeanPoissonSampler private constructor(
|
||||
x
|
||||
}
|
||||
|
||||
override fun toString(): String = "Kemp Small Mean Poisson deviate"
|
||||
public override fun toString(): String = "Kemp Small Mean Poisson deviate"
|
||||
|
||||
companion object {
|
||||
fun of(mean: Double): KempSmallMeanPoissonSampler {
|
||||
public companion object {
|
||||
public fun of(mean: Double): KempSmallMeanPoissonSampler {
|
||||
require(mean > 0) { "Mean is not strictly positive: $mean" }
|
||||
val p0 = exp(-mean)
|
||||
// Probability must be positive. As mean increases then p(0) decreases.
|
||||
|
@ -1,11 +1,11 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.chains.ConstantChain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import scientifik.kmath.prob.next
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.chains.ConstantChain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.chain
|
||||
import kscience.kmath.prob.next
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
@ -19,7 +19,7 @@ import kotlin.math.*
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.html
|
||||
*/
|
||||
class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<Int> {
|
||||
public class LargeMeanPoissonSampler private constructor(public val mean: Double) : Sampler<Int> {
|
||||
private val exponential: Sampler<Double> = AhrensDieterExponentialSampler.of(1.0)
|
||||
private val gaussian: Sampler<Double> = ZigguratNormalizedGaussianSampler.of()
|
||||
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG
|
||||
@ -41,7 +41,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
|
||||
else // Not used.
|
||||
KempSmallMeanPoissonSampler.of(mean - lambda)
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||
public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||
// This will never be null. It may be a no-op delegate that returns zero.
|
||||
val y2 = smallMeanPoissonSampler.next(generator)
|
||||
var x: Double
|
||||
@ -114,7 +114,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
|
||||
|
||||
override fun toString(): String = "Large Mean Poisson deviate"
|
||||
|
||||
companion object {
|
||||
public companion object {
|
||||
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
|
||||
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
|
||||
|
||||
@ -122,7 +122,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
|
||||
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
|
||||
}
|
||||
|
||||
fun of(mean: Double): LargeMeanPoissonSampler {
|
||||
public fun of(mean: Double): LargeMeanPoissonSampler {
|
||||
require(mean >= 1) { "mean is not >= 1: $mean" }
|
||||
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
||||
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
||||
|
@ -1,9 +1,9 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.chain
|
||||
import kotlin.math.ln
|
||||
import kotlin.math.sqrt
|
||||
|
||||
@ -15,10 +15,10 @@ import kotlin.math.sqrt
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.html
|
||||
*/
|
||||
class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
||||
public class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
||||
private var nextGaussian = Double.NaN
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||
if (nextGaussian.isNaN()) {
|
||||
val alpha: Double
|
||||
var x: Double
|
||||
@ -53,9 +53,9 @@ class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGauss
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
|
||||
public override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
|
||||
|
||||
companion object {
|
||||
fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler()
|
||||
public companion object {
|
||||
public fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler()
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.Sampler
|
||||
|
||||
/**
|
||||
* Marker interface for a sampler that generates values from an N(0,1)
|
||||
* [Gaussian distribution](https://en.wikipedia.org/wiki/Normal_distribution).
|
||||
*/
|
||||
interface NormalizedGaussianSampler : Sampler<Double>
|
||||
public interface NormalizedGaussianSampler : Sampler<Double>
|
||||
|
@ -1,8 +1,8 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
|
||||
/**
|
||||
* Sampler for the Poisson distribution.
|
||||
@ -16,17 +16,15 @@ import scientifik.kmath.prob.Sampler
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/PoissonSampler.html
|
||||
*/
|
||||
class PoissonSampler private constructor(
|
||||
mean: Double
|
||||
) : Sampler<Int> {
|
||||
public class PoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
||||
private val poissonSamplerDelegate: Sampler<Int> = of(mean)
|
||||
override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
|
||||
override fun toString(): String = poissonSamplerDelegate.toString()
|
||||
public override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
|
||||
public override fun toString(): String = poissonSamplerDelegate.toString()
|
||||
|
||||
companion object {
|
||||
public companion object {
|
||||
private const val PIVOT = 40.0
|
||||
|
||||
fun of(mean: Double) =// Each sampler should check the input arguments.
|
||||
public fun of(mean: Double): Sampler<Int> =// Each sampler should check the input arguments.
|
||||
if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean)
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.chain
|
||||
import kotlin.math.ceil
|
||||
import kotlin.math.exp
|
||||
|
||||
@ -19,7 +19,7 @@ import kotlin.math.exp
|
||||
*
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.html
|
||||
*/
|
||||
class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
||||
public class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
||||
private val p0: Double = exp(-mean)
|
||||
|
||||
private val limit: Int = (if (p0 > 0)
|
||||
@ -27,7 +27,7 @@ class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
||||
else
|
||||
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||
public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||
var n = 0
|
||||
var r = 1.0
|
||||
|
||||
@ -39,10 +39,10 @@ class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
||||
n
|
||||
}
|
||||
|
||||
override fun toString(): String = "Small Mean Poisson deviate"
|
||||
public override fun toString(): String = "Small Mean Poisson deviate"
|
||||
|
||||
companion object {
|
||||
fun of(mean: Double): SmallMeanPoissonSampler {
|
||||
public companion object {
|
||||
public fun of(mean: Double): SmallMeanPoissonSampler {
|
||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||
return SmallMeanPoissonSampler(mean)
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
package scientifik.kmath.prob.samplers
|
||||
package kscience.kmath.prob.samplers
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.Sampler
|
||||
import scientifik.kmath.prob.chain
|
||||
import kscience.kmath.chains.Chain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.Sampler
|
||||
import kscience.kmath.prob.chain
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
@ -14,7 +14,7 @@ import kotlin.math.*
|
||||
* Based on Commons RNG implementation.
|
||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.html
|
||||
*/
|
||||
class ZigguratNormalizedGaussianSampler private constructor() :
|
||||
public class ZigguratNormalizedGaussianSampler private constructor() :
|
||||
NormalizedGaussianSampler, Sampler<Double> {
|
||||
|
||||
private fun sampleOne(generator: RandomGenerator): Double {
|
||||
@ -23,9 +23,8 @@ class ZigguratNormalizedGaussianSampler private constructor() :
|
||||
return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i)
|
||||
}
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { sampleOne(this) }
|
||||
|
||||
override fun toString(): String = "Ziggurat normalized Gaussian deviate"
|
||||
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { sampleOne(this) }
|
||||
public override fun toString(): String = "Ziggurat normalized Gaussian deviate"
|
||||
|
||||
private fun fix(
|
||||
generator: RandomGenerator,
|
||||
@ -59,7 +58,7 @@ class ZigguratNormalizedGaussianSampler private constructor() :
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
public companion object {
|
||||
private const val R: Double = 3.442619855899
|
||||
private const val ONE_OVER_R: Double = 1 / R
|
||||
private const val V: Double = 9.91256303526217e-3
|
||||
@ -94,7 +93,7 @@ class ZigguratNormalizedGaussianSampler private constructor() :
|
||||
}
|
||||
}
|
||||
|
||||
fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler()
|
||||
public fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler()
|
||||
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
||||
}
|
||||
}
|
||||
|
@ -1,46 +0,0 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.chains.SimpleChain
|
||||
|
||||
/**
|
||||
* A multivariate distribution which takes a map of parameters
|
||||
*/
|
||||
interface NamedDistribution<T> : Distribution<Map<String, T>>
|
||||
|
||||
/**
|
||||
* A multivariate distribution that has independent distributions for separate axis
|
||||
*/
|
||||
class FactorizedDistribution<T>(val distributions: Collection<NamedDistribution<T>>) : NamedDistribution<T> {
|
||||
override fun probability(arg: Map<String, T>): Double {
|
||||
return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
|
||||
}
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
||||
val chains = distributions.map { it.sample(generator) }
|
||||
return SimpleChain<Map<String, T>> {
|
||||
chains.fold(emptyMap()) { acc, chain -> acc + chain.next() }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class NamedDistributionWrapper<T : Any>(val name: String, val distribution: Distribution<T>) : NamedDistribution<T> {
|
||||
override fun probability(arg: Map<String, T>): Double = distribution.probability(
|
||||
arg[name] ?: error("Argument with name $name not found in input parameters")
|
||||
)
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
||||
val chain = distribution.sample(generator)
|
||||
return SimpleChain {
|
||||
mapOf(name to chain.next())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class DistributionBuilder<T: Any>{
|
||||
private val distributions = ArrayList<NamedDistribution<T>>()
|
||||
|
||||
infix fun String.to(distribution: Distribution<T>){
|
||||
distributions.add(NamedDistributionWrapper(this,distribution))
|
||||
}
|
||||
}
|
@ -1,30 +0,0 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import scientifik.kmath.chains.BlockingIntChain
|
||||
import scientifik.kmath.chains.BlockingRealChain
|
||||
import scientifik.kmath.chains.Chain
|
||||
|
||||
/**
|
||||
* A possibly stateful chain producing random values.
|
||||
*/
|
||||
class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspend RandomGenerator.() -> R) : Chain<R> {
|
||||
override suspend fun next(): R = generator.gen()
|
||||
|
||||
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
||||
}
|
||||
|
||||
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
||||
|
||||
fun RandomChain<Double>.blocking(): BlockingRealChain = let {
|
||||
object : BlockingRealChain() {
|
||||
override suspend fun next(): Double = it.next()
|
||||
override fun fork(): Chain<Double> = it.fork()
|
||||
}
|
||||
}
|
||||
|
||||
fun RandomChain<Int>.blocking(): BlockingIntChain = let {
|
||||
object : BlockingIntChain() {
|
||||
override suspend fun next(): Int = it.next()
|
||||
override fun fork(): Chain<Int> = it.fork()
|
||||
}
|
||||
}
|
@ -1,47 +0,0 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import kotlin.random.Random
|
||||
|
||||
/**
|
||||
* A basic generator
|
||||
*/
|
||||
interface RandomGenerator {
|
||||
fun nextBoolean(): Boolean
|
||||
fun nextDouble(): Double
|
||||
fun nextInt(): Int
|
||||
fun nextInt(until: Int): Int
|
||||
fun nextLong(): Long
|
||||
fun nextLong(until: Long): Long
|
||||
fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size)
|
||||
fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) }
|
||||
|
||||
/**
|
||||
* Create a new generator which is independent from current generator (operations on new generator do not affect this one
|
||||
* and vise versa). The statistical properties of new generator should be the same as for this one.
|
||||
* For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run.
|
||||
*
|
||||
* The thread safety of this operation is not guaranteed since it could affect the state of the generator.
|
||||
*/
|
||||
fun fork(): RandomGenerator
|
||||
|
||||
companion object {
|
||||
val default by lazy { DefaultGenerator() }
|
||||
fun default(seed: Long) = DefaultGenerator(Random(seed))
|
||||
}
|
||||
}
|
||||
|
||||
inline class DefaultGenerator(private val random: Random = Random) : RandomGenerator {
|
||||
override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||
override fun nextDouble(): Double = random.nextDouble()
|
||||
override fun nextInt(): Int = random.nextInt()
|
||||
override fun nextInt(until: Int): Int = random.nextInt(until)
|
||||
override fun nextLong(): Long = random.nextLong()
|
||||
override fun nextLong(until: Long): Long = random.nextLong(until)
|
||||
|
||||
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
||||
random.nextBytes(array, fromIndex, toIndex)
|
||||
}
|
||||
|
||||
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
|
||||
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
|
||||
}
|
@ -1,21 +1,22 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import org.apache.commons.rng.simple.RandomSource
|
||||
|
||||
class RandomSourceGenerator(private val source: RandomSource, seed: Long?) :
|
||||
public class RandomSourceGenerator(private val source: RandomSource, seed: Long?) :
|
||||
RandomGenerator {
|
||||
private val random = seed?.let {
|
||||
RandomSource.create(source, seed)
|
||||
} ?: RandomSource.create(source)
|
||||
|
||||
override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||
override fun nextDouble(): Double = random.nextDouble()
|
||||
override fun nextInt(): Int = random.nextInt()
|
||||
override fun nextInt(until: Int): Int = random.nextInt(until)
|
||||
override fun nextLong(): Long = random.nextLong()
|
||||
override fun nextLong(until: Long): Long = random.nextLong(until)
|
||||
public override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||
public override fun nextDouble(): Double = random.nextDouble()
|
||||
public override fun nextInt(): Int = random.nextInt()
|
||||
public override fun nextInt(until: Int): Int = random.nextInt(until)
|
||||
public override fun nextLong(): Long = random.nextLong()
|
||||
public override fun nextLong(until: Long): Long = random.nextLong(until)
|
||||
|
||||
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
||||
public override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
||||
require(toIndex > fromIndex)
|
||||
random.nextBytes(array, fromIndex, toIndex - fromIndex)
|
||||
}
|
||||
@ -23,8 +24,8 @@ class RandomSourceGenerator(private val source: RandomSource, seed: Long?) :
|
||||
override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong())
|
||||
}
|
||||
|
||||
fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
|
||||
public fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
|
||||
RandomSourceGenerator(source, seed)
|
||||
|
||||
fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator =
|
||||
public fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator =
|
||||
fromSource(RandomSource.MT, seed)
|
||||
|
@ -3,25 +3,24 @@ package kscience.kmath.prob
|
||||
import kotlinx.coroutines.flow.take
|
||||
import kotlinx.coroutines.flow.toList
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kscience.kmath.prob.samplers.GaussianSampler
|
||||
import org.junit.jupiter.api.Assertions
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
internal class CommonsDistributionsTest {
|
||||
@Test
|
||||
fun testNormalDistributionSuspend() {
|
||||
val distribution = Distribution.normal(7.0, 2.0)
|
||||
val distribution = GaussianSampler.of(7.0, 2.0)
|
||||
val generator = RandomGenerator.default(1)
|
||||
val sample = runBlocking {
|
||||
distribution.sample(generator).take(1000).toList()
|
||||
}
|
||||
val sample = runBlocking { distribution.sample(generator).take(1000).toList() }
|
||||
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testNormalDistributionBlocking() {
|
||||
val distribution = Distribution.normal(7.0, 2.0)
|
||||
val distribution = GaussianSampler.of(7.0, 2.0)
|
||||
val generator = RandomGenerator.default(1)
|
||||
val sample = distribution.sample(generator).nextBlock(1000)
|
||||
val sample = runBlocking { distribution.sample(generator).blocking().nextBlock(1000) }
|
||||
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user