Implement Commons RNG-like samplers in kmath-prob module for Multiplatform #164

Merged
CommanderTvis merged 44 commits from feature/mp-samplers into dev 2021-03-31 09:25:44 +03:00
4 changed files with 22 additions and 31 deletions
Showing only changes of commit 822f960e9c - Show all commits

View File

@ -3,25 +3,24 @@ package scientifik.kmath.commons.prob
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
import org.apache.commons.rng.simple.RandomSource
import scientifik.kmath.chains.BlockingRealChain
import scientifik.kmath.prob.*
import scientifik.kmath.prob.RandomChain
import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.fromSource
import java.time.Duration
import java.time.Instant
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler
import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler
private suspend fun runChain(): Duration {
private suspend fun runKMathChained(): Duration {
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
val normal = Distribution.normal(NormalSamplerMethod.Ziggurat)
val chain = normal.sample(generator) as BlockingRealChain
val normal = KMathZigguratNormalizedGaussianSampler.of()
val chain = normal.sample(generator) as RandomChain<Double>
val startTime = Instant.now()
var sum = 0.0
repeat(10000001) { counter ->
sum += chain.nextDouble()
repeat(10000001) { counter ->
sum += chain.next()
if (counter % 100000 == 0) {
val duration = Duration.between(startTime, Instant.now())
@ -32,9 +31,9 @@ private suspend fun runChain(): Duration {
return Duration.between(startTime, Instant.now())
}
private fun runDirect(): Duration {
val provider = RandomSource.create(RandomSource.MT, 123L)
val sampler = ZigguratNormalizedGaussianSampler(provider)
private fun runApacheDirect(): Duration {
val rng = RandomSource.create(RandomSource.MT, 123L)
val sampler = ApacheZigguratNormalizedGaussianSampler.of<ApacheZigguratNormalizedGaussianSampler>(rng)
val startTime = Instant.now()
var sum = 0.0
@ -56,16 +55,9 @@ private fun runDirect(): Duration {
*/
fun main() {
runBlocking(Dispatchers.Default) {
val chainJob = async {
runChain()
}
val directJob = async {
runDirect()
}
println("Chain: ${chainJob.await()}")
println("Direct: ${directJob.await()}")
val chainJob = async { runKMathChained() }
val directJob = async { runApacheDirect() }
println("KMath Chained: ${chainJob.await()}")
println("Apache Direct: ${directJob.await()}")
}
}
}

View File

@ -3,9 +3,8 @@ package scientifik.kmath.commons.prob
import kotlinx.coroutines.runBlocking
import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.collectWithState
import scientifik.kmath.prob.Distribution
import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.normal
import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
@ -18,7 +17,7 @@ fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState()
fun main() {
val normal = Distribution.normal()
val normal = ZigguratNormalizedGaussianSampler.of()
val chain = normal.sample(RandomGenerator.default).mean()
runBlocking {

View File

@ -11,4 +11,4 @@ class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspen
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
}
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)

View File

@ -44,4 +44,4 @@ inline class DefaultGenerator(private val random: Random = Random) : RandomGener
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
}
}