Fix broken demos, add newlines at the end of files

This commit is contained in:
Iaroslav 2020-06-08 18:19:18 +07:00
parent 246feacd72
commit 822f960e9c
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
4 changed files with 22 additions and 31 deletions

View File

@ -3,25 +3,24 @@ package scientifik.kmath.commons.prob
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
import org.apache.commons.rng.simple.RandomSource import org.apache.commons.rng.simple.RandomSource
import scientifik.kmath.chains.BlockingRealChain import scientifik.kmath.prob.RandomChain
import scientifik.kmath.prob.* import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.fromSource
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler as ApacheZigguratNormalizedGaussianSampler
import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler as KMathZigguratNormalizedGaussianSampler
private suspend fun runKMathChained(): Duration {
private suspend fun runChain(): Duration {
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L) val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
val normal = KMathZigguratNormalizedGaussianSampler.of()
val normal = Distribution.normal(NormalSamplerMethod.Ziggurat) val chain = normal.sample(generator) as RandomChain<Double>
val chain = normal.sample(generator) as BlockingRealChain
val startTime = Instant.now() val startTime = Instant.now()
var sum = 0.0 var sum = 0.0
repeat(10000001) { counter ->
sum += chain.nextDouble() repeat(10000001) { counter ->
sum += chain.next()
if (counter % 100000 == 0) { if (counter % 100000 == 0) {
val duration = Duration.between(startTime, Instant.now()) val duration = Duration.between(startTime, Instant.now())
@ -32,9 +31,9 @@ private suspend fun runChain(): Duration {
return Duration.between(startTime, Instant.now()) return Duration.between(startTime, Instant.now())
} }
private fun runDirect(): Duration { private fun runApacheDirect(): Duration {
val provider = RandomSource.create(RandomSource.MT, 123L) val rng = RandomSource.create(RandomSource.MT, 123L)
val sampler = ZigguratNormalizedGaussianSampler(provider) val sampler = ApacheZigguratNormalizedGaussianSampler.of<ApacheZigguratNormalizedGaussianSampler>(rng)
val startTime = Instant.now() val startTime = Instant.now()
var sum = 0.0 var sum = 0.0
@ -56,16 +55,9 @@ private fun runDirect(): Duration {
*/ */
fun main() { fun main() {
runBlocking(Dispatchers.Default) { runBlocking(Dispatchers.Default) {
val chainJob = async { val chainJob = async { runKMathChained() }
runChain() val directJob = async { runApacheDirect() }
} println("KMath Chained: ${chainJob.await()}")
println("Apache Direct: ${directJob.await()}")
val directJob = async {
runDirect()
}
println("Chain: ${chainJob.await()}")
println("Direct: ${directJob.await()}")
} }
}
}

View File

@ -3,9 +3,8 @@ package scientifik.kmath.commons.prob
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import scientifik.kmath.chains.Chain import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.collectWithState import scientifik.kmath.chains.collectWithState
import scientifik.kmath.prob.Distribution
import scientifik.kmath.prob.RandomGenerator import scientifik.kmath.prob.RandomGenerator
import scientifik.kmath.prob.normal import scientifik.kmath.prob.samplers.ZigguratNormalizedGaussianSampler
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
@ -18,7 +17,7 @@ fun Chain<Double>.mean(): Chain<Double> = collectWithState(AveragingChainState()
fun main() { fun main() {
val normal = Distribution.normal() val normal = ZigguratNormalizedGaussianSampler.of()
val chain = normal.sample(RandomGenerator.default).mean() val chain = normal.sample(RandomGenerator.default).mean()
runBlocking { runBlocking {

View File

@ -11,4 +11,4 @@ class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspen
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen) override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
} }
fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen) fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)

View File

@ -44,4 +44,4 @@ inline class DefaultGenerator(private val random: Random = Random) : RandomGener
override fun nextBytes(size: Int): ByteArray = random.nextBytes(size) override fun nextBytes(size: Int): ByteArray = random.nextBytes(size)
override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong()) override fun fork(): RandomGenerator = RandomGenerator.default(random.nextLong())
} }