Refactor, remove unused files, remove BasicSampler

This commit is contained in:
Iaroslav Postovalov 2020-10-16 16:49:47 +07:00
parent d4aa4587a9
commit 612f6f0082
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
22 changed files with 194 additions and 127 deletions

View File

@ -1,3 +1,5 @@
import ru.mipt.npm.gradle.Maturity
plugins { plugins {
id("ru.mipt.npm.mpp") id("ru.mipt.npm.mpp")
id("ru.mipt.npm.native") id("ru.mipt.npm.native")
@ -11,33 +13,39 @@ kotlin.sourceSets.commonMain {
readme { readme {
description = "Core classes, algebra definitions, basic linear algebra" description = "Core classes, algebra definitions, basic linear algebra"
maturity = ru.mipt.npm.gradle.Maturity.DEVELOPMENT maturity = Maturity.DEVELOPMENT
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
feature( feature(
id = "algebras", id = "algebras",
description = "Algebraic structures: contexts and elements", description = "Algebraic structures: contexts and elements",
ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt" ref = "src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt"
) )
feature( feature(
id = "nd", id = "nd",
description = "Many-dimensional structures", description = "Many-dimensional structures",
ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt" ref = "src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt"
) )
feature( feature(
id = "buffers", id = "buffers",
description = "One-dimensional structure", description = "One-dimensional structure",
ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt" ref = "src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt"
) )
feature( feature(
id = "expressions", id = "expressions",
description = "Functional Expressions", description = "Functional Expressions",
ref = "src/commonMain/kotlin/kscience/kmath/expressions" ref = "src/commonMain/kotlin/kscience/kmath/expressions"
) )
feature( feature(
id = "domains", id = "domains",
description = "Domains", description = "Domains",
ref = "src/commonMain/kotlin/kscience/kmath/domains" ref = "src/commonMain/kotlin/kscience/kmath/domains"
) )
feature( feature(
id = "autodif", id = "autodif",
description = "Automatic differentiation", description = "Automatic differentiation",

View File

@ -237,18 +237,18 @@ public class BigInt internal constructor(
) )
private fun compareMagnitudes(mag1: Magnitude, mag2: Magnitude): Int { private fun compareMagnitudes(mag1: Magnitude, mag2: Magnitude): Int {
when { return when {
mag1.size > mag2.size -> return 1 mag1.size > mag2.size -> 1
mag1.size < mag2.size -> return -1 mag1.size < mag2.size -> -1
else -> { else -> {
for (i in mag1.size - 1 downTo 0) { for (i in mag1.size - 1 downTo 0) return when {
if (mag1[i] > mag2[i]) { mag1[i] > mag2[i] -> 1
return 1 mag1[i] < mag2[i] -> -1
} else if (mag1[i] < mag2[i]) { else -> continue
return -1
}
} }
return 0
0
} }
} }
} }
@ -298,10 +298,11 @@ public class BigInt internal constructor(
var carry = 0uL var carry = 0uL
for (i in mag.indices) { for (i in mag.indices) {
val cur: ULong = carry + mag[i].toULong() * x.toULong() val cur = carry + mag[i].toULong() * x.toULong()
result[i] = (cur and BASE).toUInt() result[i] = (cur and BASE).toUInt()
carry = cur shr BASE_SIZE carry = cur shr BASE_SIZE
} }
result[resultLength - 1] = (carry and BASE).toUInt() result[resultLength - 1] = (carry and BASE).toUInt()
return stripLeadingZeros(result) return stripLeadingZeros(result)

View File

@ -71,7 +71,7 @@ public interface Buffer<T> {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public inline fun <T : Any> auto(type: KClass<T>, size: Int, initializer: (Int) -> T): Buffer<T> = public inline fun <T : Any> auto(type: KClass<T>, size: Int, initializer: (Int) -> T): Buffer<T> =
when (type) { when (type) {
Double::class -> RealBuffer(size) { initializer(it) as Double } as Buffer<T> Double::class -> real(size) { initializer(it) as Double } as Buffer<T>
Short::class -> ShortBuffer(size) { initializer(it) as Short } as Buffer<T> Short::class -> ShortBuffer(size) { initializer(it) as Short } as Buffer<T>
Int::class -> IntBuffer(size) { initializer(it) as Int } as Buffer<T> Int::class -> IntBuffer(size) { initializer(it) as Int } as Buffer<T>
Long::class -> LongBuffer(size) { initializer(it) as Long } as Buffer<T> Long::class -> LongBuffer(size) { initializer(it) as Long } as Buffer<T>

View File

@ -22,7 +22,7 @@ public interface Structure2D<T> : NDStructure<T> {
override fun elements(): Sequence<Pair<IntArray, T>> = sequence { override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
for (i in (0 until rowNum)) for (i in (0 until rowNum))
for (j in (0 until colNum)) yield(intArrayOf(i, j) to get(i, j)) for (j in (0 until colNum)) yield(intArrayOf(i, j) to this@Structure2D[i, j])
} }
public companion object public companion object
@ -35,7 +35,6 @@ private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Stru
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override operator fun get(i: Int, j: Int): T = structure[i, j] override operator fun get(i: Int, j: Int): T = structure[i, j]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements() override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
} }

View File

@ -29,12 +29,11 @@ class NumberNDFieldTest {
val array = real2D(3, 3) { i, j -> (i * 10 + j).toDouble() } val array = real2D(3, 3) { i, j -> (i * 10 + j).toDouble() }
for (i in 0..2) { for (i in 0..2)
for (j in 0..2) { for (j in 0..2) {
val expected = (i * 10 + j).toDouble() val expected = (i * 10 + j).toDouble()
assertEquals(expected, array[i, j], "Error at index [$i, $j]") assertEquals(expected, array[i, j], "Error at index [$i, $j]")
} }
}
} }
@Test @Test

View File

@ -63,12 +63,10 @@ public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private
public fun value(): R? = value public fun value(): R? = value
public override suspend fun next(): R { public override suspend fun next(): R = mutex.withLock {
mutex.withLock { val newValue = gen(value ?: seed())
val newValue = gen(value ?: seed()) value = newValue
value = newValue newValue
return newValue
}
} }
public override fun fork(): Chain<R> = MarkovChain(seed = { value ?: seed() }, gen = gen) public override fun fork(): Chain<R> = MarkovChain(seed = { value ?: seed() }, gen = gen)
@ -90,12 +88,10 @@ public class StatefulChain<S, out R>(
public fun value(): R? = value public fun value(): R? = value
public override suspend fun next(): R { public override suspend fun next(): R = mutex.withLock {
mutex.withLock { val newValue = state.gen(value ?: state.seed())
val newValue = state.gen(value ?: state.seed()) value = newValue
value = newValue newValue
return newValue
}
} }
public override fun fork(): Chain<R> = StatefulChain(forkState(state), seed, forkState, gen) public override fun fork(): Chain<R> = StatefulChain(forkState(state), seed, forkState, gen)

View File

@ -28,7 +28,7 @@ public fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>)
var counter = 0 var counter = 0
this@chunked.collect { element -> this@chunked.collect { element ->
list.add(element) list += element
counter++ counter++
if (counter == bufferSize) { if (counter == bufferSize) {

View File

@ -48,11 +48,9 @@ public class RingBuffer<T>(
/** /**
* A safe snapshot operation * A safe snapshot operation
*/ */
public suspend fun snapshot(): Buffer<T> { public suspend fun snapshot(): Buffer<T> = mutex.withLock {
mutex.withLock { val copy = buffer.copy()
val copy = buffer.copy() VirtualBuffer(size) { i -> copy[startIndex.forward(i)] as T }
return VirtualBuffer(size) { i -> copy[startIndex.forward(i)] as T }
}
} }
public suspend fun push(element: T) { public suspend fun push(element: T) {

View File

@ -23,8 +23,8 @@ public class OrderedPiecewisePolynomial<T : Comparable<T>>(delimiter: T) :
*/ */
public fun putRight(right: T, piece: Polynomial<T>) { public fun putRight(right: T, piece: Polynomial<T>) {
require(right > delimiters.last()) { "New delimiter should be to the right of old one" } require(right > delimiters.last()) { "New delimiter should be to the right of old one" }
delimiters.add(right) delimiters += right
pieces.add(piece) pieces += piece
} }
public fun putLeft(left: T, piece: Polynomial<T>) { public fun putLeft(left: T, piece: Polynomial<T>) {

View File

@ -5,8 +5,18 @@ import kscience.kmath.chains.Chain
import kscience.kmath.chains.collect import kscience.kmath.chains.collect
import kscience.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import kscience.kmath.structures.BufferFactory import kscience.kmath.structures.BufferFactory
import kscience.kmath.structures.IntBuffer
/**
* Sampler that generates chains of values of type [T].
*/
public fun interface Sampler<T : Any> { public fun interface Sampler<T : Any> {
/**
* Generates a chain of samples.
*
* @param generator the randomness provider.
* @return the new chain.
*/
public fun sample(generator: RandomGenerator): Chain<T> public fun sample(generator: RandomGenerator): Chain<T>
} }
@ -59,16 +69,25 @@ public fun <T : Any> Sampler<T>.sampleBuffer(
//clear list from previous run //clear list from previous run
tmp.clear() tmp.clear()
//Fill list //Fill list
repeat(size) { tmp.add(chain.next()) } repeat(size) { tmp += chain.next() }
//return new buffer with elements from tmp //return new buffer with elements from tmp
bufferFactory(size) { tmp[it] } bufferFactory(size) { tmp[it] }
} }
} }
/**
* Samples one value from this [Sampler].
*/
public suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator): T = sample(generator).first() public suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator): T = sample(generator).first()
/** /**
* Generate a bunch of samples from real distributions * Generates [size] real samples and chunks them into some buffers.
*/ */
public fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int): Chain<Buffer<Double>> = public fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int): Chain<Buffer<Double>> =
sampleBuffer(generator, size, Buffer.Companion::real) sampleBuffer(generator, size, Buffer.Companion::real)
/**
* Generates [size] integer samples and chunks them into some buffers.
*/
public fun Sampler<Int>.sampleBuffer(generator: RandomGenerator, size: Int): Chain<Buffer<Int>> =
sampleBuffer(generator, size, ::IntBuffer)

View File

@ -38,6 +38,6 @@ public class DistributionBuilder<T : Any> {
private val distributions = ArrayList<NamedDistribution<T>>() private val distributions = ArrayList<NamedDistribution<T>>()
public infix fun String.to(distribution: Distribution<T>) { public infix fun String.to(distribution: Distribution<T>) {
distributions.add(NamedDistributionWrapper(this, distribution)) distributions += NamedDistributionWrapper(this, distribution)
} }
} }

View File

@ -82,6 +82,8 @@ public interface RandomGenerator {
/** /**
* Implements [RandomGenerator] by delegating all operations to [Random]. * Implements [RandomGenerator] by delegating all operations to [Random].
*
* @property random the underlying [Random] object.
*/ */
public inline class DefaultGenerator(public val random: Random = Random) : RandomGenerator { public inline class DefaultGenerator(public val random: Random = Random) : RandomGenerator {
public override fun nextBoolean(): Boolean = random.nextBoolean() public override fun nextBoolean(): Boolean = random.nextBoolean()

View File

@ -7,25 +7,28 @@ import kscience.kmath.chains.zip
import kscience.kmath.operations.Space import kscience.kmath.operations.Space
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
public class BasicSampler<T : Any>(public val chainBuilder: (RandomGenerator) -> Chain<T>) : Sampler<T> { /**
public override fun sample(generator: RandomGenerator): Chain<T> = chainBuilder(generator) * Implements [Sampler] by sampling only certain [value].
} *
* @property value the value to sample.
*/
public class ConstantSampler<T : Any>(public val value: T) : Sampler<T> { public class ConstantSampler<T : Any>(public val value: T) : Sampler<T> {
public override fun sample(generator: RandomGenerator): Chain<T> = ConstantChain(value) public override fun sample(generator: RandomGenerator): Chain<T> = ConstantChain(value)
} }
/** /**
* A space for samplers. Allows to perform simple operations on distributions * A space of samplers. Allows to perform simple operations on distributions.
*
* @property space the space to provide addition and scalar multiplication for [T].
*/ */
public class SamplerSpace<T : Any>(public val space: Space<T>) : Space<Sampler<T>> { public class SamplerSpace<T : Any>(public val space: Space<T>) : Space<Sampler<T>> {
public override val zero: Sampler<T> = ConstantSampler(space.zero) public override val zero: Sampler<T> = ConstantSampler(space.zero)
public override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator -> public override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = Sampler { generator ->
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space { aValue + bValue } } a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> space { aValue + bValue } }
} }
public override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = BasicSampler { generator -> public override fun multiply(a: Sampler<T>, k: Number): Sampler<T> = Sampler { generator ->
a.sample(generator).map { space { it * k.toDouble() } } a.sample(generator).map { space { it * k.toDouble() } }
} }
} }

View File

@ -98,28 +98,21 @@ internal object InternalGamma {
private const val INV_GAMMA1P_M1_C12 = .113302723198169588237412962033074E-05 private const val INV_GAMMA1P_M1_C12 = .113302723198169588237412962033074E-05
private const val INV_GAMMA1P_M1_C13 = -.205633841697760710345015413002057E-06 private const val INV_GAMMA1P_M1_C13 = -.205633841697760710345015413002057E-06
fun logGamma(x: Double): Double { fun logGamma(x: Double): Double = when {
val ret: Double x.isNaN() || x <= 0.0 -> Double.NaN
x < 0.5 -> logGamma1p(x) - ln(x)
x <= 2.5 -> logGamma1p(x - 0.5 - 0.5)
when { x <= 8.0 -> {
x.isNaN() || x <= 0.0 -> ret = Double.NaN val n = floor(x - 1.5).toInt()
x < 0.5 -> return logGamma1p(x) - ln(x) val prod = (1..n).fold(1.0, { prod, i -> prod * (x - i) })
x <= 2.5 -> return logGamma1p(x - 0.5 - 0.5) logGamma1p(x - (n + 1)) + ln(prod)
x <= 8.0 -> {
val n = floor(x - 1.5).toInt()
var prod = 1.0
(1..n).forEach { i -> prod *= x - i }
return logGamma1p(x - (n + 1)) + ln(prod)
}
else -> {
val tmp = x + LANCZOS_G + .5
ret = (x + .5) * ln(tmp) - tmp + HALF_LOG_2_PI + ln(lanczos(x) / x)
}
} }
return ret else -> {
val tmp = x + LANCZOS_G + .5
(x + .5) * ln(tmp) - tmp + HALF_LOG_2_PI + ln(lanczos(x) / x)
}
} }
private fun regularizedGammaP( private fun regularizedGammaP(

View File

@ -20,24 +20,17 @@ internal object InternalUtils {
fun validateProbabilities(probabilities: DoubleArray?): Double { fun validateProbabilities(probabilities: DoubleArray?): Double {
require(!(probabilities == null || probabilities.isEmpty())) { "Probabilities must not be empty." } require(!(probabilities == null || probabilities.isEmpty())) { "Probabilities must not be empty." }
var sumProb = 0.0
probabilities.forEach { prob -> val sumProb = probabilities.sumByDouble { prob ->
validateProbability(prob) require(!(prob < 0 || prob.isInfinite() || prob.isNaN())) { "Invalid probability: $prob" }
sumProb += prob prob
} }
require(!(sumProb.isInfinite() || sumProb <= 0)) { "Invalid sum of probabilities: $sumProb" } require(!(sumProb.isInfinite() || sumProb <= 0)) { "Invalid sum of probabilities: $sumProb" }
return sumProb return sumProb
} }
private fun validateProbability(probability: Double): Unit = class FactorialLog private constructor(numValues: Int, cache: DoubleArray?) {
require(!(probability < 0 || probability.isInfinite() || probability.isNaN())) { "Invalid probability: $probability" }
class FactorialLog private constructor(
numValues: Int,
cache: DoubleArray?
) {
private val logFactorials: DoubleArray = DoubleArray(numValues) private val logFactorials: DoubleArray = DoubleArray(numValues)
init { init {
@ -46,14 +39,15 @@ internal object InternalUtils {
if (cache != null && cache.size > BEGIN_LOG_FACTORIALS) { if (cache != null && cache.size > BEGIN_LOG_FACTORIALS) {
// Copy available values. // Copy available values.
endCopy = min(cache.size, numValues) endCopy = min(cache.size, numValues)
cache.copyInto( cache.copyInto(
logFactorials, logFactorials,
BEGIN_LOG_FACTORIALS, BEGIN_LOG_FACTORIALS,
BEGIN_LOG_FACTORIALS, endCopy BEGIN_LOG_FACTORIALS, endCopy
) )
} } else
// All values to be computed // All values to be computed
else endCopy = BEGIN_LOG_FACTORIALS endCopy = BEGIN_LOG_FACTORIALS
// Compute remaining values. // Compute remaining values.
(endCopy until numValues).forEach { i -> (endCopy until numValues).forEach { i ->
@ -65,9 +59,7 @@ internal object InternalUtils {
} }
fun value(n: Int): Double { fun value(n: Int): Double {
if (n < logFactorials.size) if (n < logFactorials.size) return logFactorials[n]
return logFactorials[n]
return if (n < FACTORIALS.size) ln(FACTORIALS[n].toDouble()) else InternalGamma.logGamma(n + 1.0) return if (n < FACTORIALS.size) ln(FACTORIALS[n].toDouble()) else InternalGamma.logGamma(n + 1.0)
} }

View File

@ -42,7 +42,7 @@ public open class AliasMethodDiscreteSampler private constructor(
protected val alias: IntArray protected val alias: IntArray
) : Sampler<Int> { ) : Sampler<Int> {
private class SmallTableAliasMethodDiscreteSampler internal constructor( private class SmallTableAliasMethodDiscreteSampler(
probability: LongArray, probability: LongArray,
alias: IntArray alias: IntArray
) : AliasMethodDiscreteSampler(probability, alias) { ) : AliasMethodDiscreteSampler(probability, alias) {
@ -71,7 +71,7 @@ public open class AliasMethodDiscreteSampler private constructor(
} }
} }
override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain { public override fun sample(generator: RandomGenerator): Chain<Int> = generator.chain {
// This implements the algorithm as per Vose (1991): // This implements the algorithm as per Vose (1991):
// v = uniform() in [0, 1) // v = uniform() in [0, 1)
// j = uniform(n) in [0, n) // j = uniform(n) in [0, n)

View File

@ -1,2 +0,0 @@
package kscience.kmath.prob.samplers

View File

@ -1,2 +0,0 @@
package kscience.kmath.prob.samplers

View File

@ -26,35 +26,24 @@ public class ZigguratNormalizedGaussianSampler private constructor() :
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { sampleOne(this) } public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain { sampleOne(this) }
public override fun toString(): String = "Ziggurat normalized Gaussian deviate" public override fun toString(): String = "Ziggurat normalized Gaussian deviate"
private fun fix( private fun fix(generator: RandomGenerator, hz: Long, iz: Int): Double {
generator: RandomGenerator, var x = hz * W[iz]
hz: Long,
iz: Int
): Double {
var x: Double
var y: Double
x = hz * W[iz]
return if (iz == 0) { return when {
// Base strip. iz == 0 -> {
// This branch is called about 5.7624515E-4 times per sample. var y: Double
do {
y = -ln(generator.nextDouble())
x = -ln(generator.nextDouble()) * ONE_OVER_R
} while (y + y < x * x)
val out = R + x do {
if (hz > 0) out else -out y = -ln(generator.nextDouble())
} else { x = -ln(generator.nextDouble()) * ONE_OVER_R
// Wedge of other strips. } while (y + y < x * x)
// This branch is called about 0.027323 times per sample.
// else val out = R + x
// Try again. if (hz > 0) out else -out
// This branch is called about 0.012362 times per sample. }
if (F[iz] + generator.nextDouble() * (F[iz - 1] - F[iz]) < gauss(
x F[iz] + generator.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x) -> x
) else -> sampleOne(generator)
) x else sampleOne(generator)
} }
} }

View File

@ -3,10 +3,14 @@ package kscience.kmath.prob
import org.apache.commons.rng.UniformRandomProvider import org.apache.commons.rng.UniformRandomProvider
import org.apache.commons.rng.simple.RandomSource import org.apache.commons.rng.simple.RandomSource
public class RandomSourceGenerator(public val source: RandomSource, seed: Long?) : RandomGenerator { /**
internal val random: UniformRandomProvider = seed?.let { * Implements [RandomGenerator] by delegating all operations to [RandomSource].
RandomSource.create(source, seed) *
} ?: RandomSource.create(source) * @property source the underlying [RandomSource] object.
*/
public class RandomSourceGenerator internal constructor(public val source: RandomSource, seed: Long?) : RandomGenerator {
internal val random: UniformRandomProvider = seed?.let { RandomSource.create(source, seed) }
?: RandomSource.create(source)
public override fun nextBoolean(): Boolean = random.nextBoolean() public override fun nextBoolean(): Boolean = random.nextBoolean()
public override fun nextDouble(): Double = random.nextDouble() public override fun nextDouble(): Double = random.nextDouble()
@ -23,19 +27,84 @@ public class RandomSourceGenerator(public val source: RandomSource, seed: Long?)
public override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong()) public override fun fork(): RandomGenerator = RandomSourceGenerator(source, nextLong())
} }
/**
* Implements [UniformRandomProvider] by delegating all operations to [RandomGenerator].
*
* @property generator the underlying [RandomGenerator] object.
*/
public inline class RandomGeneratorProvider(public val generator: RandomGenerator) : UniformRandomProvider { public inline class RandomGeneratorProvider(public val generator: RandomGenerator) : UniformRandomProvider {
/**
* Generates a [Boolean] value.
*
* @return the next random value.
*/
public override fun nextBoolean(): Boolean = generator.nextBoolean() public override fun nextBoolean(): Boolean = generator.nextBoolean()
/**
* Generates a [Float] value between 0 and 1.
*
* @return the next random value between 0 and 1.
*/
public override fun nextFloat(): Float = generator.nextDouble().toFloat() public override fun nextFloat(): Float = generator.nextDouble().toFloat()
/**
* Generates [Byte] values and places them into a user-supplied array.
*
* The number of random bytes produced is equal to the length of the the byte array.
*
* @param bytes byte array in which to put the random bytes.
*/
public override fun nextBytes(bytes: ByteArray): Unit = generator.fillBytes(bytes) public override fun nextBytes(bytes: ByteArray): Unit = generator.fillBytes(bytes)
/**
* Generates [Byte] values and places them into a user-supplied array.
*
* The array is filled with bytes extracted from random integers. This implies that the number of random bytes
* generated may be larger than the length of the byte array.
*
* @param bytes the array in which to put the generated bytes.
* @param start the index at which to start inserting the generated bytes.
* @param len the number of bytes to insert.
*/
public override fun nextBytes(bytes: ByteArray, start: Int, len: Int) { public override fun nextBytes(bytes: ByteArray, start: Int, len: Int) {
generator.fillBytes(bytes, start, start + len) generator.fillBytes(bytes, start, start + len)
} }
/**
* Generates an [Int] value.
*
* @return the next random value.
*/
public override fun nextInt(): Int = generator.nextInt() public override fun nextInt(): Int = generator.nextInt()
/**
* Generates an [Int] value between 0 (inclusive) and the specified value (exclusive).
*
* @param n the bound on the random number to be returned. Must be positive.
* @return a random integer between 0 (inclusive) and [n] (exclusive).
*/
public override fun nextInt(n: Int): Int = generator.nextInt(n) public override fun nextInt(n: Int): Int = generator.nextInt(n)
/**
* Generates a [Double] value between 0 and 1.
*
* @return the next random value between 0 and 1.
*/
public override fun nextDouble(): Double = generator.nextDouble() public override fun nextDouble(): Double = generator.nextDouble()
/**
* Generates a [Long] value.
*
* @return the next random value.
*/
public override fun nextLong(): Long = generator.nextLong() public override fun nextLong(): Long = generator.nextLong()
/**
* Generates a [Long] value between 0 (inclusive) and the specified value (exclusive).
*
* @param n Bound on the random number to be returned. Must be positive.
* @return a random long value between 0 (inclusive) and [n] (exclusive).
*/
public override fun nextLong(n: Long): Long = generator.nextLong(n) public override fun nextLong(n: Long): Long = generator.nextLong(n)
} }
@ -48,8 +117,14 @@ public fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if
else else
RandomGeneratorProvider(this) RandomGeneratorProvider(this)
/**
* Returns [RandomSourceGenerator] with given [RandomSource] and [seed].
*/
public fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator = public fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
RandomSourceGenerator(source, seed) RandomSourceGenerator(source, seed)
/**
* Returns [RandomSourceGenerator] with [RandomSource.MT] algorithm and given [seed].
*/
public fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator = public fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator =
fromSource(RandomSource.MT, seed) fromSource(RandomSource.MT, seed)

View File

@ -7,11 +7,8 @@ class SamplerTest {
@Test @Test
fun bufferSamplerTest() { fun bufferSamplerTest() {
val sampler: Sampler<Double> = val sampler = Sampler { it.chain { nextDouble() } }
BasicSampler { it.chain { nextDouble() } }
val data = sampler.sampleBuffer(RandomGenerator.default, 100) val data = sampler.sampleBuffer(RandomGenerator.default, 100)
runBlocking { runBlocking { println(data.next()) }
println(data.next())
}
} }
} }