Update example

This commit is contained in:
Iaroslav Postovalov 2020-09-27 15:52:18 +07:00
parent 53ebec2e01
commit b83293a057
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7

View File

@ -3,19 +3,20 @@ package kscience.kmath.commons.prob
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.apache.commons.rng.simple.RandomSource
import kscience.kmath.prob.RandomChain
import kscience.kmath.prob.RandomGenerator import kscience.kmath.prob.RandomGenerator
import kscience.kmath.prob.blocking
import kscience.kmath.prob.fromSource import kscience.kmath.prob.fromSource
import kscience.kmath.prob.samplers.GaussianSampler
import org.apache.commons.rng.simple.RandomSource
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler import org.apache.commons.rng.sampling.distribution.GaussianSampler as CMGaussianSampler
import kscience.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as CMZigguratNormalizedGaussianSampler
private suspend fun runKMathChained(): Duration { private suspend fun runKMathChained(): Duration {
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L) val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
val normal = KMathZigguratNormalizedGaussianSampler.of() val normal = GaussianSampler.of(7.0, 2.0)
val chain = normal.sample(generator) as RandomChain<Double> val chain = normal.sample(generator).blocking()
val startTime = Instant.now() val startTime = Instant.now()
var sum = 0.0 var sum = 0.0
@ -28,17 +29,23 @@ private suspend fun runKMathChained(): Duration {
println("Chain sampler completed $counter elements in $duration: $meanValue") println("Chain sampler completed $counter elements in $duration: $meanValue")
} }
} }
return Duration.between(startTime, Instant.now()) return Duration.between(startTime, Instant.now())
} }
private fun runApacheDirect(): Duration { private fun runApacheDirect(): Duration {
val rng = RandomSource.create(RandomSource.MT, 123L) val rng = RandomSource.create(RandomSource.MT, 123L)
val sampler = ApacheZigguratNormalizedGaussianSampler.of<ApacheZigguratNormalizedGaussianSampler>(rng)
val sampler = CMGaussianSampler.of(
CMZigguratNormalizedGaussianSampler.of(rng),
7.0,
2.0
)
val startTime = Instant.now() val startTime = Instant.now()
var sum = 0.0 var sum = 0.0
repeat(10000001) { counter ->
repeat(10000001) { counter ->
sum += sampler.sample() sum += sampler.sample()
if (counter % 100000 == 0) { if (counter % 100000 == 0) {
@ -47,17 +54,16 @@ private fun runApacheDirect(): Duration {
println("Direct sampler completed $counter elements in $duration: $meanValue") println("Direct sampler completed $counter elements in $duration: $meanValue")
} }
} }
return Duration.between(startTime, Instant.now()) return Duration.between(startTime, Instant.now())
} }
/** /**
* Comparing chain sampling performance with direct sampling performance * Comparing chain sampling performance with direct sampling performance
*/ */
fun main() { fun main(): Unit = runBlocking(Dispatchers.Default) {
runBlocking(Dispatchers.Default) {
val chainJob = async { runKMathChained() } val chainJob = async { runKMathChained() }
val directJob = async { runApacheDirect() } val directJob = async { runApacheDirect() }
println("KMath Chained: ${chainJob.await()}") println("KMath Chained: ${chainJob.await()}")
println("Apache Direct: ${directJob.await()}") println("Apache Direct: ${directJob.await()}")
}
} }