Implement Commons RNG-like samplers in kmath-prob module for Multiplatform #164

Merged
CommanderTvis merged 44 commits from feature/mp-samplers into dev 2021-03-31 09:25:44 +03:00
4 changed files with 22 additions and 31 deletions
Showing only changes of commit 822f960e9c - Show all commits

View File

@ -3,25 +3,24 @@ package scientifik.kmath.commons.prob
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
import org.apache.commons.rng.simple.RandomSource import org.apache.commons.rng.simple.RandomSource
import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.prob.RandomChain
import scientifik.kmath.prob.* import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.fromSource
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler
import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler
private suspend fun runKMathChained(): Duration {
private suspend fun runChain(): Duration {
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L) val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
val normal = KMathZigguratNormalizedGaussianSampler.of()
val normal = Distribution.normal(NormalSamplerMethod.Ziggurat) val chain = normal.sample(generator) as RandomChain<Double>
val chain = normal.sample(generator) as BlockingRealChain
val startTime = Instant.now() val startTime = Instant.now()
var sum = 0.0 var sum = 0.0
repeat(10000001) { counter ->
sum += chain.nextDouble() repeat(10000001) { counter ->
sum += chain.next()
if (counter % 100000 == 0) { if (counter % 100000 == 0) {
val duration = Duration.between(startTime, Instant.now()) val duration = Duration.between(startTime, Instant.now())
@ -32,9 +31,9 @@ private suspend fun runChain(): Duration {
return Duration.between(startTime, Instant.now()) return Duration.between(startTime, Instant.now())
} }
private fun runDirect(): Duration { private fun runApacheDirect(): Duration {
val provider = RandomSource.create(RandomSource.MT, 123L) val rng = RandomSource.create(RandomSource.MT, 123L)
val sampler = ZigguratNormalizedGaussianSampler(provider) val sampler = ApacheZigguratNormalizedGaussianSampler.of<ApacheZigguratNormalizedGaussianSampler>(rng)
val startTime = Instant.now() val startTime = Instant.now()
var sum = 0.0 var sum = 0.0
@ -56,16 +55,9 @@ private fun runDirect(): Duration {
*/ */
fun main() { fun main() {
runBlocking(Dispatchers.Default) { runBlocking(Dispatchers.Default) {
val chainJob = async { val chainJob = async { runKMathChained() }
runChain() val directJob = async { runApacheDirect() }
println("KMath Chained: ${chainJob.await()}")
println("Apache Direct: ${directJob.await()}")
} }
val directJob = async {
runDirect()
}
println("Chain: ${chainJob.await()}")
println("Direct: ${directJob.await()}")
}
} }

View File

@ -3,9 +3,8 @@ package scientifik.kmath.commons.prob
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import scientifik.kmath.chains.Chain import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.collectWithState import scientifik.kmath.chains.collectWithState
import scientifik.kmath.prob.Distribution
import scientifik.kmath.prob.RandomGenerator import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.normal import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
@ -18,7 +17,7 @@ fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState()
fun main() { fun main() {
val normal = Distribution.normal() val normal = ZigguratNormalizedGaussianSampler.of()
val chain = normal.sample(RandomGenerator.default).mean() val chain = normal.sample(RandomGenerator.default).mean()
runBlocking { runBlocking {