Perform 1.4 and explicit API migrations, refactor blocking chain, make tests work
This commit is contained in:
parent
e35a364aa8
commit
53ebec2e01
@ -3,10 +3,7 @@ package kscience.kmath.chains
|
|||||||
/**
|
/**
|
||||||
* Performance optimized chain for integer values
|
* Performance optimized chain for integer values
|
||||||
*/
|
*/
|
||||||
public abstract class BlockingIntChain : Chain<Int> {
|
public interface BlockingIntChain : Chain<Int> {
|
||||||
public abstract fun nextInt(): Int
|
public override suspend fun next(): Int
|
||||||
|
public suspend fun nextBlock(size: Int): IntArray = IntArray(size) { next() }
|
||||||
override suspend fun next(): Int = nextInt()
|
|
||||||
|
|
||||||
public fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
|
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,7 @@ package kscience.kmath.chains
|
|||||||
/**
|
/**
|
||||||
* Performance optimized chain for real values
|
* Performance optimized chain for real values
|
||||||
*/
|
*/
|
||||||
public abstract class BlockingRealChain : Chain<Double> {
|
public interface BlockingRealChain : Chain<Double> {
|
||||||
public abstract fun nextDouble(): Double
|
public override suspend fun next(): Double
|
||||||
|
public suspend fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { next() }
|
||||||
override suspend fun next(): Double = nextDouble()
|
|
||||||
|
|
||||||
public fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
|
|
||||||
}
|
}
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
package scientifik.kmath.chains
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performance optimized chain for integer values
|
|
||||||
*/
|
|
||||||
abstract class BlockingIntChain : Chain<Int> {
|
|
||||||
suspend fun nextBlock(size: Int): IntArray = IntArray(size) { next() }
|
|
||||||
}
|
|
@ -1,8 +0,0 @@
|
|||||||
package scientifik.kmath.chains
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performance optimized chain for real values
|
|
||||||
*/
|
|
||||||
abstract class BlockingRealChain : Chain<Double> {
|
|
||||||
suspend fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { next() }
|
|
||||||
}
|
|
@ -1,5 +1,6 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.flow.first
|
||||||
import kscience.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import kscience.kmath.chains.collect
|
import kscience.kmath.chains.collect
|
||||||
import kscience.kmath.structures.Buffer
|
import kscience.kmath.structures.Buffer
|
||||||
@ -70,7 +71,7 @@ public fun <T : Any> Sampler<T>.sampleBuffer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator) = sample(generator).first()
|
public suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator): T = sample(generator).first()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a bunch of samples from real distributions
|
* Generate a bunch of samples from real distributions
|
||||||
|
@ -14,7 +14,7 @@ public interface NamedDistribution<T> : Distribution<Map<String, T>>
|
|||||||
public class FactorizedDistribution<T>(public val distributions: Collection<NamedDistribution<T>>) :
|
public class FactorizedDistribution<T>(public val distributions: Collection<NamedDistribution<T>>) :
|
||||||
NamedDistribution<T> {
|
NamedDistribution<T> {
|
||||||
override fun probability(arg: Map<String, T>): Double =
|
override fun probability(arg: Map<String, T>): Double =
|
||||||
distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
|
distributions.fold(1.0) { acc, dist -> acc * dist.probability(arg) }
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
||||||
val chains = distributions.map { it.sample(generator) }
|
val chains = distributions.map { it.sample(generator) }
|
||||||
|
@ -1,17 +1,22 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.prob
|
||||||
|
|
||||||
|
import kscience.kmath.chains.BlockingIntChain
|
||||||
|
import kscience.kmath.chains.BlockingRealChain
|
||||||
import kscience.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A possibly stateful chain producing random values.
|
* A possibly stateful chain producing random values.
|
||||||
|
*
|
||||||
|
* @property generator the underlying [RandomGenerator] instance.
|
||||||
*/
|
*/
|
||||||
public class RandomChain<out R>(
|
public class RandomChain<out R>(
|
||||||
public val generator: RandomGenerator,
|
public val generator: RandomGenerator,
|
||||||
private val gen: suspend RandomGenerator.() -> R
|
private val gen: suspend RandomGenerator.() -> R
|
||||||
) : Chain<R> {
|
) : Chain<R> {
|
||||||
override suspend fun next(): R = generator.gen()
|
override suspend fun next(): R = generator.gen()
|
||||||
|
|
||||||
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
public fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
||||||
|
public fun Chain<Double>.blocking(): BlockingRealChain = object : Chain<Double> by this, BlockingRealChain {}
|
||||||
|
public fun Chain<Int>.blocking(): BlockingIntChain = object : Chain<Int> by this, BlockingIntChain {}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import scientifik.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
@ -13,8 +13,8 @@ import kotlin.math.pow
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.html
|
||||||
*/
|
*/
|
||||||
class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sampler<Double> {
|
public class AhrensDieterExponentialSampler private constructor(public val mean: Double) : Sampler<Double> {
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||||
// Step 1:
|
// Step 1:
|
||||||
var a = 0.0
|
var a = 0.0
|
||||||
var u = nextDouble()
|
var u = nextDouble()
|
||||||
@ -47,7 +47,7 @@ class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sam
|
|||||||
|
|
||||||
override fun toString(): String = "Ahrens-Dieter Exponential deviate"
|
override fun toString(): String = "Ahrens-Dieter Exponential deviate"
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) }
|
private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) }
|
||||||
|
|
||||||
init {
|
init {
|
||||||
@ -64,7 +64,7 @@ class AhrensDieterExponentialSampler private constructor(val mean: Double) : Sam
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun of(mean: Double): AhrensDieterExponentialSampler {
|
public fun of(mean: Double): AhrensDieterExponentialSampler {
|
||||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||||
return AhrensDieterExponentialSampler(mean)
|
return AhrensDieterExponentialSampler(mean)
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import scientifik.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
import scientifik.kmath.prob.next
|
import kscience.kmath.prob.next
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -17,15 +17,12 @@ import kotlin.math.*
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.html
|
||||||
*/
|
*/
|
||||||
class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
public class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
||||||
alpha: Double,
|
alpha: Double,
|
||||||
theta: Double
|
theta: Double
|
||||||
) : Sampler<Double> {
|
) : Sampler<Double> {
|
||||||
private val delegate: BaseGammaSampler =
|
private val delegate: BaseGammaSampler =
|
||||||
if (alpha < 1) AhrensDieterGammaSampler(alpha, theta) else MarsagliaTsangGammaSampler(
|
if (alpha < 1) AhrensDieterGammaSampler(alpha, theta) else MarsagliaTsangGammaSampler(alpha, theta)
|
||||||
alpha,
|
|
||||||
theta
|
|
||||||
)
|
|
||||||
|
|
||||||
private abstract class BaseGammaSampler internal constructor(
|
private abstract class BaseGammaSampler internal constructor(
|
||||||
protected val alpha: Double,
|
protected val alpha: Double,
|
||||||
@ -39,7 +36,7 @@ class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
|||||||
override fun toString(): String = "Ahrens-Dieter-Marsaglia-Tsang Gamma deviate"
|
override fun toString(): String = "Ahrens-Dieter-Marsaglia-Tsang Gamma deviate"
|
||||||
}
|
}
|
||||||
|
|
||||||
private class AhrensDieterGammaSampler internal constructor(alpha: Double, theta: Double) :
|
private class AhrensDieterGammaSampler(alpha: Double, theta: Double) :
|
||||||
BaseGammaSampler(alpha, theta) {
|
BaseGammaSampler(alpha, theta) {
|
||||||
private val oneOverAlpha: Double = 1.0 / alpha
|
private val oneOverAlpha: Double = 1.0 / alpha
|
||||||
private val bGSOptim: Double = 1.0 + alpha / E
|
private val bGSOptim: Double = 1.0 + alpha / E
|
||||||
@ -75,7 +72,7 @@ class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class MarsagliaTsangGammaSampler internal constructor(alpha: Double, theta: Double) :
|
private class MarsagliaTsangGammaSampler(alpha: Double, theta: Double) :
|
||||||
BaseGammaSampler(alpha, theta) {
|
BaseGammaSampler(alpha, theta) {
|
||||||
private val dOptim: Double
|
private val dOptim: Double
|
||||||
private val cOptim: Double
|
private val cOptim: Double
|
||||||
@ -110,11 +107,11 @@ class AhrensDieterMarsagliaTsangGammaSampler private constructor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = delegate.sample(generator)
|
public override fun sample(generator: RandomGenerator): Chain<Double> = delegate.sample(generator)
|
||||||
override fun toString(): String = delegate.toString()
|
public override fun toString(): String = delegate.toString()
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
fun of(
|
public fun of(
|
||||||
alpha: Double,
|
alpha: Double,
|
||||||
theta: Double
|
theta: Double
|
||||||
): Sampler<Double> = AhrensDieterMarsagliaTsangGammaSampler(alpha, theta)
|
): Sampler<Double> = AhrensDieterMarsagliaTsangGammaSampler(alpha, theta)
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import scientifik.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
import kotlin.jvm.JvmOverloads
|
|
||||||
import kotlin.math.ceil
|
import kotlin.math.ceil
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
@ -36,7 +35,7 @@ import kotlin.math.min
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSampler.html
|
||||||
*/
|
*/
|
||||||
open class AliasMethodDiscreteSampler private constructor(
|
public open class AliasMethodDiscreteSampler private constructor(
|
||||||
// Deliberate direct storage of input arrays
|
// Deliberate direct storage of input arrays
|
||||||
protected val probability: LongArray,
|
protected val probability: LongArray,
|
||||||
protected val alias: IntArray
|
protected val alias: IntArray
|
||||||
@ -102,17 +101,16 @@ open class AliasMethodDiscreteSampler private constructor(
|
|||||||
if (generator.nextLong() ushr 11 < probability[j]) j else alias[j]
|
if (generator.nextLong() ushr 11 < probability[j]) j else alias[j]
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "Alias method"
|
public override fun toString(): String = "Alias method"
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
private const val DEFAULT_ALPHA = 0
|
private const val DEFAULT_ALPHA = 0
|
||||||
private const val ZERO = 0.0
|
private const val ZERO = 0.0
|
||||||
private const val ONE_AS_NUMERATOR = 1L shl 53
|
private const val ONE_AS_NUMERATOR = 1L shl 53
|
||||||
private const val CONVERT_TO_NUMERATOR: Double = ONE_AS_NUMERATOR.toDouble()
|
private const val CONVERT_TO_NUMERATOR: Double = ONE_AS_NUMERATOR.toDouble()
|
||||||
private const val MAX_SMALL_POWER_2_SIZE = 1 shl 11
|
private const val MAX_SMALL_POWER_2_SIZE = 1 shl 11
|
||||||
|
|
||||||
@JvmOverloads
|
public fun of(
|
||||||
fun of(
|
|
||||||
probabilities: DoubleArray,
|
probabilities: DoubleArray,
|
||||||
alpha: Int = DEFAULT_ALPHA
|
alpha: Int = DEFAULT_ALPHA
|
||||||
): Sampler<Int> {
|
): Sampler<Int> {
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import scientifik.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -13,10 +13,10 @@ import kotlin.math.*
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.html
|
||||||
*/
|
*/
|
||||||
class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
public class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
||||||
private var nextGaussian: Double = Double.NaN
|
private var nextGaussian: Double = Double.NaN
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||||
val random: Double
|
val random: Double
|
||||||
|
|
||||||
if (nextGaussian.isNaN()) {
|
if (nextGaussian.isNaN()) {
|
||||||
@ -40,9 +40,9 @@ class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGauss
|
|||||||
random
|
random
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "Box-Muller normalized Gaussian deviate"
|
public override fun toString(): String = "Box-Muller normalized Gaussian deviate"
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
|
public fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.chains.map
|
import kscience.kmath.chains.map
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sampling from a Gaussian distribution with given mean and standard deviation.
|
* Sampling from a Gaussian distribution with given mean and standard deviation.
|
||||||
@ -11,24 +11,24 @@ import scientifik.kmath.prob.Sampler
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html
|
||||||
*/
|
*/
|
||||||
class GaussianSampler private constructor(
|
public class GaussianSampler private constructor(
|
||||||
private val mean: Double,
|
private val mean: Double,
|
||||||
private val standardDeviation: Double,
|
private val standardDeviation: Double,
|
||||||
private val normalized: NormalizedGaussianSampler
|
private val normalized: NormalizedGaussianSampler
|
||||||
) : Sampler<Double> {
|
) : Sampler<Double> {
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = normalized
|
public override fun sample(generator: RandomGenerator): Chain<Double> = normalized
|
||||||
.sample(generator)
|
.sample(generator)
|
||||||
.map { standardDeviation * it + mean }
|
.map { standardDeviation * it + mean }
|
||||||
|
|
||||||
override fun toString(): String = "Gaussian deviate [$normalized]"
|
override fun toString(): String = "Gaussian deviate [$normalized]"
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
fun of(
|
public fun of(
|
||||||
mean: Double,
|
mean: Double,
|
||||||
standardDeviation: Double,
|
standardDeviation: Double,
|
||||||
normalized: NormalizedGaussianSampler
|
normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler.of()
|
||||||
): GaussianSampler {
|
): GaussianSampler {
|
||||||
require(standardDeviation > 0) { "standard deviation is not strictly positive: $standardDeviation" }
|
require(standardDeviation > 0.0) { "standard deviation is not strictly positive: $standardDeviation" }
|
||||||
|
|
||||||
return GaussianSampler(
|
return GaussianSampler(
|
||||||
mean,
|
mean,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import scientifik.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
import kotlin.math.exp
|
import kotlin.math.exp
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -18,11 +18,11 @@ import kotlin.math.exp
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.html
|
||||||
*/
|
*/
|
||||||
class KempSmallMeanPoissonSampler private constructor(
|
public class KempSmallMeanPoissonSampler private constructor(
|
||||||
private val p0: Double,
|
private val p0: Double,
|
||||||
private val mean: Double
|
private val mean: Double
|
||||||
) : Sampler<Int> {
|
) : Sampler<Int> {
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||||
// Note on the algorithm:
|
// Note on the algorithm:
|
||||||
// - X is the unknown sample deviate (the output of the algorithm)
|
// - X is the unknown sample deviate (the output of the algorithm)
|
||||||
// - x is the current value from the distribution
|
// - x is the current value from the distribution
|
||||||
@ -48,10 +48,10 @@ class KempSmallMeanPoissonSampler private constructor(
|
|||||||
x
|
x
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "Kemp Small Mean Poisson deviate"
|
public override fun toString(): String = "Kemp Small Mean Poisson deviate"
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
fun of(mean: Double): KempSmallMeanPoissonSampler {
|
public fun of(mean: Double): KempSmallMeanPoissonSampler {
|
||||||
require(mean > 0) { "Mean is not strictly positive: $mean" }
|
require(mean > 0) { "Mean is not strictly positive: $mean" }
|
||||||
val p0 = exp(-mean)
|
val p0 = exp(-mean)
|
||||||
// Probability must be positive. As mean increases then p(0) decreases.
|
// Probability must be positive. As mean increases then p(0) decreases.
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.chains.ConstantChain
|
import kscience.kmath.chains.ConstantChain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import scientifik.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
import scientifik.kmath.prob.next
|
import kscience.kmath.prob.next
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -19,7 +19,7 @@ import kotlin.math.*
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.html
|
||||||
*/
|
*/
|
||||||
class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<Int> {
|
public class LargeMeanPoissonSampler private constructor(public val mean: Double) : Sampler<Int> {
|
||||||
private val exponential: Sampler<Double> = AhrensDieterExponentialSampler.of(1.0)
|
private val exponential: Sampler<Double> = AhrensDieterExponentialSampler.of(1.0)
|
||||||
private val gaussian: Sampler<Double> = ZigguratNormalizedGaussianSampler.of()
|
private val gaussian: Sampler<Double> = ZigguratNormalizedGaussianSampler.of()
|
||||||
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG
|
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG
|
||||||
@ -41,7 +41,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
|
|||||||
else // Not used.
|
else // Not used.
|
||||||
KempSmallMeanPoissonSampler.of(mean - lambda)
|
KempSmallMeanPoissonSampler.of(mean - lambda)
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||||
// This will never be null. It may be a no-op delegate that returns zero.
|
// This will never be null. It may be a no-op delegate that returns zero.
|
||||||
val y2 = smallMeanPoissonSampler.next(generator)
|
val y2 = smallMeanPoissonSampler.next(generator)
|
||||||
var x: Double
|
var x: Double
|
||||||
@ -114,7 +114,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
|
|||||||
|
|
||||||
override fun toString(): String = "Large Mean Poisson deviate"
|
override fun toString(): String = "Large Mean Poisson deviate"
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
|
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
|
||||||
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
|
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ class LargeMeanPoissonSampler private constructor(val mean: Double) : Sampler<In
|
|||||||
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
|
override fun sample(generator: RandomGenerator): Chain<Int> = ConstantChain(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun of(mean: Double): LargeMeanPoissonSampler {
|
public fun of(mean: Double): LargeMeanPoissonSampler {
|
||||||
require(mean >= 1) { "mean is not >= 1: $mean" }
|
require(mean >= 1) { "mean is not >= 1: $mean" }
|
||||||
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
// The algorithm is not valid if Math.floor(mean) is not an integer.
|
||||||
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import scientifik.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
@ -15,10 +15,10 @@ import kotlin.math.sqrt
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.html
|
||||||
*/
|
*/
|
||||||
class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
public class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
|
||||||
private var nextGaussian = Double.NaN
|
private var nextGaussian = Double.NaN
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
|
||||||
if (nextGaussian.isNaN()) {
|
if (nextGaussian.isNaN()) {
|
||||||
val alpha: Double
|
val alpha: Double
|
||||||
var x: Double
|
var x: Double
|
||||||
@ -53,9 +53,9 @@ class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGauss
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
|
public override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate"
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler()
|
public fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Marker interface for a sampler that generates values from an N(0,1)
|
* Marker interface for a sampler that generates values from an N(0,1)
|
||||||
* [Gaussian distribution](https://en.wikipedia.org/wiki/Normal_distribution).
|
* [Gaussian distribution](https://en.wikipedia.org/wiki/Normal_distribution).
|
||||||
*/
|
*/
|
||||||
interface NormalizedGaussianSampler : Sampler<Double>
|
public interface NormalizedGaussianSampler : Sampler<Double>
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sampler for the Poisson distribution.
|
* Sampler for the Poisson distribution.
|
||||||
@ -16,17 +16,15 @@ import scientifik.kmath.prob.Sampler
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/PoissonSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/PoissonSampler.html
|
||||||
*/
|
*/
|
||||||
class PoissonSampler private constructor(
|
public class PoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
||||||
mean: Double
|
|
||||||
) : Sampler<Int> {
|
|
||||||
private val poissonSamplerDelegate: Sampler<Int> = of(mean)
|
private val poissonSamplerDelegate: Sampler<Int> = of(mean)
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
|
public override fun sample(generator: RandomGenerator): Chain<Int> = poissonSamplerDelegate.sample(generator)
|
||||||
override fun toString(): String = poissonSamplerDelegate.toString()
|
public override fun toString(): String = poissonSamplerDelegate.toString()
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
private const val PIVOT = 40.0
|
private const val PIVOT = 40.0
|
||||||
|
|
||||||
fun of(mean: Double) =// Each sampler should check the input arguments.
|
public fun of(mean: Double): Sampler<Int> =// Each sampler should check the input arguments.
|
||||||
if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean)
|
if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import scientifik.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
import kotlin.math.ceil
|
import kotlin.math.ceil
|
||||||
import kotlin.math.exp
|
import kotlin.math.exp
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ import kotlin.math.exp
|
|||||||
*
|
*
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.html
|
||||||
*/
|
*/
|
||||||
class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
public class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
||||||
private val p0: Double = exp(-mean)
|
private val p0: Double = exp(-mean)
|
||||||
|
|
||||||
private val limit: Int = (if (p0 > 0)
|
private val limit: Int = (if (p0 > 0)
|
||||||
@ -27,7 +27,7 @@ class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
|||||||
else
|
else
|
||||||
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
|
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt()
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
|
||||||
var n = 0
|
var n = 0
|
||||||
var r = 1.0
|
var r = 1.0
|
||||||
|
|
||||||
@ -39,10 +39,10 @@ class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler<Int> {
|
|||||||
n
|
n
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "Small Mean Poisson deviate"
|
public override fun toString(): String = "Small Mean Poisson deviate"
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
fun of(mean: Double): SmallMeanPoissonSampler {
|
public fun of(mean: Double): SmallMeanPoissonSampler {
|
||||||
require(mean > 0) { "mean is not strictly positive: $mean" }
|
require(mean > 0) { "mean is not strictly positive: $mean" }
|
||||||
return SmallMeanPoissonSampler(mean)
|
return SmallMeanPoissonSampler(mean)
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
package scientifik.kmath.prob.samplers
|
package kscience.kmath.prob.samplers
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.Sampler
|
import kscience.kmath.prob.Sampler
|
||||||
import scientifik.kmath.prob.chain
|
import kscience.kmath.prob.chain
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -14,7 +14,7 @@ import kotlin.math.*
|
|||||||
* Based on Commons RNG implementation.
|
* Based on Commons RNG implementation.
|
||||||
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.html
|
* See https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.html
|
||||||
*/
|
*/
|
||||||
class ZigguratNormalizedGaussianSampler private constructor() :
|
public class ZigguratNormalizedGaussianSampler private constructor() :
|
||||||
NormalizedGaussianSampler, Sampler<Double> {
|
NormalizedGaussianSampler, Sampler<Double> {
|
||||||
|
|
||||||
private fun sampleOne(generator: RandomGenerator): Double {
|
private fun sampleOne(generator: RandomGenerator): Double {
|
||||||
@ -23,9 +23,8 @@ class ZigguratNormalizedGaussianSampler private constructor() :
|
|||||||
return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i)
|
return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { sampleOne(this) }
|
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { sampleOne(this) }
|
||||||
|
public override fun toString(): String = "Ziggurat normalized Gaussian deviate"
|
||||||
override fun toString(): String = "Ziggurat normalized Gaussian deviate"
|
|
||||||
|
|
||||||
private fun fix(
|
private fun fix(
|
||||||
generator: RandomGenerator,
|
generator: RandomGenerator,
|
||||||
@ -59,7 +58,7 @@ class ZigguratNormalizedGaussianSampler private constructor() :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
public companion object {
|
||||||
private const val R: Double = 3.442619855899
|
private const val R: Double = 3.442619855899
|
||||||
private const val ONE_OVER_R: Double = 1 / R
|
private const val ONE_OVER_R: Double = 1 / R
|
||||||
private const val V: Double = 9.91256303526217e-3
|
private const val V: Double = 9.91256303526217e-3
|
||||||
@ -94,7 +93,7 @@ class ZigguratNormalizedGaussianSampler private constructor() :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler()
|
public fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler()
|
||||||
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,46 +0,0 @@
|
|||||||
package scientifik.kmath.prob
|
|
||||||
|
|
||||||
import scientifik.kmath.chains.Chain
|
|
||||||
import scientifik.kmath.chains.SimpleChain
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A multivariate distribution which takes a map of parameters
|
|
||||||
*/
|
|
||||||
interface NamedDistribution<T> : Distribution<Map<String, T>>
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A multivariate distribution that has independent distributions for separate axis
|
|
||||||
*/
|
|
||||||
class FactorizedDistribution<T>(val distributions: Collection<NamedDistribution<T>>) : NamedDistribution<T> {
|
|
||||||
override fun probability(arg: Map<String, T>): Double {
|
|
||||||
return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
|
||||||
val chains = distributions.map { it.sample(generator) }
|
|
||||||
return SimpleChain<Map<String, T>> {
|
|
||||||
chains.fold(emptyMap()) { acc, chain -> acc + chain.next() }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class NamedDistributionWrapper<T : Any>(val name: String, val distribution: Distribution<T>) : NamedDistribution<T> {
|
|
||||||
override fun probability(arg: Map<String, T>): Double = distribution.probability(
|
|
||||||
arg[name] ?: error("Argument with name $name not found in input parameters")
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
|
||||||
val chain = distribution.sample(generator)
|
|
||||||
return SimpleChain {
|
|
||||||
mapOf(name to chain.next())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class DistributionBuilder<T: Any>{
|
|
||||||
private val distributions = ArrayList<NamedDistribution<T>>()
|
|
||||||
|
|
||||||
infix fun String.to(distribution: Distribution<T>){
|
|
||||||
distributions.add(NamedDistributionWrapper(this,distribution))
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,30 +0,0 @@
|
|||||||
package scientifik.kmath.prob
|
|
||||||
|
|
||||||
import scientifik.kmath.chains.BlockingIntChain
|
|
||||||
import scientifik.kmath.chains.BlockingRealChain
|
|
||||||
import scientifik.kmath.chains.Chain
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A possibly stateful chain producing random values.
|
|
||||||
*/
|
|
||||||
class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspend RandomGenerator.() -> R) : Chain<R> {
|
|
||||||
override suspend fun next(): R = generator.gen()
|
|
||||||
|
|
||||||
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
|
||||||
|
|
||||||
fun RandomChain<Double>.blocking(): BlockingRealChain = let {
|
|
||||||
object : BlockingRealChain() {
|
|
||||||
override suspend fun next(): Double = it.next()
|
|
||||||
override fun fork(): Chain<Double> = it.fork()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun RandomChain<Int>.blocking(): BlockingIntChain = let {
|
|
||||||
object : BlockingIntChain() {
|
|
||||||
override suspend fun next(): Int = it.next()
|
|
||||||
override fun fork(): Chain<Int> = it.fork()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,47 +0,0 @@
|
|||||||
package scientifik.kmath.prob
|
|
||||||
|
|
||||||
import kotlin.random.Random
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A basic generator
|
|
||||||
*/
|
|
||||||
interface RandomGenerator {
|
|
||||||
fun nextBoolean(): Boolean
|
|
||||||
fun nextDouble(): Double
|
|
||||||
fun nextInt(): Int
|
|
||||||
fun nextInt(until: Int): Int
|
|
||||||
fun nextLong(): Long
|
|
||||||
fun nextLong(until: Long): Long
|
|
||||||
fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size)
|
|
||||||
fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a new generator which is independent from current generator (operations on new generator do not affect this one
|
|
||||||
* and vise versa). The statistical properties of new generator should be the same as for this one.
|
|
||||||
* For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run.
|
|
||||||
*
|
|
||||||
* The thread safety of this operation is not guaranteed since it could affect the state of the generator.
|
|
||||||
*/
|
|
||||||
fun fork(): RandomGenerator
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
val default by lazy { DefaultGenerator() }
|
|
||||||
fun default(seed: Long) = DefaultGenerator(Random(seed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline class DefaultGenerator(private val random: Random = Random) : RandomGenerator {
|
|
||||||
override fun nextBoolean(): Boolean = random.nextBoolean()
|
|
||||||
override fun nextDouble(): Double = random.nextDouble()
|
|
||||||
override fun nextInt(): Int = random.nextInt()
|
|
||||||
override fun nextInt(until: Int): Int = random.nextInt(until)
|
|
||||||
override fun nextLong(): Long = random.nextLong()
|
|
||||||
override fun nextLong(until: Long): Long = random.nextLong(until)
|
|
||||||
|
|
||||||
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
|
||||||
random.nextBytes(array, fromIndex, toIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
|
|
||||||
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
|
|
||||||
}
|
|
@ -1,21 +1,22 @@
|
|||||||
package scientifik.kmath.prob
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kscience.kmath.prob.RandomGenerator
|
||||||
import org.apache.commons.rng.simple.RandomSource
|
import org.apache.commons.rng.simple.RandomSource
|
||||||
|
|
||||||
class RandomSourceGenerator(private val source: RandomSource, seed: Long?) :
|
public class RandomSourceGenerator(private val source: RandomSource, seed: Long?) :
|
||||||
RandomGenerator {
|
RandomGenerator {
|
||||||
private val random = seed?.let {
|
private val random = seed?.let {
|
||||||
RandomSource.create(source, seed)
|
RandomSource.create(source, seed)
|
||||||
} ?: RandomSource.create(source)
|
} ?: RandomSource.create(source)
|
||||||
|
|
||||||
override fun nextBoolean(): Boolean = random.nextBoolean()
|
public override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||||
override fun nextDouble(): Double = random.nextDouble()
|
public override fun nextDouble(): Double = random.nextDouble()
|
||||||
override fun nextInt(): Int = random.nextInt()
|
public override fun nextInt(): Int = random.nextInt()
|
||||||
override fun nextInt(until: Int): Int = random.nextInt(until)
|
public override fun nextInt(until: Int): Int = random.nextInt(until)
|
||||||
override fun nextLong(): Long = random.nextLong()
|
public override fun nextLong(): Long = random.nextLong()
|
||||||
override fun nextLong(until: Long): Long = random.nextLong(until)
|
public override fun nextLong(until: Long): Long = random.nextLong(until)
|
||||||
|
|
||||||
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
public override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
||||||
require(toIndex > fromIndex)
|
require(toIndex > fromIndex)
|
||||||
random.nextBytes(array, fromIndex, toIndex - fromIndex)
|
random.nextBytes(array, fromIndex, toIndex - fromIndex)
|
||||||
}
|
}
|
||||||
@ -23,8 +24,8 @@ class RandomSourceGenerator(private val source: RandomSource, seed: Long?) :
|
|||||||
override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong())
|
override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong())
|
||||||
}
|
}
|
||||||
|
|
||||||
fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
|
public fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
|
||||||
RandomSourceGenerator(source, seed)
|
RandomSourceGenerator(source, seed)
|
||||||
|
|
||||||
fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator =
|
public fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator =
|
||||||
fromSource(RandomSource.MT, seed)
|
fromSource(RandomSource.MT, seed)
|
||||||
|
@ -3,25 +3,24 @@ package kscience.kmath.prob
|
|||||||
import kotlinx.coroutines.flow.take
|
import kotlinx.coroutines.flow.take
|
||||||
import kotlinx.coroutines.flow.toList
|
import kotlinx.coroutines.flow.toList
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
|
import kscience.kmath.prob.samplers.GaussianSampler
|
||||||
import org.junit.jupiter.api.Assertions
|
import org.junit.jupiter.api.Assertions
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
|
|
||||||
internal class CommonsDistributionsTest {
|
internal class CommonsDistributionsTest {
|
||||||
@Test
|
@Test
|
||||||
fun testNormalDistributionSuspend() {
|
fun testNormalDistributionSuspend() {
|
||||||
val distribution = Distribution.normal(7.0, 2.0)
|
val distribution = GaussianSampler.of(7.0, 2.0)
|
||||||
val generator = RandomGenerator.default(1)
|
val generator = RandomGenerator.default(1)
|
||||||
val sample = runBlocking {
|
val sample = runBlocking { distribution.sample(generator).take(1000).toList() }
|
||||||
distribution.sample(generator).take(1000).toList()
|
|
||||||
}
|
|
||||||
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testNormalDistributionBlocking() {
|
fun testNormalDistributionBlocking() {
|
||||||
val distribution = Distribution.normal(7.0, 2.0)
|
val distribution = GaussianSampler.of(7.0, 2.0)
|
||||||
val generator = RandomGenerator.default(1)
|
val generator = RandomGenerator.default(1)
|
||||||
val sample = distribution.sample(generator).nextBlock(1000)
|
val sample = runBlocking { distribution.sample(generator).blocking().nextBlock(1000) }
|
||||||
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user