Proof of concept for coroutine-safe random number generators

This commit is contained in:
Alexander Nozik 2019-06-15 15:17:21 +03:00
parent 651f8323f3
commit c267372c70
5 changed files with 106 additions and 2 deletions

View File

@ -8,7 +8,7 @@ fun main(args: Array<String>) {
val dim = 1000
val n = 1000
// automatically build context most suited for given type.
// automatically build coroutineContext most suited for given type.
val autoField = NDField.auto(RealField, dim, dim)
// specialized nd-field for Double. It works as generic Double field as well
val specializedField = NDField.real(dim, dim)

View File

@ -111,7 +111,7 @@ fun DiffExpression.derivative(vararg orders: Pair<String, Int>) = derivative(map
fun DiffExpression.derivative(name: String) = derivative(name to 1)
/**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
* A coroutineContext for [DiffExpression] (not to be confused with [DerivativeStructure])
*/
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
override fun variable(name: String, default: Double?) =

View File

@ -0,0 +1,58 @@
package scientifik.kmath.prob
import kotlinx.coroutines.*
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.coroutines.coroutineContext
/**
* A scope for a monte-carlo simulation or multi-coroutine random number generation
*/
class MCScope private constructor(val coroutineContext: CoroutineContext) {
fun mcScope(context: CoroutineContext): MCScope = MCScope(context + coroutineContext[Random]!!.split())
companion object {
suspend fun init(generator: RandomGenerator): MCScope =
MCScope(coroutineContext + Random(generator))
}
}
inline class Random(val generator: RandomGenerator) : CoroutineContext.Element {
override val key: CoroutineContext.Key<*> get() = Companion
fun split(): Random = Random(generator.fork())
companion object : CoroutineContext.Key<Random>
}
val CoroutineContext.random get() = this[Random]?.generator
val MCScope.random get() = coroutineContext.random!!
/**
* Launches a supervised Monte-Carlo scope
*/
suspend fun <T> mc(generator: RandomGenerator, block: suspend MCScope.() -> T): T =
supervisorScope {
MCScope.init(generator).block()
}
suspend fun <T> mc(seed: Long = -1, block: suspend MCScope.() -> T): T =
mc(SplitRandomWrapper(seed), block)
inline fun MCScope.launch(
context: CoroutineContext = EmptyCoroutineContext,
start: CoroutineStart = CoroutineStart.DEFAULT,
crossinline block: suspend MCScope.() -> Unit
): Job = CoroutineScope(coroutineContext).launch(context + coroutineContext[Random]!!.split(), start) {
mcScope(coroutineContext).block()
}
inline fun <T> MCScope.async(
context: CoroutineContext = EmptyCoroutineContext,
start: CoroutineStart = CoroutineStart.DEFAULT,
crossinline block: suspend MCScope.() -> T
): Deferred<T> = CoroutineScope(coroutineContext).async(context, start) {
mcScope(coroutineContext).block()
}

View File

@ -0,0 +1,18 @@
package scientifik.kmath.prob
import java.util.*
class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator {
constructor(seed: Long) : this(SplittableRandom(seed))
override fun nextDouble(): Double = random.nextDouble()
override fun nextInt(): Int = random.nextInt()
override fun nextLong(): Long = random.nextLong()
override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { random.nextBytes(it) }
override fun fork(): RandomGenerator = SplitRandomWrapper(random.split())
}

View File

@ -0,0 +1,28 @@
package scientifik.kmath.prob
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import kotlin.test.Test
import kotlin.test.assertEquals
class MCScopeTest {
@Test
fun testGenerator(){
runBlocking(Dispatchers.Default) {
mc(1122){
launch {
repeat(9){
println("first:${random.nextInt()}")
}
assertEquals(690819834,random.nextInt())
}
launch {
repeat(9){
println("second:${random.nextInt()}")
}
assertEquals(691997530,random.nextInt())
}
}
}
}
}