Moved probability distributions to commons-rng and to prob module
This commit is contained in:
parent
c15f77acef
commit
048a1ceaed
@ -2,7 +2,7 @@ plugins {
|
|||||||
id("scientifik.publish") apply false
|
id("scientifik.publish") apply false
|
||||||
}
|
}
|
||||||
|
|
||||||
val kmathVersion by extra("0.1.4-dev-6")
|
val kmathVersion by extra("0.1.4-dev-7")
|
||||||
|
|
||||||
val bintrayRepo by extra("scientifik")
|
val bintrayRepo by extra("scientifik")
|
||||||
val githubProject by extra("kmath")
|
val githubProject by extra("kmath")
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
package scientifik.kmath.commons.prob
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator
|
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator as CMRandom
|
|
||||||
|
|
||||||
inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator {
|
|
||||||
override fun nextDouble(): Double = generator.nextDouble()
|
|
||||||
|
|
||||||
override fun nextInt(): Int = generator.nextInt()
|
|
||||||
|
|
||||||
override fun nextLong(): Long = generator.nextLong()
|
|
||||||
|
|
||||||
override fun nextBlock(size: Int): ByteArray = ByteArray(size).apply { generator.nextBytes(this) }
|
|
||||||
|
|
||||||
override fun fork(): RandomGenerator {
|
|
||||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun CMRandom.asKmathGenerator(): RandomGenerator = CMRandomGeneratorWrapper(this)
|
|
||||||
|
|
||||||
fun RandomGenerator.asCMGenerator(): CMRandom =
|
|
||||||
(this as? CMRandomGeneratorWrapper)?.generator ?: TODO("Implement reverse CM wrapper")
|
|
||||||
|
|
||||||
val RandomGenerator.Companion.default: RandomGenerator by lazy { JDKRandomGenerator().asKmathGenerator() }
|
|
||||||
|
|
||||||
fun RandomGenerator.Companion.jdk(seed: Int? = null): RandomGenerator = if (seed == null) {
|
|
||||||
JDKRandomGenerator()
|
|
||||||
} else {
|
|
||||||
JDKRandomGenerator(seed)
|
|
||||||
}.asKmathGenerator()
|
|
@ -1,82 +0,0 @@
|
|||||||
package scientifik.kmath.commons.prob
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.*
|
|
||||||
import scientifik.kmath.prob.Distribution
|
|
||||||
import scientifik.kmath.prob.RandomChain
|
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
|
||||||
import scientifik.kmath.prob.UnivariateDistribution
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator as CMRandom
|
|
||||||
|
|
||||||
class CMRealDistributionWrapper(val builder: (CMRandom?) -> RealDistribution) : UnivariateDistribution<Double> {
|
|
||||||
|
|
||||||
private val defaultDistribution by lazy { builder(null) }
|
|
||||||
|
|
||||||
override fun probability(arg: Double): Double = defaultDistribution.probability(arg)
|
|
||||||
|
|
||||||
override fun cumulative(arg: Double): Double = defaultDistribution.cumulativeProbability(arg)
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): RandomChain<Double> {
|
|
||||||
val distribution = builder(generator.asCMGenerator())
|
|
||||||
return RandomChain(generator) { distribution.sample() }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class CMIntDistributionWrapper(val builder: (CMRandom?) -> IntegerDistribution) : UnivariateDistribution<Int> {
|
|
||||||
|
|
||||||
private val defaultDistribution by lazy { builder(null) }
|
|
||||||
|
|
||||||
override fun probability(arg: Int): Double = defaultDistribution.probability(arg)
|
|
||||||
|
|
||||||
override fun cumulative(arg: Int): Double = defaultDistribution.cumulativeProbability(arg)
|
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): RandomChain<Int> {
|
|
||||||
val distribution = builder(generator.asCMGenerator())
|
|
||||||
return RandomChain(generator) { distribution.sample() }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
fun Distribution.Companion.normal(mean: Double = 0.0, sigma: Double = 1.0): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator -> NormalDistribution(generator, mean, sigma) }
|
|
||||||
|
|
||||||
fun Distribution.Companion.poisson(mean: Double): UnivariateDistribution<Int> = CMIntDistributionWrapper { generator ->
|
|
||||||
PoissonDistribution(
|
|
||||||
generator,
|
|
||||||
mean,
|
|
||||||
PoissonDistribution.DEFAULT_EPSILON,
|
|
||||||
PoissonDistribution.DEFAULT_MAX_ITERATIONS
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.binomial(trials: Int, p: Double): UnivariateDistribution<Int> =
|
|
||||||
CMIntDistributionWrapper { generator ->
|
|
||||||
BinomialDistribution(generator, trials, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.student(degreesOfFreedom: Double): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
TDistribution(generator, degreesOfFreedom, TDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.chi2(degreesOfFreedom: Double): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
ChiSquaredDistribution(generator, degreesOfFreedom)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.fisher(
|
|
||||||
numeratorDegreesOfFreedom: Double,
|
|
||||||
denominatorDegreesOfFreedom: Double
|
|
||||||
): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
FDistribution(generator, numeratorDegreesOfFreedom, denominatorDegreesOfFreedom)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.exponential(mean: Double): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
ExponentialDistribution(generator, mean)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Distribution.Companion.uniform(a: Double, b: Double): UnivariateDistribution<Double> =
|
|
||||||
CMRealDistributionWrapper { generator ->
|
|
||||||
UniformRealDistribution(generator, a, b)
|
|
||||||
}
|
|
@ -38,9 +38,13 @@ interface Chain<out R>: Flow<R> {
|
|||||||
*/
|
*/
|
||||||
fun fork(): Chain<R>
|
fun fork(): Chain<R>
|
||||||
|
|
||||||
@InternalCoroutinesApi
|
@OptIn(InternalCoroutinesApi::class)
|
||||||
override suspend fun collect(collector: FlowCollector<R>) {
|
override suspend fun collect(collector: FlowCollector<R>) {
|
||||||
kotlinx.coroutines.flow.flow { while (true) emit(next()) }.collect(collector)
|
kotlinx.coroutines.flow.flow {
|
||||||
|
while (true){
|
||||||
|
emit(next())
|
||||||
|
}
|
||||||
|
}.collect(collector)
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object
|
companion object
|
||||||
|
@ -69,3 +69,9 @@ inline fun Memory.write(block: MemoryWriter.() -> Unit) {
|
|||||||
* Allocate the most effective platform-specific memory
|
* Allocate the most effective platform-specific memory
|
||||||
*/
|
*/
|
||||||
expect fun Memory.Companion.allocate(length: Int): Memory
|
expect fun Memory.Companion.allocate(length: Int): Memory
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wrap a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied
|
||||||
|
* and could be mutated independently from the resulting [Memory]
|
||||||
|
*/
|
||||||
|
expect fun Memory.Companion.wrap(array: ByteArray): Memory
|
||||||
|
@ -2,14 +2,7 @@ package scientifik.memory
|
|||||||
|
|
||||||
import org.khronos.webgl.ArrayBuffer
|
import org.khronos.webgl.ArrayBuffer
|
||||||
import org.khronos.webgl.DataView
|
import org.khronos.webgl.DataView
|
||||||
|
import org.khronos.webgl.Int8Array
|
||||||
/**
|
|
||||||
* Allocate the most effective platform-specific memory
|
|
||||||
*/
|
|
||||||
actual fun Memory.Companion.allocate(length: Int): Memory {
|
|
||||||
val buffer = ArrayBuffer(length)
|
|
||||||
return DataViewMemory(DataView(buffer, 0, length))
|
|
||||||
}
|
|
||||||
|
|
||||||
class DataViewMemory(val view: DataView) : Memory {
|
class DataViewMemory(val view: DataView) : Memory {
|
||||||
|
|
||||||
@ -89,3 +82,16 @@ class DataViewMemory(val view: DataView) : Memory {
|
|||||||
override fun writer(): MemoryWriter = writer
|
override fun writer(): MemoryWriter = writer
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allocate the most effective platform-specific memory
|
||||||
|
*/
|
||||||
|
actual fun Memory.Companion.allocate(length: Int): Memory {
|
||||||
|
val buffer = ArrayBuffer(length)
|
||||||
|
return DataViewMemory(DataView(buffer, 0, length))
|
||||||
|
}
|
||||||
|
|
||||||
|
actual fun Memory.Companion.wrap(array: ByteArray): Memory {
|
||||||
|
@Suppress("CAST_NEVER_SUCCEEDS") val int8Array = array as Int8Array
|
||||||
|
return DataViewMemory(DataView(int8Array.buffer, int8Array.byteOffset, int8Array.length))
|
||||||
|
}
|
@ -7,14 +7,6 @@ import java.nio.file.Path
|
|||||||
import java.nio.file.StandardOpenOption
|
import java.nio.file.StandardOpenOption
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Allocate the most effective platform-specific memory
|
|
||||||
*/
|
|
||||||
actual fun Memory.Companion.allocate(length: Int): Memory {
|
|
||||||
val buffer = ByteBuffer.allocate(length)
|
|
||||||
return ByteBufferMemory(buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
private class ByteBufferMemory(
|
private class ByteBufferMemory(
|
||||||
val buffer: ByteBuffer,
|
val buffer: ByteBuffer,
|
||||||
val startOffset: Int = 0,
|
val startOffset: Int = 0,
|
||||||
@ -96,6 +88,22 @@ private class ByteBufferMemory(
|
|||||||
override fun writer(): MemoryWriter = writer
|
override fun writer(): MemoryWriter = writer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allocate the most effective platform-specific memory
|
||||||
|
*/
|
||||||
|
actual fun Memory.Companion.allocate(length: Int): Memory {
|
||||||
|
val buffer = ByteBuffer.allocate(length)
|
||||||
|
return ByteBufferMemory(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual fun Memory.Companion.wrap(array: ByteArray): Memory {
|
||||||
|
val buffer = ByteBuffer.wrap(array)
|
||||||
|
return ByteBufferMemory(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun ByteBuffer.asMemory(startOffset: Int = 0, size: Int = limit()): Memory =
|
||||||
|
ByteBufferMemory(this, startOffset, size)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Use direct memory-mapped buffer from file to read something and close it afterwards.
|
* Use direct memory-mapped buffer from file to read something and close it afterwards.
|
||||||
*/
|
*/
|
||||||
|
@ -8,4 +8,10 @@ kotlin.sourceSets {
|
|||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
jvmMain{
|
||||||
|
dependencies{
|
||||||
|
api("org.apache.commons:commons-rng-sampling:1.3")
|
||||||
|
api("org.apache.commons:commons-rng-simple:1.3")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -6,10 +6,16 @@ import kotlin.random.Random
|
|||||||
* A basic generator
|
* A basic generator
|
||||||
*/
|
*/
|
||||||
interface RandomGenerator {
|
interface RandomGenerator {
|
||||||
|
fun nextBoolean(): Boolean
|
||||||
|
|
||||||
fun nextDouble(): Double
|
fun nextDouble(): Double
|
||||||
fun nextInt(): Int
|
fun nextInt(): Int
|
||||||
|
fun nextInt(until: Int): Int
|
||||||
fun nextLong(): Long
|
fun nextLong(): Long
|
||||||
fun nextBlock(size: Int): ByteArray
|
fun nextLong(until: Long): Long
|
||||||
|
|
||||||
|
fun fillBytes(array: ByteArray, fromIndex: Int = 0, toIndex: Int = array.size)
|
||||||
|
fun nextBytes(size: Int): ByteArray = ByteArray(size).also { fillBytes(it) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new generator which is independent from current generator (operations on new generator do not affect this one
|
* Create a new generator which is independent from current generator (operations on new generator do not affect this one
|
||||||
@ -21,21 +27,29 @@ interface RandomGenerator {
|
|||||||
fun fork(): RandomGenerator
|
fun fork(): RandomGenerator
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
val default by lazy { DefaultGenerator(Random.nextLong()) }
|
val default by lazy { DefaultGenerator() }
|
||||||
|
|
||||||
|
fun default(seed: Long) = DefaultGenerator(Random(seed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class DefaultGenerator(seed: Long?) : RandomGenerator {
|
inline class DefaultGenerator(val random: Random = Random) : RandomGenerator {
|
||||||
private val random = seed?.let { Random(it) } ?: Random
|
override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||||
|
|
||||||
override fun nextDouble(): Double = random.nextDouble()
|
override fun nextDouble(): Double = random.nextDouble()
|
||||||
|
|
||||||
override fun nextInt(): Int = random.nextInt()
|
override fun nextInt(): Int = random.nextInt()
|
||||||
|
override fun nextInt(until: Int): Int = random.nextInt(until)
|
||||||
|
|
||||||
override fun nextLong(): Long = random.nextLong()
|
override fun nextLong(): Long = random.nextLong()
|
||||||
|
|
||||||
override fun nextBlock(size: Int): ByteArray = random.nextBytes(size)
|
override fun nextLong(until: Long): Long = random.nextLong(until)
|
||||||
|
|
||||||
override fun fork(): RandomGenerator = DefaultGenerator(nextLong())
|
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
||||||
|
random.nextBytes(array, fromIndex, toIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
|
||||||
|
|
||||||
|
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
|
||||||
}
|
}
|
@ -29,3 +29,6 @@ class UniformDistribution(val range: ClosedFloatingPointRange<Double>) : Univari
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.uniform(range: ClosedFloatingPointRange<Double>): UniformDistribution =
|
||||||
|
UniformDistribution(range)
|
@ -0,0 +1,61 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import org.apache.commons.rng.UniformRandomProvider
|
||||||
|
import org.apache.commons.rng.simple.RandomSource
|
||||||
|
|
||||||
|
class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator {
|
||||||
|
internal val random: UniformRandomProvider = seed?.let {
|
||||||
|
RandomSource.create(source, seed, null)
|
||||||
|
} ?: RandomSource.create(source)
|
||||||
|
|
||||||
|
override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||||
|
|
||||||
|
override fun nextDouble(): Double = random.nextDouble()
|
||||||
|
|
||||||
|
override fun nextInt(): Int = random.nextInt()
|
||||||
|
override fun nextInt(until: Int): Int = random.nextInt(until)
|
||||||
|
|
||||||
|
override fun nextLong(): Long = random.nextLong()
|
||||||
|
override fun nextLong(until: Long): Long = random.nextLong(until)
|
||||||
|
|
||||||
|
override fun fillBytes(array: ByteArray, fromIndex: Int, toIndex: Int) {
|
||||||
|
require(toIndex > fromIndex)
|
||||||
|
random.nextBytes(array, fromIndex, toIndex - fromIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong())
|
||||||
|
}
|
||||||
|
|
||||||
|
inline class RandomGeneratorProvider(val generator: RandomGenerator) : UniformRandomProvider {
|
||||||
|
override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||||
|
|
||||||
|
override fun nextFloat(): Float = generator.nextDouble().toFloat()
|
||||||
|
|
||||||
|
override fun nextBytes(bytes: ByteArray) {
|
||||||
|
generator.fillBytes(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun nextBytes(bytes: ByteArray, start: Int, len: Int) {
|
||||||
|
generator.fillBytes(bytes, start, start + len)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun nextInt(): Int = generator.nextInt()
|
||||||
|
|
||||||
|
override fun nextInt(n: Int): Int = generator.nextInt(n)
|
||||||
|
|
||||||
|
override fun nextDouble(): Double = generator.nextDouble()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = generator.nextLong()
|
||||||
|
|
||||||
|
override fun nextLong(n: Long): Long = generator.nextLong(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent this [RandomGenerator] as commons-rng [UniformRandomProvider] preserving and mirroring its current state.
|
||||||
|
* Getting new value from one of those changes the state of another.
|
||||||
|
*/
|
||||||
|
fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this is RandomSourceGenerator) {
|
||||||
|
random
|
||||||
|
} else {
|
||||||
|
RandomGeneratorProvider(this)
|
||||||
|
}
|
@ -0,0 +1,107 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import org.apache.commons.rng.UniformRandomProvider
|
||||||
|
import org.apache.commons.rng.sampling.distribution.*
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import java.util.*
|
||||||
|
import kotlin.math.PI
|
||||||
|
import kotlin.math.exp
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
abstract class ContinuousSamplerDistribution : Distribution<Double> {
|
||||||
|
|
||||||
|
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain<Double> {
|
||||||
|
private val sampler = buildSampler(generator)
|
||||||
|
|
||||||
|
override suspend fun next(): Double = sampler.sample()
|
||||||
|
|
||||||
|
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract fun buildSampler(generator: RandomGenerator): ContinuousSampler
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Double> = ContinuousSamplerChain(generator)
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract class DiscreteSamplerDistribution : Distribution<Int> {
|
||||||
|
|
||||||
|
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain<Int> {
|
||||||
|
private val sampler = buildSampler(generator)
|
||||||
|
|
||||||
|
override suspend fun next(): Int = sampler.sample()
|
||||||
|
|
||||||
|
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Int> = ContinuousSamplerChain(generator)
|
||||||
|
}
|
||||||
|
|
||||||
|
enum class NormalSamplerMethod {
|
||||||
|
BoxMuller,
|
||||||
|
Marsaglia,
|
||||||
|
Ziggurat
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomProvider): NormalizedGaussianSampler =
|
||||||
|
when (method) {
|
||||||
|
NormalSamplerMethod.BoxMuller -> BoxMullerNormalizedGaussianSampler(provider)
|
||||||
|
NormalSamplerMethod.Marsaglia -> MarsagliaNormalizedGaussianSampler(provider)
|
||||||
|
NormalSamplerMethod.Ziggurat -> ZigguratNormalizedGaussianSampler(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.normal(
|
||||||
|
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
||||||
|
): Distribution<Double> = object : ContinuousSamplerDistribution() {
|
||||||
|
override fun buildSampler(generator: RandomGenerator): ContinuousSampler {
|
||||||
|
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
||||||
|
return normalSampler(method, provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun probability(arg: Double): Double {
|
||||||
|
return exp(-arg.pow(2) / 2) / sqrt(PI * 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.normal(
|
||||||
|
mean: Double,
|
||||||
|
sigma: Double,
|
||||||
|
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
||||||
|
): Distribution<Double> = object : ContinuousSamplerDistribution() {
|
||||||
|
private val sigma2 = sigma.pow(2)
|
||||||
|
private val norm = sigma * sqrt(PI * 2)
|
||||||
|
|
||||||
|
override fun buildSampler(generator: RandomGenerator): ContinuousSampler {
|
||||||
|
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
||||||
|
val normalizedSampler = normalSampler(method, provider)
|
||||||
|
return GaussianSampler(normalizedSampler, mean, sigma)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun probability(arg: Double): Double {
|
||||||
|
return exp(-(arg - mean).pow(2) / 2 / sigma2) / norm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun Distribution.Companion.poisson(
|
||||||
|
lambda: Double
|
||||||
|
): Distribution<Int> = object : DiscreteSamplerDistribution() {
|
||||||
|
|
||||||
|
override fun buildSampler(generator: RandomGenerator): DiscreteSampler {
|
||||||
|
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val computedProb: HashMap<Int, Double> = hashMapOf(0 to exp(-lambda))
|
||||||
|
|
||||||
|
override fun probability(arg: Int): Double {
|
||||||
|
require(arg >= 0) { "The argument must be >= 0" }
|
||||||
|
return if (arg > 40) {
|
||||||
|
exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda)
|
||||||
|
} else {
|
||||||
|
computedProb.getOrPut(arg) {
|
||||||
|
probability(arg - 1) * lambda / arg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.flow.take
|
||||||
|
import kotlinx.coroutines.flow.toList
|
||||||
|
import kotlinx.coroutines.runBlocking
|
||||||
|
import org.junit.jupiter.api.Assertions
|
||||||
|
import org.junit.jupiter.api.Test
|
||||||
|
|
||||||
|
class CommonsDistributionsTest {
|
||||||
|
@Test
|
||||||
|
fun testNormalDistribution(){
|
||||||
|
val distribution = Distribution.normal(7.0,2.0)
|
||||||
|
val generator = RandomGenerator.default(1)
|
||||||
|
val sample = runBlocking {
|
||||||
|
distribution.sample(generator).take(1000).toList()
|
||||||
|
}
|
||||||
|
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
||||||
|
}
|
||||||
|
}
|
@ -9,7 +9,7 @@ import kotlin.test.Test
|
|||||||
|
|
||||||
class StatisticTest {
|
class StatisticTest {
|
||||||
//create a random number generator.
|
//create a random number generator.
|
||||||
val generator = DefaultGenerator(1)
|
val generator = RandomGenerator.default(1)
|
||||||
//Create a stateless chain from generator.
|
//Create a stateless chain from generator.
|
||||||
val data = generator.chain { nextDouble() }
|
val data = generator.chain { nextDouble() }
|
||||||
//Convert a chaint to Flow and break it into chunks.
|
//Convert a chaint to Flow and break it into chunks.
|
||||||
|
Loading…
Reference in New Issue
Block a user