Dev #280
@ -3,19 +3,20 @@ package kscience.kmath.commons.prob
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.apache.commons.rng.simple.RandomSource
|
||||
import kscience.kmath.prob.RandomChain
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.blocking
|
||||
import kscience.kmath.prob.fromSource
|
||||
import kscience.kmath.prob.samplers.GaussianSampler
|
||||
import org.apache.commons.rng.simple.RandomSource
|
||||
import java.time.Duration
|
||||
import java.time.Instant
|
||||
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler
|
||||
import kscience.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler
|
||||
import org.apache.commons.rng.sampling.distribution.GaussianSampler as CMGaussianSampler
|
||||
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as CMZigguratNormalizedGaussianSampler
|
||||
|
||||
private suspend fun runKMathChained(): Duration {
|
||||
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
|
||||
val normal = KMathZigguratNormalizedGaussianSampler.of()
|
||||
val chain = normal.sample(generator) as RandomChain<Double>
|
||||
val normal = GaussianSampler.of(7.0, 2.0)
|
||||
val chain = normal.sample(generator).blocking()
|
||||
val startTime = Instant.now()
|
||||
var sum = 0.0
|
||||
|
||||
@ -28,17 +29,23 @@ private suspend fun runKMathChained(): Duration {
|
||||
println("Chain sampler completed $counter elements in $duration: $meanValue")
|
||||
}
|
||||
}
|
||||
|
||||
return Duration.between(startTime, Instant.now())
|
||||
}
|
||||
|
||||
private fun runApacheDirect(): Duration {
|
||||
val rng = RandomSource.create(RandomSource.MT, 123L)
|
||||
val sampler = ApacheZigguratNormalizedGaussianSampler.of<ApacheZigguratNormalizedGaussianSampler>(rng)
|
||||
|
||||
val sampler = CMGaussianSampler.of(
|
||||
CMZigguratNormalizedGaussianSampler.of(rng),
|
||||
7.0,
|
||||
2.0
|
||||
)
|
||||
|
||||
val startTime = Instant.now()
|
||||
|
||||
var sum = 0.0
|
||||
repeat(10000001) { counter ->
|
||||
|
||||
repeat(10000001) { counter ->
|
||||
sum += sampler.sample()
|
||||
|
||||
if (counter % 100000 == 0) {
|
||||
@ -47,17 +54,16 @@ private fun runApacheDirect(): Duration {
|
||||
println("Direct sampler completed $counter elements in $duration: $meanValue")
|
||||
}
|
||||
}
|
||||
|
||||
return Duration.between(startTime, Instant.now())
|
||||
}
|
||||
|
||||
/**
|
||||
* Comparing chain sampling performance with direct sampling performance
|
||||
*/
|
||||
fun main() {
|
||||
runBlocking(Dispatchers.Default) {
|
||||
fun main(): Unit = runBlocking(Dispatchers.Default) {
|
||||
val chainJob = async { runKMathChained() }
|
||||
val directJob = async { runApacheDirect() }
|
||||
println("KMath Chained: ${chainJob.await()}")
|
||||
println("Apache Direct: ${directJob.await()}")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user