diff --git a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionBenchmark.kt b/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionBenchmark.kt new file mode 100644 index 000000000..b060cddb6 --- /dev/null +++ b/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionBenchmark.kt @@ -0,0 +1,71 @@ +package scientifik.kmath.commons.prob + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.runBlocking +import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler +import org.apache.commons.rng.simple.RandomSource +import scientifik.kmath.chains.BlockingRealChain +import scientifik.kmath.prob.* +import java.time.Duration +import java.time.Instant + + +private suspend fun runChain(): Duration { + val generator = RandomGenerator.fromSource(RandomSource.MT, 123L) + + val normal = Distribution.normal(NormalSamplerMethod.Ziggurat) + val chain = normal.sample(generator) as BlockingRealChain + + val startTime = Instant.now() + var sum = 0.0 + repeat(10000001) { counter -> + + sum += chain.nextDouble() + + if (counter % 100000 == 0) { + val duration = Duration.between(startTime, Instant.now()) + val meanValue = sum / counter + println("Chain sampler completed $counter elements in $duration: $meanValue") + } + } + return Duration.between(startTime, Instant.now()) +} + +private fun runDirect(): Duration { + val provider = RandomSource.create(RandomSource.MT, 123L) + val sampler = ZigguratNormalizedGaussianSampler(provider) + val startTime = Instant.now() + + var sum = 0.0 + repeat(10000001) { counter -> + + sum += sampler.sample() + + if (counter % 100000 == 0) { + val duration = Duration.between(startTime, Instant.now()) + val meanValue = sum / counter + println("Direct sampler completed $counter elements in $duration: $meanValue") + } + } + return Duration.between(startTime, Instant.now()) +} + +/** + * Comparing chain sampling performance with direct sampling performance + */ +fun main() { + runBlocking(Dispatchers.Default) { + val chainJob = async { + runChain() + } + + val directJob = async { + runDirect() + } + + println("Chain: ${chainJob.await()}") + println("Direct: ${directJob.await()}") + } + +} \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt new file mode 100644 index 000000000..6ec84d5c7 --- /dev/null +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingIntChain.kt @@ -0,0 +1,12 @@ +package scientifik.kmath.chains + +/** + * Performance optimized chain for integer values + */ +abstract class BlockingIntChain : Chain { + abstract fun nextInt(): Int + + override suspend fun next(): Int = nextInt() + + fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() } +} \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt new file mode 100644 index 000000000..6b69d2734 --- /dev/null +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/BlockingRealChain.kt @@ -0,0 +1,12 @@ +package scientifik.kmath.chains + +/** + * Performance optimized chain for real values + */ +abstract class BlockingRealChain : Chain { + abstract fun nextDouble(): Double + + override suspend fun next(): Double = nextDouble() + + fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() } +} \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt index 85db55480..5635499e5 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt @@ -48,7 +48,6 @@ interface Chain: Flow { } companion object - } diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt index cf4d4cc17..bef21a680 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/streaming/BufferFlow.kt @@ -2,9 +2,11 @@ package scientifik.kmath.streaming import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.flow.* +import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.DoubleBuffer +import scientifik.kmath.structures.asBuffer /** * Create a [Flow] from buffer @@ -45,20 +47,28 @@ fun Flow.chunked(bufferSize: Int, bufferFactory: BufferFactory): Flow< */ fun Flow.chunked(bufferSize: Int): Flow = flow { require(bufferSize > 0) { "Resulting chunk size must be more than zero" } - val array = DoubleArray(bufferSize) - var counter = 0 - this@chunked.collect { element -> - array[counter] = element - counter++ - if (counter == bufferSize) { - val buffer = DoubleBuffer(array) - emit(buffer) - counter = 0 + if (this@chunked is BlockingRealChain) { + //performance optimization for blocking primitive chain + while (true) { + emit(nextBlock(bufferSize).asBuffer()) + } + } else { + val array = DoubleArray(bufferSize) + var counter = 0 + + this@chunked.collect { element -> + array[counter] = element + counter++ + if (counter == bufferSize) { + val buffer = DoubleBuffer(array) + emit(buffer) + counter = 0 + } + } + if (counter > 0) { + emit(DoubleBuffer(counter) { array[it] }) } - } - if (counter > 0) { - emit(DoubleBuffer(counter) { array[it] }) } } diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt index 2a97fa4be..f5a73a08b 100644 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/RandomSourceGenerator.kt @@ -5,7 +5,7 @@ import org.apache.commons.rng.simple.RandomSource class RandomSourceGenerator(val source: RandomSource, seed: Long?) : RandomGenerator { internal val random: UniformRandomProvider = seed?.let { - RandomSource.create(source, seed, null) + RandomSource.create(source, seed) } ?: RandomSource.create(source) override fun nextBoolean(): Boolean = random.nextBoolean() @@ -59,3 +59,9 @@ fun RandomGenerator.asUniformRandomProvider(): UniformRandomProvider = if (this } else { RandomGeneratorProvider(this) } + +fun RandomGenerator.Companion.fromSource(source: RandomSource, seed: Long? = null): RandomSourceGenerator = + RandomSourceGenerator(source, seed) + +fun RandomGenerator.Companion.mersenneTwister(seed: Long? = null): RandomSourceGenerator = + fromSource(RandomSource.MT, seed) diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt index 59588d0ef..412454994 100644 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/distributions.kt @@ -2,6 +2,8 @@ package scientifik.kmath.prob import org.apache.commons.rng.UniformRandomProvider import org.apache.commons.rng.sampling.distribution.* +import scientifik.kmath.chains.BlockingIntChain +import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.chains.Chain import java.util.* import kotlin.math.PI @@ -11,32 +13,32 @@ import kotlin.math.sqrt abstract class ContinuousSamplerDistribution : Distribution { - private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain { - private val sampler = buildSampler(generator) + private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() { + private val sampler = buildCMSampler(generator) - override suspend fun next(): Double = sampler.sample() + override fun nextDouble(): Double = sampler.sample() override fun fork(): Chain = ContinuousSamplerChain(generator.fork()) } - protected abstract fun buildSampler(generator: RandomGenerator): ContinuousSampler + protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler - override fun sample(generator: RandomGenerator): Chain = ContinuousSamplerChain(generator) + override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator) } abstract class DiscreteSamplerDistribution : Distribution { - private inner class ContinuousSamplerChain(val generator: RandomGenerator) : Chain { + private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() { private val sampler = buildSampler(generator) - override suspend fun next(): Int = sampler.sample() + override fun nextInt(): Int = sampler.sample() override fun fork(): Chain = ContinuousSamplerChain(generator.fork()) } protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler - override fun sample(generator: RandomGenerator): Chain = ContinuousSamplerChain(generator) + override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator) } enum class NormalSamplerMethod { @@ -55,7 +57,7 @@ private fun normalSampler(method: NormalSamplerMethod, provider: UniformRandomPr fun Distribution.Companion.normal( method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat ): Distribution = object : ContinuousSamplerDistribution() { - override fun buildSampler(generator: RandomGenerator): ContinuousSampler { + override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { val provider: UniformRandomProvider = generator.asUniformRandomProvider() return normalSampler(method, provider) } @@ -69,11 +71,11 @@ fun Distribution.Companion.normal( mean: Double, sigma: Double, method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat -): Distribution = object : ContinuousSamplerDistribution() { +): ContinuousSamplerDistribution = object : ContinuousSamplerDistribution() { private val sigma2 = sigma.pow(2) private val norm = sigma * sqrt(PI * 2) - override fun buildSampler(generator: RandomGenerator): ContinuousSampler { + override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { val provider: UniformRandomProvider = generator.asUniformRandomProvider() val normalizedSampler = normalSampler(method, provider) return GaussianSampler(normalizedSampler, mean, sigma) @@ -86,7 +88,7 @@ fun Distribution.Companion.normal( fun Distribution.Companion.poisson( lambda: Double -): Distribution = object : DiscreteSamplerDistribution() { +): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() { override fun buildSampler(generator: RandomGenerator): DiscreteSampler { return PoissonSampler.of(generator.asUniformRandomProvider(), lambda) diff --git a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt index d6fd629b7..7638c695e 100644 --- a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt +++ b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/CommonsDistributionsTest.kt @@ -8,12 +8,21 @@ import org.junit.jupiter.api.Test class CommonsDistributionsTest { @Test - fun testNormalDistribution(){ - val distribution = Distribution.normal(7.0,2.0) + fun testNormalDistributionSuspend() { + val distribution = Distribution.normal(7.0, 2.0) val generator = RandomGenerator.default(1) val sample = runBlocking { distribution.sample(generator).take(1000).toList() } Assertions.assertEquals(7.0, sample.average(), 0.1) } + + @Test + fun testNormalDistributionBlocking() { + val distribution = Distribution.normal(7.0, 2.0) + val generator = RandomGenerator.default(1) + val sample = distribution.sample(generator).nextBlock(1000) + Assertions.assertEquals(7.0, sample.average(), 0.1) + } + } \ No newline at end of file