diff --git a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt index 8f68a9d3f..e4a5bc534 100644 --- a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt @@ -3,19 +3,20 @@ package kscience.kmath.commons.prob import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking -import org.apache.commons.rng.simple.RandomSource -import kscience.kmath.prob.RandomChain import kscience.kmath.prob.RandomGenerator +import kscience.kmath.prob.blocking import kscience.kmath.prob.fromSource +import kscience.kmath.prob.samplers.GaussianSampler +import org.apache.commons.rng.simple.RandomSource import java.time.Duration import java.time.Instant -import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler -import kscience.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler +import org.apache.commons.rng.sampling.distribution.GaussianSampler as CMGaussianSampler +import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as CMZigguratNormalizedGaussianSampler private suspend fun runKMathChained(): Duration { val generator = RandomGenerator.fromSource(RandomSource.MT, 123L) - val normal = KMathZigguratNormalizedGaussianSampler.of() - val chain = normal.sample(generator) as RandomChain + val normal = GaussianSampler.of(7.0, 2.0) + val chain = normal.sample(generator).blocking() val startTime = Instant.now() var sum = 0.0 @@ -28,17 +29,23 @@ private suspend fun runKMathChained(): Duration { println("Chain sampler completed $counter elements in $duration: $meanValue") } } + return Duration.between(startTime, Instant.now()) } private fun runApacheDirect(): Duration { val rng = RandomSource.create(RandomSource.MT, 123L) - val sampler = ApacheZigguratNormalizedGaussianSampler.of(rng) + + val sampler = CMGaussianSampler.of( + CMZigguratNormalizedGaussianSampler.of(rng), + 7.0, + 2.0 + ) + val startTime = Instant.now() - var sum = 0.0 - repeat(10000001) { counter -> + repeat(10000001) { counter -> sum += sampler.sample() if (counter % 100000 == 0) { @@ -47,17 +54,16 @@ private fun runApacheDirect(): Duration { println("Direct sampler completed $counter elements in $duration: $meanValue") } } + return Duration.between(startTime, Instant.now()) } /** * Comparing chain sampling performance with direct sampling performance */ -fun main() { - runBlocking(Dispatchers.Default) { - val chainJob = async { runKMathChained() } - val directJob = async { runApacheDirect() } - println("KMath Chained: ${chainJob.await()}") - println("Apache Direct: ${directJob.await()}") - } +fun main(): Unit = runBlocking(Dispatchers.Default) { + val chainJob = async { runKMathChained() } + val directJob = async { runApacheDirect() } + println("KMath Chained: ${chainJob.await()}") + println("Apache Direct: ${directJob.await()}") }