Dev #280
@ -3,25 +3,24 @@ package scientifik.kmath.commons.prob
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
|
||||
import org.apache.commons.rng.simple.RandomSource
|
||||
import scientifik.kmath.chains.BlockingRealChain
|
||||
import scientifik.kmath.prob.*
|
||||
import scientifik.kmath.prob.RandomChain
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.fromSource
|
||||
import java.time.Duration
|
||||
import java.time.Instant
|
||||
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler
|
||||
import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler
|
||||
|
||||
|
||||
private suspend fun runChain(): Duration {
|
||||
private suspend fun runKMathChained(): Duration {
|
||||
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
|
||||
|
||||
val normal = Distribution.normal(NormalSamplerMethod.Ziggurat)
|
||||
val chain = normal.sample(generator) as BlockingRealChain
|
||||
|
||||
val normal = KMathZigguratNormalizedGaussianSampler.of()
|
||||
val chain = normal.sample(generator) as RandomChain<Double>
|
||||
val startTime = Instant.now()
|
||||
var sum = 0.0
|
||||
repeat(10000001) { counter ->
|
||||
|
||||
sum += chain.nextDouble()
|
||||
repeat(10000001) { counter ->
|
||||
sum += chain.next()
|
||||
|
||||
if (counter % 100000 == 0) {
|
||||
val duration = Duration.between(startTime, Instant.now())
|
||||
@ -32,9 +31,9 @@ private suspend fun runChain(): Duration {
|
||||
return Duration.between(startTime, Instant.now())
|
||||
}
|
||||
|
||||
private fun runDirect(): Duration {
|
||||
val provider = RandomSource.create(RandomSource.MT, 123L)
|
||||
val sampler = ZigguratNormalizedGaussianSampler(provider)
|
||||
private fun runApacheDirect(): Duration {
|
||||
val rng = RandomSource.create(RandomSource.MT, 123L)
|
||||
val sampler = ApacheZigguratNormalizedGaussianSampler.of<ApacheZigguratNormalizedGaussianSampler>(rng)
|
||||
val startTime = Instant.now()
|
||||
|
||||
var sum = 0.0
|
||||
@ -56,16 +55,9 @@ private fun runDirect(): Duration {
|
||||
*/
|
||||
fun main() {
|
||||
runBlocking(Dispatchers.Default) {
|
||||
val chainJob = async {
|
||||
runChain()
|
||||
}
|
||||
|
||||
val directJob = async {
|
||||
runDirect()
|
||||
}
|
||||
|
||||
println("Chain: ${chainJob.await()}")
|
||||
println("Direct: ${directJob.await()}")
|
||||
val chainJob = async { runKMathChained() }
|
||||
val directJob = async { runApacheDirect() }
|
||||
println("KMath Chained: ${chainJob.await()}")
|
||||
println("Apache Direct: ${directJob.await()}")
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -3,9 +3,8 @@ package scientifik.kmath.commons.prob
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.chains.collectWithState
|
||||
import scientifik.kmath.prob.Distribution
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import scientifik.kmath.prob.normal
|
||||
import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler
|
||||
|
||||
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
||||
|
||||
@ -18,7 +17,7 @@ fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState()
|
||||
|
||||
|
||||
fun main() {
|
||||
val normal = Distribution.normal()
|
||||
val normal = ZigguratNormalizedGaussianSampler.of()
|
||||
val chain = normal.sample(RandomGenerator.default).mean()
|
||||
|
||||
runBlocking {
|
||||
|
@ -11,4 +11,4 @@ class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspen
|
||||
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
||||
}
|
||||
|
||||
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
||||
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
||||
|
@ -44,4 +44,4 @@ inline class DefaultGenerator(private val random: Random = Random) : RandomGener
|
||||
|
||||
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
|
||||
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user