Dev #280
@ -3,19 +3,20 @@ package kscience.kmath.commons.prob
|
|||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.async
|
import kotlinx.coroutines.async
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import org.apache.commons.rng.simple.RandomSource
|
|
||||||
import kscience.kmath.prob.RandomChain
|
|
||||||
import kscience.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
|
import kscience.kmath.prob.blocking
|
||||||
import kscience.kmath.prob.fromSource
|
import kscience.kmath.prob.fromSource
|
||||||
|
import kscience.kmath.prob.samplers.GaussianSampler
|
||||||
|
import org.apache.commons.rng.simple.RandomSource
|
||||||
import java.time.Duration
|
import java.time.Duration
|
||||||
import java.time.Instant
|
import java.time.Instant
|
||||||
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler
|
import org.apache.commons.rng.sampling.distribution.GaussianSampler as CMGaussianSampler
|
||||||
import kscience.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler
|
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as CMZigguratNormalizedGaussianSampler
|
||||||
|
|
||||||
private suspend fun runKMathChained(): Duration {
|
private suspend fun runKMathChained(): Duration {
|
||||||
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
|
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
|
||||||
val normal = KMathZigguratNormalizedGaussianSampler.of()
|
val normal = GaussianSampler.of(7.0, 2.0)
|
||||||
val chain = normal.sample(generator) as RandomChain<Double>
|
val chain = normal.sample(generator).blocking()
|
||||||
val startTime = Instant.now()
|
val startTime = Instant.now()
|
||||||
var sum = 0.0
|
var sum = 0.0
|
||||||
|
|
||||||
@ -28,17 +29,23 @@ private suspend fun runKMathChained(): Duration {
|
|||||||
println("Chain sampler completed $counter elements in $duration: $meanValue")
|
println("Chain sampler completed $counter elements in $duration: $meanValue")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Duration.between(startTime, Instant.now())
|
return Duration.between(startTime, Instant.now())
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun runApacheDirect(): Duration {
|
private fun runApacheDirect(): Duration {
|
||||||
val rng = RandomSource.create(RandomSource.MT, 123L)
|
val rng = RandomSource.create(RandomSource.MT, 123L)
|
||||||
val sampler = ApacheZigguratNormalizedGaussianSampler.of<ApacheZigguratNormalizedGaussianSampler>(rng)
|
|
||||||
|
val sampler = CMGaussianSampler.of(
|
||||||
|
CMZigguratNormalizedGaussianSampler.of(rng),
|
||||||
|
7.0,
|
||||||
|
2.0
|
||||||
|
)
|
||||||
|
|
||||||
val startTime = Instant.now()
|
val startTime = Instant.now()
|
||||||
|
|
||||||
var sum = 0.0
|
var sum = 0.0
|
||||||
repeat(10000001) { counter ->
|
|
||||||
|
|
||||||
|
repeat(10000001) { counter ->
|
||||||
sum += sampler.sample()
|
sum += sampler.sample()
|
||||||
|
|
||||||
if (counter % 100000 == 0) {
|
if (counter % 100000 == 0) {
|
||||||
@ -47,17 +54,16 @@ private fun runApacheDirect(): Duration {
|
|||||||
println("Direct sampler completed $counter elements in $duration: $meanValue")
|
println("Direct sampler completed $counter elements in $duration: $meanValue")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Duration.between(startTime, Instant.now())
|
return Duration.between(startTime, Instant.now())
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Comparing chain sampling performance with direct sampling performance
|
* Comparing chain sampling performance with direct sampling performance
|
||||||
*/
|
*/
|
||||||
fun main() {
|
fun main(): Unit = runBlocking(Dispatchers.Default) {
|
||||||
runBlocking(Dispatchers.Default) {
|
|
||||||
val chainJob = async { runKMathChained() }
|
val chainJob = async { runKMathChained() }
|
||||||
val directJob = async { runApacheDirect() }
|
val directJob = async { runApacheDirect() }
|
||||||
println("KMath Chained: ${chainJob.await()}")
|
println("KMath Chained: ${chainJob.await()}")
|
||||||
println("Apache Direct: ${directJob.await()}")
|
println("Apache Direct: ${directJob.await()}")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user