forked from kscience/kmath
Proof of concept for coroutine-safe random number generators
This commit is contained in:
parent
651f8323f3
commit
c267372c70
@ -8,7 +8,7 @@ fun main(args: Array<String>) {
|
|||||||
val dim = 1000
|
val dim = 1000
|
||||||
val n = 1000
|
val n = 1000
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build coroutineContext most suited for given type.
|
||||||
val autoField = NDField.auto(RealField, dim, dim)
|
val autoField = NDField.auto(RealField, dim, dim)
|
||||||
// specialized nd-field for Double. It works as generic Double field as well
|
// specialized nd-field for Double. It works as generic Double field as well
|
||||||
val specializedField = NDField.real(dim, dim)
|
val specializedField = NDField.real(dim, dim)
|
||||||
|
@ -111,7 +111,7 @@ fun DiffExpression.derivative(vararg orders: Pair<String, Int>) = derivative(map
|
|||||||
fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
* A coroutineContext for [DiffExpression] (not to be confused with [DerivativeStructure])
|
||||||
*/
|
*/
|
||||||
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
|
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
|
||||||
override fun variable(name: String, default: Double?) =
|
override fun variable(name: String, default: Double?) =
|
||||||
|
@ -0,0 +1,58 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.*
|
||||||
|
import kotlin.coroutines.CoroutineContext
|
||||||
|
import kotlin.coroutines.EmptyCoroutineContext
|
||||||
|
import kotlin.coroutines.coroutineContext
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A scope for a monte-carlo simulation or multi-coroutine random number generation
|
||||||
|
*/
|
||||||
|
class MCScope private constructor(val coroutineContext: CoroutineContext) {
|
||||||
|
fun mcScope(context: CoroutineContext): MCScope = MCScope(context + coroutineContext[Random]!!.split())
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
suspend fun init(generator: RandomGenerator): MCScope =
|
||||||
|
MCScope(coroutineContext + Random(generator))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline class Random(val generator: RandomGenerator) : CoroutineContext.Element {
|
||||||
|
override val key: CoroutineContext.Key<*> get() = Companion
|
||||||
|
|
||||||
|
fun split(): Random = Random(generator.fork())
|
||||||
|
|
||||||
|
companion object : CoroutineContext.Key<Random>
|
||||||
|
}
|
||||||
|
|
||||||
|
val CoroutineContext.random get() = this[Random]?.generator
|
||||||
|
|
||||||
|
val MCScope.random get() = coroutineContext.random!!
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Launches a supervised Monte-Carlo scope
|
||||||
|
*/
|
||||||
|
suspend fun <T> mc(generator: RandomGenerator, block: suspend MCScope.() -> T): T =
|
||||||
|
supervisorScope {
|
||||||
|
MCScope.init(generator).block()
|
||||||
|
}
|
||||||
|
|
||||||
|
suspend fun <T> mc(seed: Long = -1, block: suspend MCScope.() -> T): T =
|
||||||
|
mc(SplitRandomWrapper(seed), block)
|
||||||
|
|
||||||
|
inline fun MCScope.launch(
|
||||||
|
context: CoroutineContext = EmptyCoroutineContext,
|
||||||
|
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||||
|
crossinline block: suspend MCScope.() -> Unit
|
||||||
|
): Job = CoroutineScope(coroutineContext).launch(context + coroutineContext[Random]!!.split(), start) {
|
||||||
|
mcScope(coroutineContext).block()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <T> MCScope.async(
|
||||||
|
context: CoroutineContext = EmptyCoroutineContext,
|
||||||
|
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||||
|
crossinline block: suspend MCScope.() -> T
|
||||||
|
): Deferred<T> = CoroutineScope(coroutineContext).async(context, start) {
|
||||||
|
mcScope(coroutineContext).block()
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,18 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import java.util.*
|
||||||
|
|
||||||
|
class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator {
|
||||||
|
|
||||||
|
constructor(seed: Long) : this(SplittableRandom(seed))
|
||||||
|
|
||||||
|
override fun nextDouble(): Double = random.nextDouble()
|
||||||
|
|
||||||
|
override fun nextInt(): Int = random.nextInt()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = random.nextLong()
|
||||||
|
|
||||||
|
override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { random.nextBytes(it) }
|
||||||
|
|
||||||
|
override fun fork(): RandomGenerator = SplitRandomWrapper(random.split())
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.runBlocking
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class MCScopeTest {
|
||||||
|
@Test
|
||||||
|
fun testGenerator(){
|
||||||
|
runBlocking(Dispatchers.Default) {
|
||||||
|
mc(1122){
|
||||||
|
launch {
|
||||||
|
repeat(9){
|
||||||
|
println("first:${random.nextInt()}")
|
||||||
|
}
|
||||||
|
assertEquals(690819834,random.nextInt())
|
||||||
|
}
|
||||||
|
launch {
|
||||||
|
repeat(9){
|
||||||
|
println("second:${random.nextInt()}")
|
||||||
|
}
|
||||||
|
assertEquals(691997530,random.nextInt())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user