Fixed race on generator fork
This commit is contained in:
parent
2892ea11e2
commit
9c239ebfbd
@ -8,19 +8,19 @@ kotlin.sourceSets {
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Versions.coroutinesVersion}")
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||
api("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
jvmMain {
|
||||
dependencies {
|
||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutinesVersion}")
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||
api("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
jsMain {
|
||||
dependencies {
|
||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Versions.coroutinesVersion}")
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||
api("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,19 +7,19 @@ kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-coroutines"))
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||
//compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
jvmMain {
|
||||
dependencies {
|
||||
// https://mvnrepository.com/artifact/org.apache.commons/commons-rng-simple
|
||||
//api("org.apache.commons:commons-rng-sampling:1.2")
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||
// compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
jsMain {
|
||||
dependencies {
|
||||
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||
// compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,20 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlin.random.Random
|
||||
|
||||
class CountingRandomWrapper(val seed: Long) : RandomGenerator {
|
||||
private val counter = atomic(0)
|
||||
private val random = Random(seed)
|
||||
|
||||
override fun nextDouble(): Double = random.nextDouble()
|
||||
|
||||
override fun nextInt(): Int = random.nextInt()
|
||||
|
||||
override fun nextLong(): Long = random.nextLong()
|
||||
|
||||
override fun nextBlock(size: Int): ByteArray = random.nextBytes(size)
|
||||
|
||||
override fun fork(): RandomGenerator = CountingRandomWrapper(seed + counter.addAndGet(10))
|
||||
|
||||
}
|
@ -8,32 +8,13 @@ import kotlin.coroutines.coroutineContext
|
||||
/**
|
||||
* A scope for a monte-carlo simulation or multi-coroutine random number generation
|
||||
*/
|
||||
class MCScope private constructor(val coroutineContext: CoroutineContext) {
|
||||
fun mcScope(context: CoroutineContext): MCScope = MCScope(context + coroutineContext[Random]!!.split())
|
||||
|
||||
companion object {
|
||||
suspend fun init(generator: RandomGenerator): MCScope =
|
||||
MCScope(coroutineContext + Random(generator))
|
||||
}
|
||||
}
|
||||
|
||||
inline class Random(val generator: RandomGenerator) : CoroutineContext.Element {
|
||||
override val key: CoroutineContext.Key<*> get() = Companion
|
||||
|
||||
fun split(): Random = Random(generator.fork())
|
||||
|
||||
companion object : CoroutineContext.Key<Random>
|
||||
}
|
||||
|
||||
val CoroutineContext.random get() = this[Random]?.generator
|
||||
|
||||
val MCScope.random get() = coroutineContext.random!!
|
||||
class MCScope(val coroutineContext: CoroutineContext, val random: RandomGenerator)
|
||||
|
||||
/**
|
||||
* Launches a supervised Monte-Carlo scope
|
||||
*/
|
||||
suspend fun <T> mc(generator: RandomGenerator, block: suspend MCScope.() -> T): T =
|
||||
MCScope.init(generator).block()
|
||||
MCScope(coroutineContext, generator).block()
|
||||
|
||||
suspend fun <T> mc(seed: Long = -1, block: suspend MCScope.() -> T): T =
|
||||
mc(SplitRandomWrapper(seed), block)
|
||||
@ -42,15 +23,22 @@ inline fun MCScope.launch(
|
||||
context: CoroutineContext = EmptyCoroutineContext,
|
||||
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||
crossinline block: suspend MCScope.() -> Unit
|
||||
): Job = CoroutineScope(coroutineContext).launch(context + coroutineContext[Random]!!.split(), start) {
|
||||
mcScope(coroutineContext).block()
|
||||
): Job {
|
||||
|
||||
val newRandom = synchronized(this){random.fork()}
|
||||
return CoroutineScope(coroutineContext).launch(context, start) {
|
||||
MCScope(coroutineContext, newRandom).block()
|
||||
}
|
||||
}
|
||||
|
||||
inline fun <T> MCScope.async(
|
||||
context: CoroutineContext = EmptyCoroutineContext,
|
||||
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||
crossinline block: suspend MCScope.() -> T
|
||||
): Deferred<T> = CoroutineScope(coroutineContext).async(context, start) {
|
||||
mcScope(coroutineContext).block()
|
||||
): Deferred<T> {
|
||||
val newRandom = synchronized(this){random.fork()}
|
||||
return CoroutineScope(coroutineContext).async(context, start) {
|
||||
MCScope(coroutineContext, newRandom).block()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,10 @@ package scientifik.kmath.prob
|
||||
|
||||
import java.util.*
|
||||
|
||||
class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator {
|
||||
class SplitRandomWrapper(random: SplittableRandom) : RandomGenerator {
|
||||
|
||||
var random = random
|
||||
private set
|
||||
|
||||
constructor(seed: Long) : this(SplittableRandom(seed))
|
||||
|
||||
@ -14,5 +17,9 @@ class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator {
|
||||
|
||||
override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { random.nextBytes(it) }
|
||||
|
||||
override fun fork(): RandomGenerator = SplitRandomWrapper(random.split())
|
||||
override fun fork(): RandomGenerator {
|
||||
synchronized(this) {
|
||||
return SplitRandomWrapper(random.split()).also { this.random = random.split() }
|
||||
}
|
||||
}
|
||||
}
|
@ -1,9 +1,6 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.newSingleThreadContext
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.*
|
||||
import java.util.*
|
||||
import kotlin.collections.HashSet
|
||||
import kotlin.test.Test
|
||||
@ -15,11 +12,40 @@ typealias ATest = suspend CoroutineScope.() -> Set<RandomResult>
|
||||
|
||||
class MCScopeTest {
|
||||
val test: ATest = {
|
||||
mc(1122) {
|
||||
mc(1111) {
|
||||
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||
|
||||
launch {
|
||||
//println(random)
|
||||
repeat(10) {
|
||||
delay(10)
|
||||
res.add(RandomResult("first", it, random.nextInt()))
|
||||
}
|
||||
launch {
|
||||
"empty fork"
|
||||
}
|
||||
}
|
||||
|
||||
launch {
|
||||
//println(random)
|
||||
repeat(10) {
|
||||
delay(10)
|
||||
res.add(RandomResult("second", it, random.nextInt()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
val test2: ATest = {
|
||||
mc(1111) {
|
||||
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||
|
||||
val job = launch {
|
||||
repeat(10) {
|
||||
delay(10)
|
||||
res.add(RandomResult("first", it, random.nextInt()))
|
||||
}
|
||||
launch {
|
||||
@ -28,18 +54,30 @@ class MCScopeTest {
|
||||
}
|
||||
launch {
|
||||
repeat(10) {
|
||||
delay(10)
|
||||
if(it == 4) job.join()
|
||||
res.add(RandomResult("second", it, random.nextInt()))
|
||||
}
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun testGenerator() {
|
||||
fun testParallel() {
|
||||
val res1 = runBlocking(Dispatchers.Default) { test() }
|
||||
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
|
||||
assertEquals(res1.find { it.branch=="first" && it.order==7 }?.value, res2.find { it.branch=="first" && it.order==7 }?.value)
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun testConditionalJoin() {
|
||||
val res1 = runBlocking(Dispatchers.Default) { test2() }
|
||||
val res2 = runBlocking(newSingleThreadContext("test")) { test2() }
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user