Optimized blocking chains for primitive numbers generation.

This commit is contained in:
Alexander Nozik 2020-05-23 10:45:02 +03:00
parent 3f68c0c34e
commit 8533a31677
8 changed files with 149 additions and 28 deletions

View File

@ -0,0 +1,71 @@
package scientifik.kmath.commons.prob
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
import org.apache.commons.rng.simple.RandomSource
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.prob.*
import java.time.Duration
import java.time.Instant
private suspend fun runChain(): Duration {
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
val normal = Distribution.normal(NormalSamplerMethod.Ziggurat)
val chain = normal.sample(generator) as BlockingRealChain
val startTime = Instant.now()
var sum = 0.0
repeat(10000001) { counter ->
sum += chain.nextDouble()
if (counter % 100000 == 0) {
val duration = Duration.between(startTime, Instant.now())
val meanValue = sum / counter
println("Chain sampler completed $counter elements in $duration: $meanValue")
}
}
return Duration.between(startTime, Instant.now())
}
private fun runDirect(): Duration {
val provider = RandomSource.create(RandomSource.MT, 123L)
val sampler = ZigguratNormalizedGaussianSampler(provider)
val startTime = Instant.now()
var sum = 0.0
repeat(10000001) { counter ->
sum += sampler.sample()
if (counter % 100000 == 0) {
val duration = Duration.between(startTime, Instant.now())
val meanValue = sum / counter
println("Direct sampler completed $counter elements in $duration: $meanValue")
}
}
return Duration.between(startTime, Instant.now())
}
/**
* Comparing chain sampling performance with direct sampling performance
*/
fun main() {
runBlocking(Dispatchers.Default) {
val chainJob = async {
runChain()
}
val directJob = async {
runDirect()
}
println("Chain: ${chainJob.await()}")
println("Direct: ${directJob.await()}")
}
}

View File

@ -0,0 +1,12 @@
package scientifik.kmath.chains
/**
* Performance optimized chain for integer values
*/
abstract class BlockingIntChain : Chain<Int> {
abstract fun nextInt(): Int
override suspend fun next(): Int = nextInt()
fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
}

View File

@ -0,0 +1,12 @@
package scientifik.kmath.chains
/**
* Performance optimized chain for real values
*/
abstract class BlockingRealChain : Chain<Double> {
abstract fun nextDouble(): Double
override suspend fun next(): Double = nextDouble()
fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
}

View File

@ -48,7 +48,6 @@ interface Chain<out R>: Flow<R> {
} }
companion object companion object
} }

View File

@ -2,9 +2,11 @@ package scientifik.kmath.streaming
import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.flow.* import kotlinx.coroutines.flow.*
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.BufferFactory
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.DoubleBuffer
import scientifik.kmath.structures.asBuffer
/** /**
* Create a [Flow] from buffer * Create a [Flow] from buffer
@ -45,6 +47,13 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
*/ */
fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow { fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
require(bufferSize > 0) { "Resulting chunk size must be more than zero" } require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
if (this@chunked is BlockingRealChain) {
//performance optimization for blocking primitive chain
while (true) {
emit(nextBlock(bufferSize).asBuffer())
}
} else {
val array = DoubleArray(bufferSize) val array = DoubleArray(bufferSize)
var counter = 0 var counter = 0
@ -60,6 +69,7 @@ fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
if (counter > 0) { if (counter > 0) {
emit(DoubleBuffer(counter) { array[it] }) emit(DoubleBuffer(counter) { array[it] })
} }
}
} }
/** /**

View File

@ -5,7 +5,7 @@ import org.apache.commons.rng.simple.RandomSource
class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator { class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator {
internal val random: UniformRandomProvider = seed?.let { internal val random: UniformRandomProvider = seed?.let {
RandomSource.create(source, seed, null) RandomSource.create(source, seed)
} ?: RandomSource.create(source) } ?: RandomSource.create(source)
override fun nextBoolean(): Boolean = random.nextBoolean() override fun nextBoolean(): Boolean = random.nextBoolean()
@ -59,3 +59,9 @@ fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this
} else { } else {
RandomGeneratorProvider(this) RandomGeneratorProvider(this)
} }
fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator =
RandomSourceGenerator(source, seed)
fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator =
fromSource(RandomSource.MT, seed)

View File

@ -2,6 +2,8 @@ package scientifik.kmath.prob
import org.apache.commons.rng.UniformRandomProvider import org.apache.commons.rng.UniformRandomProvider
import org.apache.commons.rng.sampling.distribution.* import org.apache.commons.rng.sampling.distribution.*
import scientifik.kmath.chains.BlockingIntChain
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.chains.Chain import scientifik.kmath.chains.Chain
import java.util.* import java.util.*
import kotlin.math.PI import kotlin.math.PI
@ -11,32 +13,32 @@ import kotlin.math.sqrt
abstract class ContinuousSamplerDistribution : Distribution<Double> { abstract class ContinuousSamplerDistribution : Distribution<Double> {
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain<Double> { private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() {
private val sampler = buildSampler(generator) private val sampler = buildCMSampler(generator)
override suspend fun next(): Double = sampler.sample() override fun nextDouble(): Double = sampler.sample()
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork()) override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
} }
protected abstract fun buildSampler(generator: RandomGenerator): ContinuousSampler protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler
override fun sample(generator: RandomGenerator): Chain<Double> = ContinuousSamplerChain(generator) override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator)
} }
abstract class DiscreteSamplerDistribution : Distribution<Int> { abstract class DiscreteSamplerDistribution : Distribution<Int> {
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain<Int> { private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() {
private val sampler = buildSampler(generator) private val sampler = buildSampler(generator)
override suspend fun next(): Int = sampler.sample() override fun nextInt(): Int = sampler.sample()
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork()) override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
} }
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
override fun sample(generator: RandomGenerator): Chain<Int> = ContinuousSamplerChain(generator) override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator)
} }
enum class NormalSamplerMethod { enum class NormalSamplerMethod {
@ -55,7 +57,7 @@ private fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomPr
fun Distribution.Companion.normal( fun Distribution.Companion.normal(
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
): Distribution<Double> = object : ContinuousSamplerDistribution() { ): Distribution<Double> = object : ContinuousSamplerDistribution() {
override fun buildSampler(generator: RandomGenerator): ContinuousSampler { override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
val provider: UniformRandomProvider = generator.asUniformRandomProvider() val provider: UniformRandomProvider = generator.asUniformRandomProvider()
return normalSampler(method, provider) return normalSampler(method, provider)
} }
@ -69,11 +71,11 @@ fun Distribution.Companion.normal(
mean: Double, mean: Double,
sigma: Double, sigma: Double,
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
): Distribution<Double> = object : ContinuousSamplerDistribution() { ): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() {
private val sigma2 = sigma.pow(2) private val sigma2 = sigma.pow(2)
private val norm = sigma * sqrt(PI * 2) private val norm = sigma * sqrt(PI * 2)
override fun buildSampler(generator: RandomGenerator): ContinuousSampler { override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
val provider: UniformRandomProvider = generator.asUniformRandomProvider() val provider: UniformRandomProvider = generator.asUniformRandomProvider()
val normalizedSampler = normalSampler(method, provider) val normalizedSampler = normalSampler(method, provider)
return GaussianSampler(normalizedSampler, mean, sigma) return GaussianSampler(normalizedSampler, mean, sigma)
@ -86,7 +88,7 @@ fun Distribution.Companion.normal(
fun Distribution.Companion.poisson( fun Distribution.Companion.poisson(
lambda: Double lambda: Double
): Distribution<Int> = object : DiscreteSamplerDistribution() { ): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() {
override fun buildSampler(generator: RandomGenerator): DiscreteSampler { override fun buildSampler(generator: RandomGenerator): DiscreteSampler {
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda) return PoissonSampler.of(generator.asUniformRandomProvider(), lambda)

View File

@ -8,12 +8,21 @@ import org.junit.jupiter.api.Test
class CommonsDistributionsTest { class CommonsDistributionsTest {
@Test @Test
fun testNormalDistribution(){ fun testNormalDistributionSuspend() {
val distribution = Distribution.normal(7.0,2.0) val distribution = Distribution.normal(7.0, 2.0)
val generator = RandomGenerator.default(1) val generator = RandomGenerator.default(1)
val sample = runBlocking { val sample = runBlocking {
distribution.sample(generator).take(1000).toList() distribution.sample(generator).take(1000).toList()
} }
Assertions.assertEquals(7.0, sample.average(), 0.1) Assertions.assertEquals(7.0, sample.average(), 0.1)
} }
@Test
fun testNormalDistributionBlocking() {
val distribution = Distribution.normal(7.0, 2.0)
val generator = RandomGenerator.default(1)
val sample = distribution.sample(generator).nextBlock(1000)
Assertions.assertEquals(7.0, sample.average(), 0.1)
}
} }