Update example

This commit is contained in:
Iaroslav Postovalov 2020-09-27 15:52:18 +07:00
parent 53ebec2e01
commit b83293a057
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7

View File

@ -3,19 +3,20 @@ package kscience.kmath.commons.prob
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import org.apache.commons.rng.simple.RandomSource
import kscience.kmath.prob.RandomChain
import kscience.kmath.prob.RandomGenerator
import kscience.kmath.prob.blocking
import kscience.kmath.prob.fromSource
import kscience.kmath.prob.samplers.GaussianSampler
import org.apache.commons.rng.simple.RandomSource
import java.time.Duration
import java.time.Instant
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler
import kscience.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler
import org.apache.commons.rng.sampling.distribution.GaussianSampler as CMGaussianSampler
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as CMZigguratNormalizedGaussianSampler
private suspend fun runKMathChained(): Duration {
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
val normal = KMathZigguratNormalizedGaussianSampler.of()
val chain = normal.sample(generator) as RandomChain<Double>
val normal = GaussianSampler.of(7.0, 2.0)
val chain = normal.sample(generator).blocking()
val startTime = Instant.now()
var sum = 0.0
@ -28,17 +29,23 @@ private suspend fun runKMathChained(): Duration {
println("Chain sampler completed $counter elements in $duration: $meanValue")
}
}
return Duration.between(startTime, Instant.now())
}
private fun runApacheDirect(): Duration {
val rng = RandomSource.create(RandomSource.MT, 123L)
val sampler = ApacheZigguratNormalizedGaussianSampler.of<ApacheZigguratNormalizedGaussianSampler>(rng)
val sampler = CMGaussianSampler.of(
CMZigguratNormalizedGaussianSampler.of(rng),
7.0,
2.0
)
val startTime = Instant.now()
var sum = 0.0
repeat(10000001) { counter ->
repeat(10000001) { counter ->
sum += sampler.sample()
if (counter % 100000 == 0) {
@ -47,17 +54,16 @@ private fun runApacheDirect(): Duration {
println("Direct sampler completed $counter elements in $duration: $meanValue")
}
}
return Duration.between(startTime, Instant.now())
}
/**
* Comparing chain sampling performance with direct sampling performance
*/
fun main() {
runBlocking(Dispatchers.Default) {
val chainJob = async { runKMathChained() }
val directJob = async { runApacheDirect() }
println("KMath Chained: ${chainJob.await()}")
println("Apache Direct: ${directJob.await()}")
}
fun main(): Unit = runBlocking(Dispatchers.Default) {
val chainJob = async { runKMathChained() }
val directJob = async { runApacheDirect() }
println("KMath Chained: ${chainJob.await()}")
println("Apache Direct: ${directJob.await()}")
}