Optimized blocking chains for primitive numbers generation.

This commit is contained in:
Alexander Nozik 2020-05-23 10:45:02 +03:00
parent 3f68c0c34e
commit 8533a31677
8 changed files with 149 additions and 28 deletions

View File

@ -0,0 +1,71 @@
package scientifik.kmath.commons.prob
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
import org.apache.commons.rng.simple.RandomSource
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.prob.*
import java.time.Duration
import java.time.Instant
private suspend fun runChain(): Duration {
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
val normal = Distribution.normal(NormalSamplerMethod.Ziggurat)
val chain = normal.sample(generator) as BlockingRealChain
val startTime = Instant.now()
var sum = 0.0
repeat(10000001) { counter ->
sum += chain.nextDouble()
if (counter % 100000 == 0) {
val duration = Duration.between(startTime, Instant.now())
val meanValue = sum / counter
println("Chain sampler completed $counter elements in $duration: $meanValue")
}
}
return Duration.between(startTime, Instant.now())
}
private fun runDirect(): Duration {
val provider = RandomSource.create(RandomSource.MT, 123L)
val sampler = ZigguratNormalizedGaussianSampler(provider)
val startTime = Instant.now()
var sum = 0.0
repeat(10000001) { counter ->
sum += sampler.sample()
if (counter % 100000 == 0) {
val duration = Duration.between(startTime, Instant.now())
val meanValue = sum / counter
println("Direct sampler completed $counter elements in $duration: $meanValue")
}
}
return Duration.between(startTime, Instant.now())
}
/**
* Comparing chain sampling performance with direct sampling performance
*/
fun main() {
runBlocking(Dispatchers.Default) {
val chainJob = async {
runChain()
}
val directJob = async {
runDirect()
}
println("Chain: ${chainJob.await()}")
println("Direct: ${directJob.await()}")
}
}

View File

@ -0,0 +1,12 @@
package scientifik.kmath.chains
/**
* Performance optimized chain for integer values
*/
abstract class BlockingIntChain : Chain<Int> {
abstract fun nextInt(): Int
override suspend fun next(): Int = nextInt()
fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
}

View File

@ -0,0 +1,12 @@
package scientifik.kmath.chains
/**
* Performance optimized chain for real values
*/
abstract class BlockingRealChain : Chain<Double> {
abstract fun nextDouble(): Double
override suspend fun next(): Double = nextDouble()
fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
}

View File

@ -48,7 +48,6 @@ interface Chain<out R>: Flow<R> {
}
companion object
}

View File

@ -2,9 +2,11 @@ package scientifik.kmath.streaming
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.flow.*
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory
import scientifik.kmath.structures.DoubleBuffer
import scientifik.kmath.structures.asBuffer
/**
* Create a [Flow] from buffer
@ -45,20 +47,28 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
*/
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
val array = DoubleArray(bufferSize)
var counter = 0
this@chunked.collect { element ->
array[counter] = element
counter++
if (counter == bufferSize) {
val buffer = DoubleBuffer(array)
emit(buffer)
counter = 0
if (this@chunked is BlockingRealChain) {
//performance optimization for blocking primitive chain
while (true) {
emit(nextBlock(bufferSize).asBuffer())
}
} else {
val array = DoubleArray(bufferSize)
var counter = 0
this@chunked.collect { element ->
array[counter] = element
counter++
if (counter == bufferSize) {
val buffer = DoubleBuffer(array)
emit(buffer)
counter = 0
}
}
if (counter > 0) {
emit(DoubleBuffer(counter) { array[it] })
}
}
if (counter > 0) {
emit(DoubleBuffer(counter) { array[it] })
}
}

View File

@ -5,7 +5,7 @@ import org.apache.commons.rng.simple.RandomSource
class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator {
internal val random: UniformRandomProvider = seed?.let {
RandomSource.create(source, seed, null)
RandomSource.create(source, seed)
} ?: RandomSource.create(source)
override fun nextBoolean(): Boolean = random.nextBoolean()
@ -59,3 +59,9 @@ fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this
} else {
RandomGeneratorProvider(this)
}
fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
RandomSourceGenerator(source, seed)
fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator =
fromSource(RandomSource.MT, seed)

View File

@ -2,6 +2,8 @@ package scientifik.kmath.prob
import org.apache.commons.rng.UniformRandomProvider
import org.apache.commons.rng.sampling.distribution.*
import scientifik.kmath.chains.BlockingIntChain
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.chains.Chain
import java.util.*
import kotlin.math.PI
@ -11,32 +13,32 @@ import kotlin.math.sqrt
abstract class ContinuousSamplerDistribution : Distribution<Double> {
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain<Double> {
private val sampler = buildSampler(generator)
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() {
private val sampler = buildCMSampler(generator)
override suspend fun next(): Double = sampler.sample()
override fun nextDouble(): Double = sampler.sample()
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
}
protected abstract fun buildSampler(generator: RandomGenerator): ContinuousSampler
protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler
override fun sample(generator: RandomGenerator): Chain<Double> = ContinuousSamplerChain(generator)
override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator)
}
abstract class DiscreteSamplerDistribution : Distribution<Int> {
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain<Int> {
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() {
private val sampler = buildSampler(generator)
override suspend fun next(): Int = sampler.sample()
override fun nextInt(): Int = sampler.sample()
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
}
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
override fun sample(generator: RandomGenerator): Chain<Int> = ContinuousSamplerChain(generator)
override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator)
}
enum class NormalSamplerMethod {
@ -55,7 +57,7 @@ private fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomPr
fun Distribution.Companion.normal(
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
): Distribution<Double> = object : ContinuousSamplerDistribution() {
override fun buildSampler(generator: RandomGenerator): ContinuousSampler {
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
return normalSampler(method, provider)
}
@ -69,11 +71,11 @@ fun Distribution.Companion.normal(
mean: Double,
sigma: Double,
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
): Distribution<Double> = object : ContinuousSamplerDistribution() {
): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() {
private val sigma2 = sigma.pow(2)
private val norm = sigma * sqrt(PI * 2)
override fun buildSampler(generator: RandomGenerator): ContinuousSampler {
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
val normalizedSampler = normalSampler(method, provider)
return GaussianSampler(normalizedSampler, mean, sigma)
@ -86,7 +88,7 @@ fun Distribution.Companion.normal(
fun Distribution.Companion.poisson(
lambda: Double
): Distribution<Int> = object : DiscreteSamplerDistribution() {
): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() {
override fun buildSampler(generator: RandomGenerator): DiscreteSampler {
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda)

View File

@ -8,12 +8,21 @@ import org.junit.jupiter.api.Test
class CommonsDistributionsTest {
@Test
fun testNormalDistribution(){
val distribution = Distribution.normal(7.0,2.0)
fun testNormalDistributionSuspend() {
val distribution = Distribution.normal(7.0, 2.0)
val generator = RandomGenerator.default(1)
val sample = runBlocking {
distribution.sample(generator).take(1000).toList()
}
Assertions.assertEquals(7.0, sample.average(), 0.1)
}
@Test
fun testNormalDistributionBlocking() {
val distribution = Distribution.normal(7.0, 2.0)
val generator = RandomGenerator.default(1)
val sample = distribution.sample(generator).nextBlock(1000)
Assertions.assertEquals(7.0, sample.average(), 0.1)
}
}