forked from kscience/kmath
Optimized blocking chains for primitive numbers generation.
This commit is contained in:
parent
3f68c0c34e
commit
8533a31677
@ -0,0 +1,71 @@
|
|||||||
|
package scientifik.kmath.commons.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.async
|
||||||
|
import kotlinx.coroutines.runBlocking
|
||||||
|
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
|
||||||
|
import org.apache.commons.rng.simple.RandomSource
|
||||||
|
import scientifik.kmath.chains.BlockingRealChain
|
||||||
|
import scientifik.kmath.prob.*
|
||||||
|
import java.time.Duration
|
||||||
|
import java.time.Instant
|
||||||
|
|
||||||
|
|
||||||
|
private suspend fun runChain(): Duration {
|
||||||
|
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
|
||||||
|
|
||||||
|
val normal = Distribution.normal(NormalSamplerMethod.Ziggurat)
|
||||||
|
val chain = normal.sample(generator) as BlockingRealChain
|
||||||
|
|
||||||
|
val startTime = Instant.now()
|
||||||
|
var sum = 0.0
|
||||||
|
repeat(10000001) { counter ->
|
||||||
|
|
||||||
|
sum += chain.nextDouble()
|
||||||
|
|
||||||
|
if (counter % 100000 == 0) {
|
||||||
|
val duration = Duration.between(startTime, Instant.now())
|
||||||
|
val meanValue = sum / counter
|
||||||
|
println("Chain sampler completed $counter elements in $duration: $meanValue")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Duration.between(startTime, Instant.now())
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun runDirect(): Duration {
|
||||||
|
val provider = RandomSource.create(RandomSource.MT, 123L)
|
||||||
|
val sampler = ZigguratNormalizedGaussianSampler(provider)
|
||||||
|
val startTime = Instant.now()
|
||||||
|
|
||||||
|
var sum = 0.0
|
||||||
|
repeat(10000001) { counter ->
|
||||||
|
|
||||||
|
sum += sampler.sample()
|
||||||
|
|
||||||
|
if (counter % 100000 == 0) {
|
||||||
|
val duration = Duration.between(startTime, Instant.now())
|
||||||
|
val meanValue = sum / counter
|
||||||
|
println("Direct sampler completed $counter elements in $duration: $meanValue")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Duration.between(startTime, Instant.now())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Comparing chain sampling performance with direct sampling performance
|
||||||
|
*/
|
||||||
|
fun main() {
|
||||||
|
runBlocking(Dispatchers.Default) {
|
||||||
|
val chainJob = async {
|
||||||
|
runChain()
|
||||||
|
}
|
||||||
|
|
||||||
|
val directJob = async {
|
||||||
|
runDirect()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("Chain: ${chainJob.await()}")
|
||||||
|
println("Direct: ${directJob.await()}")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,12 @@
|
|||||||
|
package scientifik.kmath.chains
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performance optimized chain for integer values
|
||||||
|
*/
|
||||||
|
abstract class BlockingIntChain : Chain<Int> {
|
||||||
|
abstract fun nextInt(): Int
|
||||||
|
|
||||||
|
override suspend fun next(): Int = nextInt()
|
||||||
|
|
||||||
|
fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
|
||||||
|
}
|
@ -0,0 +1,12 @@
|
|||||||
|
package scientifik.kmath.chains
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performance optimized chain for real values
|
||||||
|
*/
|
||||||
|
abstract class BlockingRealChain : Chain<Double> {
|
||||||
|
abstract fun nextDouble(): Double
|
||||||
|
|
||||||
|
override suspend fun next(): Double = nextDouble()
|
||||||
|
|
||||||
|
fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
|
||||||
|
}
|
@ -48,7 +48,6 @@ interface Chain<out R>: Flow<R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
companion object
|
companion object
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,9 +2,11 @@ package scientifik.kmath.streaming
|
|||||||
|
|
||||||
import kotlinx.coroutines.FlowPreview
|
import kotlinx.coroutines.FlowPreview
|
||||||
import kotlinx.coroutines.flow.*
|
import kotlinx.coroutines.flow.*
|
||||||
|
import scientifik.kmath.chains.BlockingRealChain
|
||||||
import scientifik.kmath.structures.Buffer
|
import scientifik.kmath.structures.Buffer
|
||||||
import scientifik.kmath.structures.BufferFactory
|
import scientifik.kmath.structures.BufferFactory
|
||||||
import scientifik.kmath.structures.DoubleBuffer
|
import scientifik.kmath.structures.DoubleBuffer
|
||||||
|
import scientifik.kmath.structures.asBuffer
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a [Flow] from buffer
|
* Create a [Flow] from buffer
|
||||||
@ -45,20 +47,28 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
|
|||||||
*/
|
*/
|
||||||
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
|
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
|
||||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||||
val array = DoubleArray(bufferSize)
|
|
||||||
var counter = 0
|
|
||||||
|
|
||||||
this@chunked.collect { element ->
|
if (this@chunked is BlockingRealChain) {
|
||||||
array[counter] = element
|
//performance optimization for blocking primitive chain
|
||||||
counter++
|
while (true) {
|
||||||
if (counter == bufferSize) {
|
emit(nextBlock(bufferSize).asBuffer())
|
||||||
val buffer = DoubleBuffer(array)
|
}
|
||||||
emit(buffer)
|
} else {
|
||||||
counter = 0
|
val array = DoubleArray(bufferSize)
|
||||||
|
var counter = 0
|
||||||
|
|
||||||
|
this@chunked.collect { element ->
|
||||||
|
array[counter] = element
|
||||||
|
counter++
|
||||||
|
if (counter == bufferSize) {
|
||||||
|
val buffer = DoubleBuffer(array)
|
||||||
|
emit(buffer)
|
||||||
|
counter = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (counter > 0) {
|
||||||
|
emit(DoubleBuffer(counter) { array[it] })
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if (counter > 0) {
|
|
||||||
emit(DoubleBuffer(counter) { array[it] })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import org.apache.commons.rng.simple.RandomSource
|
|||||||
|
|
||||||
class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator {
|
class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator {
|
||||||
internal val random: UniformRandomProvider = seed?.let {
|
internal val random: UniformRandomProvider = seed?.let {
|
||||||
RandomSource.create(source, seed, null)
|
RandomSource.create(source, seed)
|
||||||
} ?: RandomSource.create(source)
|
} ?: RandomSource.create(source)
|
||||||
|
|
||||||
override fun nextBoolean(): Boolean = random.nextBoolean()
|
override fun nextBoolean(): Boolean = random.nextBoolean()
|
||||||
@ -59,3 +59,9 @@ fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this
|
|||||||
} else {
|
} else {
|
||||||
RandomGeneratorProvider(this)
|
RandomGeneratorProvider(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
|
||||||
|
RandomSourceGenerator(source, seed)
|
||||||
|
|
||||||
|
fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator =
|
||||||
|
fromSource(RandomSource.MT, seed)
|
||||||
|
@ -2,6 +2,8 @@ package scientifik.kmath.prob
|
|||||||
|
|
||||||
import org.apache.commons.rng.UniformRandomProvider
|
import org.apache.commons.rng.UniformRandomProvider
|
||||||
import org.apache.commons.rng.sampling.distribution.*
|
import org.apache.commons.rng.sampling.distribution.*
|
||||||
|
import scientifik.kmath.chains.BlockingIntChain
|
||||||
|
import scientifik.kmath.chains.BlockingRealChain
|
||||||
import scientifik.kmath.chains.Chain
|
import scientifik.kmath.chains.Chain
|
||||||
import java.util.*
|
import java.util.*
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
@ -11,32 +13,32 @@ import kotlin.math.sqrt
|
|||||||
|
|
||||||
abstract class ContinuousSamplerDistribution : Distribution<Double> {
|
abstract class ContinuousSamplerDistribution : Distribution<Double> {
|
||||||
|
|
||||||
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain<Double> {
|
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() {
|
||||||
private val sampler = buildSampler(generator)
|
private val sampler = buildCMSampler(generator)
|
||||||
|
|
||||||
override suspend fun next(): Double = sampler.sample()
|
override fun nextDouble(): Double = sampler.sample()
|
||||||
|
|
||||||
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
|
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract fun buildSampler(generator: RandomGenerator): ContinuousSampler
|
protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Double> = ContinuousSamplerChain(generator)
|
override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator)
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract class DiscreteSamplerDistribution : Distribution<Int> {
|
abstract class DiscreteSamplerDistribution : Distribution<Int> {
|
||||||
|
|
||||||
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain<Int> {
|
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() {
|
||||||
private val sampler = buildSampler(generator)
|
private val sampler = buildSampler(generator)
|
||||||
|
|
||||||
override suspend fun next(): Int = sampler.sample()
|
override fun nextInt(): Int = sampler.sample()
|
||||||
|
|
||||||
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
|
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
|
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<Int> = ContinuousSamplerChain(generator)
|
override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator)
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class NormalSamplerMethod {
|
enum class NormalSamplerMethod {
|
||||||
@ -55,7 +57,7 @@ private fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomPr
|
|||||||
fun Distribution.Companion.normal(
|
fun Distribution.Companion.normal(
|
||||||
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
||||||
): Distribution<Double> = object : ContinuousSamplerDistribution() {
|
): Distribution<Double> = object : ContinuousSamplerDistribution() {
|
||||||
override fun buildSampler(generator: RandomGenerator): ContinuousSampler {
|
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
||||||
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
||||||
return normalSampler(method, provider)
|
return normalSampler(method, provider)
|
||||||
}
|
}
|
||||||
@ -69,11 +71,11 @@ fun Distribution.Companion.normal(
|
|||||||
mean: Double,
|
mean: Double,
|
||||||
sigma: Double,
|
sigma: Double,
|
||||||
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
||||||
): Distribution<Double> = object : ContinuousSamplerDistribution() {
|
): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() {
|
||||||
private val sigma2 = sigma.pow(2)
|
private val sigma2 = sigma.pow(2)
|
||||||
private val norm = sigma * sqrt(PI * 2)
|
private val norm = sigma * sqrt(PI * 2)
|
||||||
|
|
||||||
override fun buildSampler(generator: RandomGenerator): ContinuousSampler {
|
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
||||||
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
||||||
val normalizedSampler = normalSampler(method, provider)
|
val normalizedSampler = normalSampler(method, provider)
|
||||||
return GaussianSampler(normalizedSampler, mean, sigma)
|
return GaussianSampler(normalizedSampler, mean, sigma)
|
||||||
@ -86,7 +88,7 @@ fun Distribution.Companion.normal(
|
|||||||
|
|
||||||
fun Distribution.Companion.poisson(
|
fun Distribution.Companion.poisson(
|
||||||
lambda: Double
|
lambda: Double
|
||||||
): Distribution<Int> = object : DiscreteSamplerDistribution() {
|
): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() {
|
||||||
|
|
||||||
override fun buildSampler(generator: RandomGenerator): DiscreteSampler {
|
override fun buildSampler(generator: RandomGenerator): DiscreteSampler {
|
||||||
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda)
|
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda)
|
||||||
|
@ -8,12 +8,21 @@ import org.junit.jupiter.api.Test
|
|||||||
|
|
||||||
class CommonsDistributionsTest {
|
class CommonsDistributionsTest {
|
||||||
@Test
|
@Test
|
||||||
fun testNormalDistribution(){
|
fun testNormalDistributionSuspend() {
|
||||||
val distribution = Distribution.normal(7.0,2.0)
|
val distribution = Distribution.normal(7.0, 2.0)
|
||||||
val generator = RandomGenerator.default(1)
|
val generator = RandomGenerator.default(1)
|
||||||
val sample = runBlocking {
|
val sample = runBlocking {
|
||||||
distribution.sample(generator).take(1000).toList()
|
distribution.sample(generator).take(1000).toList()
|
||||||
}
|
}
|
||||||
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testNormalDistributionBlocking() {
|
||||||
|
val distribution = Distribution.normal(7.0, 2.0)
|
||||||
|
val generator = RandomGenerator.default(1)
|
||||||
|
val sample = distribution.sample(generator).nextBlock(1000)
|
||||||
|
Assertions.assertEquals(7.0, sample.average(), 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user