Perform 1.4 and explicit API migrations, refactor blocking chain, make tests work

This commit is contained in:
Iaroslav Postovalov 2020-09-27 15:16:12 +07:00
parent e35a364aa8
commit 53ebec2e01
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
26 changed files with 149 additions and 296 deletions

View File

@ -3,10 +3,7 @@ package kscience.kmath.chains
/** /**
* Performance optimized chain for integer values * Performance optimized chain for integer values
*/ */
public abstract class BlockingIntChain : Chain<Int> { public interface BlockingIntChain : Chain<Int> {
public abstract fun nextInt(): Int public override suspend fun next(): Int
public suspend fun nextBlock(size: Int): IntArray = IntArray(size) { next() }
override suspend fun next(): Int = nextInt()
public fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
} }

View File

@ -3,10 +3,7 @@ package kscience.kmath.chains
/** /**
* Performance optimized chain for real values * Performance optimized chain for real values
*/ */
public abstract class BlockingRealChain : Chain<Double> { public interface BlockingRealChain : Chain<Double> {
public abstract fun nextDouble(): Double public override suspend fun next(): Double
public suspend fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { next() }
override suspend fun next(): Double = nextDouble()
public fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
} }

View File

@ -1,8 +0,0 @@
package scientifik.kmath.chains
/**
* Performance optimized chain for integer values
*/
abstract class BlockingIntChain : Chain<Int> {
suspend fun nextBlock(size: Int): IntArray = IntArray(size) { next() }
}

View File

@ -1,8 +0,0 @@
package scientifik.kmath.chains
/**
* Performance optimized chain for real values
*/
abstract class BlockingRealChain : Chain<Double> {
suspend fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { next() }
}

View File

@ -1,5 +1,6 @@
package kscience.kmath.prob package kscience.kmath.prob
import kotlinx.coroutines.flow.first
import kscience.kmath.chains.Chain import kscience.kmath.chains.Chain
import kscience.kmath.chains.collect import kscience.kmath.chains.collect
import kscience.kmath.structures.Buffer import kscience.kmath.structures.Buffer
@ -70,7 +71,7 @@ public fun <T : Any> Sampler<T>.sampleBuffer(
} }
} }
suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator) = sample(generator).first() public suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator): T = sample(generator).first()
/** /**
* Generate a bunch of samples from real distributions * Generate a bunch of samples from real distributions

View File

@ -14,7 +14,7 @@ public interface NamedDistribution<T> : Distribution<Map<String, T>>
public class FactorizedDistribution<T>(public val distributions: Collection<NamedDistribution<T>>) : public class FactorizedDistribution<T>(public val distributions: Collection<NamedDistribution<T>>) :
NamedDistribution<T> { NamedDistribution<T> {
override fun probability(arg: Map<String, T>): Double = override fun probability(arg: Map<String, T>): Double =
distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) } distributions.fold(1.0) { acc, dist -> acc * dist.probability(arg) }
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> { override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
val chains = distributions.map { it.sample(generator) } val chains = distributions.map { it.sample(generator) }

View File

@ -1,17 +1,22 @@
package kscience.kmath.prob package kscience.kmath.prob
import kscience.kmath.chains.BlockingIntChain
import kscience.kmath.chains.BlockingRealChain
import kscience.kmath.chains.Chain import kscience.kmath.chains.Chain
/** /**
* A possibly stateful chain producing random values. * A possibly stateful chain producing random values.
*
* @property generator the underlying [RandomGenerator] instance.
*/ */
public class RandomChain<out R>( public class RandomChain<out R>(
public val generator: RandomGenerator, public val generator: RandomGenerator,
private val gen: suspend RandomGenerator.() -> R private val gen: suspend RandomGenerator.() -> R
) : Chain<R> { ) : Chain<R> {
override suspend fun next(): R = generator.gen() override suspend fun next(): R = generator.gen()
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen) override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
} }
public fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen) public fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
public fun Chain<Double>.blocking(): BlockingRealChain = object : Chain<Double> by this, BlockingRealChain {}
public fun Chain<Int>.blocking(): BlockingIntChain = object : Chain<Int> by this, BlockingIntChain {}

View File

@ -1,9 +1,9 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
import scientifik.kmath.prob.chain import kscience.kmath.prob.chain
import kotlin.math.ln import kotlin.math.ln
import kotlin.math.pow import kotlin.math.pow
@ -13,8 +13,8 @@ import kotlin.math.pow
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.html
*/ */
class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sampler<Double> { public class AhrensDieterExponentialSampler private constructor(public val mean: Double) : Sampler<Double> {
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
// Step 1: // Step 1:
var a = 0.0 var a = 0.0
var u = nextDouble() var u = nextDouble()
@ -47,7 +47,7 @@ class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sam
override fun toString(): String = "Ahrens-Dieter Exponential deviate" override fun toString(): String = "Ahrens-Dieter Exponential deviate"
companion object { public companion object {
private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) } private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) }
init { init {
@ -64,7 +64,7 @@ class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sam
} }
} }
fun of(mean: Double): AhrensDieterExponentialSampler { public fun of(mean: Double): AhrensDieterExponentialSampler {
require(mean > 0) { "mean is not strictly positive: $mean" } require(mean > 0) { "mean is not strictly positive: $mean" }
return AhrensDieterExponentialSampler(mean) return AhrensDieterExponentialSampler(mean)
} }

View File

@ -1,10 +1,10 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
import scientifik.kmath.prob.chain import kscience.kmath.prob.chain
import scientifik.kmath.prob.next import kscience.kmath.prob.next
import kotlin.math.* import kotlin.math.*
/** /**
@ -17,15 +17,12 @@ import kotlin.math.*
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.html
*/ */
class AhrensDieterMarsagliaTsangGammaSampler private constructor( public class AhrensDieterMarsagliaTsangGammaSampler private constructor(
alpha: Double, alpha: Double,
theta: Double theta: Double
) : Sampler<Double> { ) : Sampler<Double> {
private val delegate: BaseGammaSampler = private val delegate: BaseGammaSampler =
if (alpha < 1) AhrensDieterGammaSampler(alpha, theta) else MarsagliaTsangGammaSampler( if (alpha < 1) AhrensDieterGammaSampler(alpha, theta) else MarsagliaTsangGammaSampler(alpha, theta)
alpha,
theta
)
private abstract class BaseGammaSampler internal constructor( private abstract class BaseGammaSampler internal constructor(
protected val alpha: Double, protected val alpha: Double,
@ -39,7 +36,7 @@ class AhrensDieterMarsagliaTsangGammaSampler private constructor(
override fun toString(): String = "Ahrens-Dieter-Marsaglia-Tsang Gamma deviate" override fun toString(): String = "Ahrens-Dieter-Marsaglia-Tsang Gamma deviate"
} }
private class AhrensDieterGammaSampler internal constructor(alpha: Double, theta: Double) : private class AhrensDieterGammaSampler(alpha: Double, theta: Double) :
BaseGammaSampler(alpha, theta) { BaseGammaSampler(alpha, theta) {
private val oneOverAlpha: Double = 1.0 / alpha private val oneOverAlpha: Double = 1.0 / alpha
private val bGSOptim: Double = 1.0 + alpha / E private val bGSOptim: Double = 1.0 + alpha / E
@ -75,7 +72,7 @@ class AhrensDieterMarsagliaTsangGammaSampler private constructor(
} }
} }
private class MarsagliaTsangGammaSampler internal constructor(alpha: Double, theta: Double) : private class MarsagliaTsangGammaSampler(alpha: Double, theta: Double) :
BaseGammaSampler(alpha, theta) { BaseGammaSampler(alpha, theta) {
private val dOptim: Double private val dOptim: Double
private val cOptim: Double private val cOptim: Double
@ -110,11 +107,11 @@ class AhrensDieterMarsagliaTsangGammaSampler private constructor(
} }
} }
override fun sample(generator: RandomGenerator): Chain<Double> = delegate.sample(generator) public override fun sample(generator: RandomGenerator): Chain<Double> = delegate.sample(generator)
override fun toString(): String = delegate.toString() public override fun toString(): String = delegate.toString()
companion object { public companion object {
fun of( public fun of(
alpha: Double, alpha: Double,
theta: Double theta: Double
): Sampler<Double> = AhrensDieterMarsagliaTsangGammaSampler(alpha, theta) ): Sampler<Double> = AhrensDieterMarsagliaTsangGammaSampler(alpha, theta)

View File

@ -1,10 +1,9 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
import scientifik.kmath.prob.chain import kscience.kmath.prob.chain
import kotlin.jvm.JvmOverloads
import kotlin.math.ceil import kotlin.math.ceil
import kotlin.math.max import kotlin.math.max
import kotlin.math.min import kotlin.math.min
@ -36,7 +35,7 @@ import kotlin.math.min
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSampler.html
*/ */
open class AliasMethodDiscreteSampler private constructor( public open class AliasMethodDiscreteSampler private constructor(
// Deliberate direct storage of input arrays // Deliberate direct storage of input arrays
protected val probability: LongArray, protected val probability: LongArray,
protected val alias: IntArray protected val alias: IntArray
@ -102,17 +101,16 @@ open class AliasMethodDiscreteSampler private constructor(
if (generator.nextLong() ushr 11 < probability[j]) j else alias[j] if (generator.nextLong() ushr 11 < probability[j]) j else alias[j]
} }
override fun toString(): String = "Alias method" public override fun toString(): String = "Alias method"
companion object { public companion object {
private const val DEFAULT_ALPHA = 0 private const val DEFAULT_ALPHA = 0
private const val ZERO = 0.0 private const val ZERO = 0.0
private const val ONE_AS_NUMERATOR = 1L shl 53 private const val ONE_AS_NUMERATOR = 1L shl 53
private const val CONVERT_TO_NUMERATOR: Double = ONE_AS_NUMERATOR.toDouble() private const val CONVERT_TO_NUMERATOR: Double = ONE_AS_NUMERATOR.toDouble()
private const val MAX_SMALL_POWER_2_SIZE = 1 shl 11 private const val MAX_SMALL_POWER_2_SIZE = 1 shl 11
@JvmOverloads public fun of(
fun of(
probabilities: DoubleArray, probabilities: DoubleArray,
alpha: Int = DEFAULT_ALPHA alpha: Int = DEFAULT_ALPHA
): Sampler<Int> { ): Sampler<Int> {

View File

@ -1,9 +1,9 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
import scientifik.kmath.prob.chain import kscience.kmath.prob.chain
import kotlin.math.* import kotlin.math.*
/** /**
@ -13,10 +13,10 @@ import kotlin.math.*
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.html
*/ */
class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> { public class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
private var nextGaussian: Double = Double.NaN private var nextGaussian: Double = Double.NaN
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
val random: Double val random: Double
if (nextGaussian.isNaN()) { if (nextGaussian.isNaN()) {
@ -40,9 +40,9 @@ class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGauss
random random
} }
override fun toString(): String = "Box-Muller normalized Gaussian deviate" public override fun toString(): String = "Box-Muller normalized Gaussian deviate"
companion object { public companion object {
fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler() public fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
} }
} }

View File

@ -1,9 +1,9 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.chains.map import kscience.kmath.chains.map
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
/** /**
* Sampling from a Gaussian distribution with given mean and standard deviation. * Sampling from a Gaussian distribution with given mean and standard deviation.
@ -11,24 +11,24 @@ import scientifik.kmath.prob.Sampler
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html
*/ */
class GaussianSampler private constructor( public class GaussianSampler private constructor(
private val mean: Double, private val mean: Double,
private val standardDeviation: Double, private val standardDeviation: Double,
private val normalized: NormalizedGaussianSampler private val normalized: NormalizedGaussianSampler
) : Sampler<Double> { ) : Sampler<Double> {
override fun sample(generator: RandomGenerator): Chain<Double> = normalized public override fun sample(generator: RandomGenerator): Chain<Double> = normalized
.sample(generator) .sample(generator)
.map { standardDeviation * it + mean } .map { standardDeviation * it + mean }
override fun toString(): String = "Gaussian deviate [$normalized]" override fun toString(): String = "Gaussian deviate [$normalized]"
companion object { public companion object {
fun of( public fun of(
mean: Double, mean: Double,
standardDeviation: Double, standardDeviation: Double,
normalized: NormalizedGaussianSampler normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler.of()
): GaussianSampler { ): GaussianSampler {
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" } require(standardDeviation > 0.0) { "standard deviation is not strictly positive: $standardDeviation" }
return GaussianSampler( return GaussianSampler(
mean, mean,

View File

@ -1,4 +1,4 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import kotlin.math.PI import kotlin.math.PI
import kotlin.math.ln import kotlin.math.ln

View File

@ -1,4 +1,4 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import kotlin.math.ln import kotlin.math.ln
import kotlin.math.min import kotlin.math.min

View File

@ -1,9 +1,9 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
import scientifik.kmath.prob.chain import kscience.kmath.prob.chain
import kotlin.math.exp import kotlin.math.exp
/** /**
@ -18,11 +18,11 @@ import kotlin.math.exp
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.html
*/ */
class KempSmallMeanPoissonSampler private constructor( public class KempSmallMeanPoissonSampler private constructor(
private val p0: Double, private val p0: Double,
private val mean: Double private val mean: Double
) : Sampler<Int> { ) : Sampler<Int> {
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain { public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
// Note on the algorithm: // Note on the algorithm:
// - X is the unknown sample deviate (the output of the algorithm) // - X is the unknown sample deviate (the output of the algorithm)
// - x is the current value from the distribution // - x is the current value from the distribution
@ -48,10 +48,10 @@ class KempSmallMeanPoissonSampler private constructor(
x x
} }
override fun toString(): String = "Kemp Small Mean Poisson deviate" public override fun toString(): String = "Kemp Small Mean Poisson deviate"
companion object { public companion object {
fun of(mean: Double): KempSmallMeanPoissonSampler { public fun of(mean: Double): KempSmallMeanPoissonSampler {
require(mean > 0) { "Mean is not strictly positive: $mean" } require(mean > 0) { "Mean is not strictly positive: $mean" }
val p0 = exp(-mean) val p0 = exp(-mean)
// Probability must be positive. As mean increases then p(0) decreases. // Probability must be positive. As mean increases then p(0) decreases.

View File

@ -1,11 +1,11 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.chains.ConstantChain import kscience.kmath.chains.ConstantChain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
import scientifik.kmath.prob.chain import kscience.kmath.prob.chain
import scientifik.kmath.prob.next import kscience.kmath.prob.next
import kotlin.math.* import kotlin.math.*
/** /**
@ -19,7 +19,7 @@ import kotlin.math.*
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.html
*/ */
class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<Int> { public class LargeMeanPoissonSampler private constructor(public val mean: Double) : Sampler<Int> {
private val exponential: Sampler<Double> = AhrensDieterExponentialSampler.of(1.0) private val exponential: Sampler<Double> = AhrensDieterExponentialSampler.of(1.0)
private val gaussian: Sampler<Double> = ZigguratNormalizedGaussianSampler.of() private val gaussian: Sampler<Double> = ZigguratNormalizedGaussianSampler.of()
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG
@ -41,7 +41,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
else // Not used. else // Not used.
KempSmallMeanPoissonSampler.of(mean - lambda) KempSmallMeanPoissonSampler.of(mean - lambda)
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain { public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
// This will never be null. It may be a no-op delegate that returns zero. // This will never be null. It may be a no-op delegate that returns zero.
val y2 = smallMeanPoissonSampler.next(generator) val y2 = smallMeanPoissonSampler.next(generator)
var x: Double var x: Double
@ -114,7 +114,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
override fun toString(): String = "Large Mean Poisson deviate" override fun toString(): String = "Large Mean Poisson deviate"
companion object { public companion object {
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create() private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
@ -122,7 +122,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0) override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
} }
fun of(mean: Double): LargeMeanPoissonSampler { public fun of(mean: Double): LargeMeanPoissonSampler {
require(mean >= 1) { "mean is not >= 1: $mean" } require(mean >= 1) { "mean is not >= 1: $mean" }
// The algorithm is not valid if Math.floor(mean) is not an integer. // The algorithm is not valid if Math.floor(mean) is not an integer.
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" } require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }

View File

@ -1,9 +1,9 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
import scientifik.kmath.prob.chain import kscience.kmath.prob.chain
import kotlin.math.ln import kotlin.math.ln
import kotlin.math.sqrt import kotlin.math.sqrt
@ -15,10 +15,10 @@ import kotlin.math.sqrt
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.html
*/ */
class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> { public class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
private var nextGaussian = Double.NaN private var nextGaussian = Double.NaN
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
if (nextGaussian.isNaN()) { if (nextGaussian.isNaN()) {
val alpha: Double val alpha: Double
var x: Double var x: Double
@ -53,9 +53,9 @@ class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGauss
} }
} }
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate" public override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
companion object { public companion object {
fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler() public fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler()
} }
} }

View File

@ -1,9 +1,9 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
/** /**
* Marker interface for a sampler that generates values from an N(0,1) * Marker interface for a sampler that generates values from an N(0,1)
* [Gaussian distribution](https://en.wikipedia.org/wiki/Normal_distribution). * [Gaussian distribution](https://en.wikipedia.org/wiki/Normal_distribution).
*/ */
interface NormalizedGaussianSampler : Sampler<Double> public interface NormalizedGaussianSampler : Sampler<Double>

View File

@ -1,8 +1,8 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
/** /**
* Sampler for the Poisson distribution. * Sampler for the Poisson distribution.
@ -16,17 +16,15 @@ import scientifik.kmath.prob.Sampler
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/PoissonSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/PoissonSampler.html
*/ */
class PoissonSampler private constructor( public class PoissonSampler private constructor(mean: Double) : Sampler<Int> {
mean: Double
) : Sampler<Int> {
private val poissonSamplerDelegate: Sampler<Int> = of(mean) private val poissonSamplerDelegate: Sampler<Int> = of(mean)
override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator) public override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
override fun toString(): String = poissonSamplerDelegate.toString() public override fun toString(): String = poissonSamplerDelegate.toString()
companion object { public companion object {
private const val PIVOT = 40.0 private const val PIVOT = 40.0
fun of(mean: Double) =// Each sampler should check the input arguments. public fun of(mean: Double): Sampler<Int> =// Each sampler should check the input arguments.
if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean) if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean)
} }
} }

View File

@ -1,9 +1,9 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
import scientifik.kmath.prob.chain import kscience.kmath.prob.chain
import kotlin.math.ceil import kotlin.math.ceil
import kotlin.math.exp import kotlin.math.exp
@ -19,7 +19,7 @@ import kotlin.math.exp
* *
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.html
*/ */
class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> { public class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
private val p0: Double = exp(-mean) private val p0: Double = exp(-mean)
private val limit: Int = (if (p0 > 0) private val limit: Int = (if (p0 > 0)
@ -27,7 +27,7 @@ class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
else else
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt() throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain { public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
var n = 0 var n = 0
var r = 1.0 var r = 1.0
@ -39,10 +39,10 @@ class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
n n
} }
override fun toString(): String = "Small Mean Poisson deviate" public override fun toString(): String = "Small Mean Poisson deviate"
companion object { public companion object {
fun of(mean: Double): SmallMeanPoissonSampler { public fun of(mean: Double): SmallMeanPoissonSampler {
require(mean > 0) { "mean is not strictly positive: $mean" } require(mean > 0) { "mean is not strictly positive: $mean" }
return SmallMeanPoissonSampler(mean) return SmallMeanPoissonSampler(mean)
} }

View File

@ -1,9 +1,9 @@
package scientifik.kmath.prob.samplers package kscience.kmath.prob.samplers
import scientifik.kmath.chains.Chain import kscience.kmath.chains.Chain
import scientifik.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import scientifik.kmath.prob.Sampler import kscience.kmath.prob.Sampler
import scientifik.kmath.prob.chain import kscience.kmath.prob.chain
import kotlin.math.* import kotlin.math.*
/** /**
@ -14,7 +14,7 @@ import kotlin.math.*
* Based on Commons RNG implementation. * Based on Commons RNG implementation.
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.html * See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.html
*/ */
class ZigguratNormalizedGaussianSampler private constructor() : public class ZigguratNormalizedGaussianSampler private constructor() :
NormalizedGaussianSampler, Sampler<Double> { NormalizedGaussianSampler, Sampler<Double> {
private fun sampleOne(generator: RandomGenerator): Double { private fun sampleOne(generator: RandomGenerator): Double {
@ -23,9 +23,8 @@ class ZigguratNormalizedGaussianSampler private constructor() :
return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i) return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i)
} }
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { sampleOne(this) } public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { sampleOne(this) }
public override fun toString(): String = "Ziggurat normalized Gaussian deviate"
override fun toString(): String = "Ziggurat normalized Gaussian deviate"
private fun fix( private fun fix(
generator: RandomGenerator, generator: RandomGenerator,
@ -59,7 +58,7 @@ class ZigguratNormalizedGaussianSampler private constructor() :
} }
} }
companion object { public companion object {
private const val R: Double = 3.442619855899 private const val R: Double = 3.442619855899
private const val ONE_OVER_R: Double = 1 / R private const val ONE_OVER_R: Double = 1 / R
private const val V: Double = 9.91256303526217e-3 private const val V: Double = 9.91256303526217e-3
@ -94,7 +93,7 @@ class ZigguratNormalizedGaussianSampler private constructor() :
} }
} }
fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler() public fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler()
private fun gauss(x: Double): Double = exp(-0.5 * x * x) private fun gauss(x: Double): Double = exp(-0.5 * x * x)
} }
} }

View File

@ -1,46 +0,0 @@
package scientifik.kmath.prob
import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.SimpleChain
/**
* A multivariate distribution which takes a map of parameters
*/
interface NamedDistribution<T> : Distribution<Map<String, T>>
/**
* A multivariate distribution that has independent distributions for separate axis
*/
class FactorizedDistribution<T>(val distributions: Collection<NamedDistribution<T>>) : NamedDistribution<T> {
override fun probability(arg: Map<String, T>): Double {
return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
}
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
val chains = distributions.map { it.sample(generator) }
return SimpleChain<Map<String, T>> {
chains.fold(emptyMap()) { acc, chain -> acc + chain.next() }
}
}
}
class NamedDistributionWrapper<T : Any>(val name: String, val distribution: Distribution<T>) : NamedDistribution<T> {
override fun probability(arg: Map<String, T>): Double = distribution.probability(
arg[name] ?: error("Argument with name $name not found in input parameters")
)
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
val chain = distribution.sample(generator)
return SimpleChain {
mapOf(name to chain.next())
}
}
}
class DistributionBuilder<T: Any>{
private val distributions = ArrayList<NamedDistribution<T>>()
infix fun String.to(distribution: Distribution<T>){
distributions.add(NamedDistributionWrapper(this,distribution))
}
}

View File

@ -1,30 +0,0 @@
package scientifik.kmath.prob
import scientifik.kmath.chains.BlockingIntChain
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.chains.Chain
/**
* A possibly stateful chain producing random values.
*/
class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspend RandomGenerator.() -> R) : Chain<R> {
override suspend fun next(): R = generator.gen()
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
}
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
fun RandomChain<Double>.blocking(): BlockingRealChain = let {
object : BlockingRealChain() {
override suspend fun next(): Double = it.next()
override fun fork(): Chain<Double> = it.fork()
}
}
fun RandomChain<Int>.blocking(): BlockingIntChain = let {
object : BlockingIntChain() {
override suspend fun next(): Int = it.next()
override fun fork(): Chain<Int> = it.fork()
}
}

View File

@ -1,47 +0,0 @@
package scientifik.kmath.prob
import kotlin.random.Random
/**
* A basic generator
*/
interface RandomGenerator {
fun nextBoolean(): Boolean
fun nextDouble(): Double
fun nextInt(): Int
fun nextInt(until: Int): Int
fun nextLong(): Long
fun nextLong(until: Long): Long
fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size)
fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) }
/**
* Create a new generator which is independent from current generator (operations on new generator do not affect this one
* and vise versa). The statistical properties of new generator should be the same as for this one.
* For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run.
*
* The thread safety of this operation is not guaranteed since it could affect the state of the generator.
*/
fun fork(): RandomGenerator
companion object {
val default by lazy { DefaultGenerator() }
fun default(seed: Long) = DefaultGenerator(Random(seed))
}
}
inline class DefaultGenerator(private val random: Random = Random) : RandomGenerator {
override fun nextBoolean(): Boolean = random.nextBoolean()
override fun nextDouble(): Double = random.nextDouble()
override fun nextInt(): Int = random.nextInt()
override fun nextInt(until: Int): Int = random.nextInt(until)
override fun nextLong(): Long = random.nextLong()
override fun nextLong(until: Long): Long = random.nextLong(until)
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
random.nextBytes(array, fromIndex, toIndex)
}
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
}

View File

@ -1,21 +1,22 @@
package scientifik.kmath.prob package scientifik.kmath.prob
import kscience.kmath.prob.RandomGenerator
import org.apache.commons.rng.simple.RandomSource import org.apache.commons.rng.simple.RandomSource
class RandomSourceGenerator(private val source: RandomSource, seed: Long?) : public class RandomSourceGenerator(private val source: RandomSource, seed: Long?) :
RandomGenerator { RandomGenerator {
private val random = seed?.let { private val random = seed?.let {
RandomSource.create(source, seed) RandomSource.create(source, seed)
} ?: RandomSource.create(source) } ?: RandomSource.create(source)
override fun nextBoolean(): Boolean = random.nextBoolean() public override fun nextBoolean(): Boolean = random.nextBoolean()
override fun nextDouble(): Double = random.nextDouble() public override fun nextDouble(): Double = random.nextDouble()
override fun nextInt(): Int = random.nextInt() public override fun nextInt(): Int = random.nextInt()
override fun nextInt(until: Int): Int = random.nextInt(until) public override fun nextInt(until: Int): Int = random.nextInt(until)
override fun nextLong(): Long = random.nextLong() public override fun nextLong(): Long = random.nextLong()
override fun nextLong(until: Long): Long = random.nextLong(until) public override fun nextLong(until: Long): Long = random.nextLong(until)
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) { public override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
require(toIndex > fromIndex) require(toIndex > fromIndex)
random.nextBytes(array, fromIndex, toIndex - fromIndex) random.nextBytes(array, fromIndex, toIndex - fromIndex)
} }
@ -23,8 +24,8 @@ class RandomSourceGenerator(private val source: RandomSource, seed: Long?) :
override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong()) override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong())
} }
fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator = public fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
RandomSourceGenerator(source, seed) RandomSourceGenerator(source, seed)
fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator = public fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator =
fromSource(RandomSource.MT, seed) fromSource(RandomSource.MT, seed)

View File

@ -3,25 +3,24 @@ package kscience.kmath.prob
import kotlinx.coroutines.flow.take import kotlinx.coroutines.flow.take
import kotlinx.coroutines.flow.toList import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kscience.kmath.prob.samplers.GaussianSampler
import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
internal class CommonsDistributionsTest { internal class CommonsDistributionsTest {
@Test @Test
fun testNormalDistributionSuspend() { fun testNormalDistributionSuspend() {
val distribution = Distribution.normal(7.0, 2.0) val distribution = GaussianSampler.of(7.0, 2.0)
val generator = RandomGenerator.default(1) val generator = RandomGenerator.default(1)
val sample = runBlocking { val sample = runBlocking { distribution.sample(generator).take(1000).toList() }
distribution.sample(generator).take(1000).toList()
}
Assertions.assertEquals(7.0, sample.average(), 0.1) Assertions.assertEquals(7.0, sample.average(), 0.1)
} }
@Test @Test
fun testNormalDistributionBlocking() { fun testNormalDistributionBlocking() {
val distribution = Distribution.normal(7.0, 2.0) val distribution = GaussianSampler.of(7.0, 2.0)
val generator = RandomGenerator.default(1) val generator = RandomGenerator.default(1)
val sample = distribution.sample(generator).nextBlock(1000) val sample = runBlocking { distribution.sample(generator).blocking().nextBlock(1000) }
Assertions.assertEquals(7.0, sample.average(), 0.1) Assertions.assertEquals(7.0, sample.average(), 0.1)
} }
} }