Fixed race on generator fork

This commit is contained in:
Alexander Nozik 2019-06-16 11:24:55 +03:00
parent 2892ea11e2
commit 9c239ebfbd
6 changed files with 94 additions and 41 deletions

View File

@ -8,19 +8,19 @@ kotlin.sourceSets {
dependencies {
api(project(":kmath-core"))
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Versions.coroutinesVersion}")
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
api("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
}
}
jvmMain {
dependencies {
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutinesVersion}")
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
api("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
}
}
jsMain {
dependencies {
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Versions.coroutinesVersion}")
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
api("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
}
}
}

View File

@ -7,19 +7,19 @@ kotlin.sourceSets {
commonMain {
dependencies {
api(project(":kmath-coroutines"))
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
//compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
}
}
jvmMain {
dependencies {
// https://mvnrepository.com/artifact/org.apache.commons/commons-rng-simple
//api("org.apache.commons:commons-rng-sampling:1.2")
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
// compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
}
}
jsMain {
dependencies {
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
// compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
}
}

View File

@ -0,0 +1,20 @@
package scientifik.kmath.prob
import kotlinx.atomicfu.atomic
import kotlin.random.Random
class CountingRandomWrapper(val seed: Long) : RandomGenerator {
private val counter = atomic(0)
private val random = Random(seed)
override fun nextDouble(): Double = random.nextDouble()
override fun nextInt(): Int = random.nextInt()
override fun nextLong(): Long = random.nextLong()
override fun nextBlock(size: Int): ByteArray = random.nextBytes(size)
override fun fork(): RandomGenerator = CountingRandomWrapper(seed + counter.addAndGet(10))
}

View File

@ -8,32 +8,13 @@ import kotlin.coroutines.coroutineContext
/**
* A scope for a monte-carlo simulation or multi-coroutine random number generation
*/
class MCScope private constructor(val coroutineContext: CoroutineContext) {
fun mcScope(context: CoroutineContext): MCScope = MCScope(context + coroutineContext[Random]!!.split())
companion object {
suspend fun init(generator: RandomGenerator): MCScope =
MCScope(coroutineContext + Random(generator))
}
}
inline class Random(val generator: RandomGenerator) : CoroutineContext.Element {
override val key: CoroutineContext.Key<*> get() = Companion
fun split(): Random = Random(generator.fork())
companion object : CoroutineContext.Key<Random>
}
val CoroutineContext.random get() = this[Random]?.generator
val MCScope.random get() = coroutineContext.random!!
class MCScope(val coroutineContext: CoroutineContext, val random: RandomGenerator)
/**
* Launches a supervised Monte-Carlo scope
*/
suspend fun <T> mc(generator: RandomGenerator, block: suspend MCScope.() -> T): T =
MCScope.init(generator).block()
MCScope(coroutineContext, generator).block()
suspend fun <T> mc(seed: Long = -1, block: suspend MCScope.() -> T): T =
mc(SplitRandomWrapper(seed), block)
@ -42,15 +23,22 @@ inline fun MCScope.launch(
context: CoroutineContext = EmptyCoroutineContext,
start: CoroutineStart = CoroutineStart.DEFAULT,
crossinline block: suspend MCScope.() -> Unit
): Job = CoroutineScope(coroutineContext).launch(context + coroutineContext[Random]!!.split(), start) {
mcScope(coroutineContext).block()
): Job {
val newRandom = synchronized(this){random.fork()}
return CoroutineScope(coroutineContext).launch(context, start) {
MCScope(coroutineContext, newRandom).block()
}
}
inline fun <T> MCScope.async(
context: CoroutineContext = EmptyCoroutineContext,
start: CoroutineStart = CoroutineStart.DEFAULT,
crossinline block: suspend MCScope.() -> T
): Deferred<T> = CoroutineScope(coroutineContext).async(context, start) {
mcScope(coroutineContext).block()
): Deferred<T> {
val newRandom = synchronized(this){random.fork()}
return CoroutineScope(coroutineContext).async(context, start) {
MCScope(coroutineContext, newRandom).block()
}
}

View File

@ -2,7 +2,10 @@ package scientifik.kmath.prob
import java.util.*
class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator {
class SplitRandomWrapper(random: SplittableRandom) : RandomGenerator {
var random = random
private set
constructor(seed: Long) : this(SplittableRandom(seed))
@ -14,5 +17,9 @@ class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator {
override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { random.nextBytes(it) }
override fun fork(): RandomGenerator = SplitRandomWrapper(random.split())
override fun fork(): RandomGenerator {
synchronized(this) {
return SplitRandomWrapper(random.split()).also { this.random = random.split() }
}
}
}

View File

@ -1,9 +1,6 @@
package scientifik.kmath.prob
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.*
import java.util.*
import kotlin.collections.HashSet
import kotlin.test.Test
@ -15,11 +12,40 @@ typealias ATest = suspend CoroutineScope.() -> Set<RandomResult>
class MCScopeTest {
val test: ATest = {
mc(1122) {
mc(1111) {
val res = Collections.synchronizedSet(HashSet<RandomResult>())
launch {
//println(random)
repeat(10) {
delay(10)
res.add(RandomResult("first", it, random.nextInt()))
}
launch {
"empty fork"
}
}
launch {
//println(random)
repeat(10) {
delay(10)
res.add(RandomResult("second", it, random.nextInt()))
}
}
res
}
}
val test2: ATest = {
mc(1111) {
val res = Collections.synchronizedSet(HashSet<RandomResult>())
val job = launch {
repeat(10) {
delay(10)
res.add(RandomResult("first", it, random.nextInt()))
}
launch {
@ -28,18 +54,30 @@ class MCScopeTest {
}
launch {
repeat(10) {
delay(10)
if(it == 4) job.join()
res.add(RandomResult("second", it, random.nextInt()))
}
}
res
}
}
@Test
fun testGenerator() {
fun testParallel() {
val res1 = runBlocking(Dispatchers.Default) { test() }
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
assertEquals(res1.find { it.branch=="first" && it.order==7 }?.value, res2.find { it.branch=="first" && it.order==7 }?.value)
assertEquals(res1, res2)
}
@Test
fun testConditionalJoin() {
val res1 = runBlocking(Dispatchers.Default) { test2() }
val res2 = runBlocking(newSingleThreadContext("test")) { test2() }
assertEquals(res1, res2)
}
}