Implement Commons RNG-like samplers in kmath-prob module for Multiplatform #164
@ -3,25 +3,24 @@ package scientifik.kmath.commons.prob
|
|||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.async
|
import kotlinx.coroutines.async
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
|
|
||||||
import org.apache.commons.rng.simple.RandomSource
|
import org.apache.commons.rng.simple.RandomSource
|
||||||
import scientifik.kmath.chains.BlockingRealChain
|
import scientifik.kmath.prob.RandomChain
|
||||||
import scientifik.kmath.prob.*
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
import scientifik.kmath.prob.fromSource
|
||||||
import java.time.Duration
|
import java.time.Duration
|
||||||
import java.time.Instant
|
import java.time.Instant
|
||||||
|
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler
|
||||||
|
import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler
|
||||||
|
|
||||||
|
private suspend fun runKMathChained(): Duration {
|
||||||
private suspend fun runChain(): Duration {
|
|
||||||
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
|
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
|
||||||
|
val normal = KMathZigguratNormalizedGaussianSampler.of()
|
||||||
val normal = Distribution.normal(NormalSamplerMethod.Ziggurat)
|
val chain = normal.sample(generator) as RandomChain<Double>
|
||||||
val chain = normal.sample(generator) as BlockingRealChain
|
|
||||||
|
|
||||||
val startTime = Instant.now()
|
val startTime = Instant.now()
|
||||||
var sum = 0.0
|
var sum = 0.0
|
||||||
repeat(10000001) { counter ->
|
|
||||||
|
|
||||||
sum += chain.nextDouble()
|
repeat(10000001) { counter ->
|
||||||
|
sum += chain.next()
|
||||||
|
|
||||||
if (counter % 100000 == 0) {
|
if (counter % 100000 == 0) {
|
||||||
val duration = Duration.between(startTime, Instant.now())
|
val duration = Duration.between(startTime, Instant.now())
|
||||||
@ -32,9 +31,9 @@ private suspend fun runChain(): Duration {
|
|||||||
return Duration.between(startTime, Instant.now())
|
return Duration.between(startTime, Instant.now())
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun runDirect(): Duration {
|
private fun runApacheDirect(): Duration {
|
||||||
val provider = RandomSource.create(RandomSource.MT, 123L)
|
val rng = RandomSource.create(RandomSource.MT, 123L)
|
||||||
val sampler = ZigguratNormalizedGaussianSampler(provider)
|
val sampler = ApacheZigguratNormalizedGaussianSampler.of<ApacheZigguratNormalizedGaussianSampler>(rng)
|
||||||
val startTime = Instant.now()
|
val startTime = Instant.now()
|
||||||
|
|
||||||
var sum = 0.0
|
var sum = 0.0
|
||||||
@ -56,16 +55,9 @@ private fun runDirect(): Duration {
|
|||||||
*/
|
*/
|
||||||
fun main() {
|
fun main() {
|
||||||
runBlocking(Dispatchers.Default) {
|
runBlocking(Dispatchers.Default) {
|
||||||
val chainJob = async {
|
val chainJob = async { runKMathChained() }
|
||||||
runChain()
|
val directJob = async { runApacheDirect() }
|
||||||
}
|
println("KMath Chained: ${chainJob.await()}")
|
||||||
|
println("Apache Direct: ${directJob.await()}")
|
||||||
val directJob = async {
|
|
||||||
runDirect()
|
|
||||||
}
|
|
||||||
|
|
||||||
println("Chain: ${chainJob.await()}")
|
|
||||||
println("Direct: ${directJob.await()}")
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
|
||||||
|
@ -3,9 +3,8 @@ package scientifik.kmath.commons.prob
|
|||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import scientifik.kmath.chains.Chain
|
import scientifik.kmath.chains.Chain
|
||||||
import scientifik.kmath.chains.collectWithState
|
import scientifik.kmath.chains.collectWithState
|
||||||
import scientifik.kmath.prob.Distribution
|
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
import scientifik.kmath.prob.normal
|
import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler
|
||||||
|
|
||||||
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
||||||
|
|
||||||
@ -18,7 +17,7 @@ fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState()
|
|||||||
|
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
val normal = Distribution.normal()
|
val normal = ZigguratNormalizedGaussianSampler.of()
|
||||||
val chain = normal.sample(RandomGenerator.default).mean()
|
val chain = normal.sample(RandomGenerator.default).mean()
|
||||||
|
|
||||||
runBlocking {
|
runBlocking {
|
||||||
|
@ -11,4 +11,4 @@ class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspen
|
|||||||
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
|
||||||
|
@ -44,4 +44,4 @@ inline class DefaultGenerator(private val random: Random = Random) : RandomGener
|
|||||||
|
|
||||||
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
|
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
|
||||||
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
|
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user