forked from kscience/kmath
Fixed race on generator fork
This commit is contained in:
parent
2892ea11e2
commit
9c239ebfbd
@ -8,19 +8,19 @@ kotlin.sourceSets {
|
|||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Versions.coroutinesVersion}")
|
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Versions.coroutinesVersion}")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
api("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jvmMain {
|
jvmMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutinesVersion}")
|
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutinesVersion}")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
api("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jsMain {
|
jsMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Versions.coroutinesVersion}")
|
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Versions.coroutinesVersion}")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
api("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,19 +7,19 @@ kotlin.sourceSets {
|
|||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
//compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jvmMain {
|
jvmMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
// https://mvnrepository.com/artifact/org.apache.commons/commons-rng-simple
|
// https://mvnrepository.com/artifact/org.apache.commons/commons-rng-simple
|
||||||
//api("org.apache.commons:commons-rng-sampling:1.2")
|
//api("org.apache.commons:commons-rng-sampling:1.2")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
// compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jsMain {
|
jsMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
// compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlinx.atomicfu.atomic
|
||||||
|
import kotlin.random.Random
|
||||||
|
|
||||||
|
class CountingRandomWrapper(val seed: Long) : RandomGenerator {
|
||||||
|
private val counter = atomic(0)
|
||||||
|
private val random = Random(seed)
|
||||||
|
|
||||||
|
override fun nextDouble(): Double = random.nextDouble()
|
||||||
|
|
||||||
|
override fun nextInt(): Int = random.nextInt()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = random.nextLong()
|
||||||
|
|
||||||
|
override fun nextBlock(size: Int): ByteArray = random.nextBytes(size)
|
||||||
|
|
||||||
|
override fun fork(): RandomGenerator = CountingRandomWrapper(seed + counter.addAndGet(10))
|
||||||
|
|
||||||
|
}
|
@ -8,32 +8,13 @@ import kotlin.coroutines.coroutineContext
|
|||||||
/**
|
/**
|
||||||
* A scope for a monte-carlo simulation or multi-coroutine random number generation
|
* A scope for a monte-carlo simulation or multi-coroutine random number generation
|
||||||
*/
|
*/
|
||||||
class MCScope private constructor(val coroutineContext: CoroutineContext) {
|
class MCScope(val coroutineContext: CoroutineContext, val random: RandomGenerator)
|
||||||
fun mcScope(context: CoroutineContext): MCScope = MCScope(context + coroutineContext[Random]!!.split())
|
|
||||||
|
|
||||||
companion object {
|
|
||||||
suspend fun init(generator: RandomGenerator): MCScope =
|
|
||||||
MCScope(coroutineContext + Random(generator))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline class Random(val generator: RandomGenerator) : CoroutineContext.Element {
|
|
||||||
override val key: CoroutineContext.Key<*> get() = Companion
|
|
||||||
|
|
||||||
fun split(): Random = Random(generator.fork())
|
|
||||||
|
|
||||||
companion object : CoroutineContext.Key<Random>
|
|
||||||
}
|
|
||||||
|
|
||||||
val CoroutineContext.random get() = this[Random]?.generator
|
|
||||||
|
|
||||||
val MCScope.random get() = coroutineContext.random!!
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Launches a supervised Monte-Carlo scope
|
* Launches a supervised Monte-Carlo scope
|
||||||
*/
|
*/
|
||||||
suspend fun <T> mc(generator: RandomGenerator, block: suspend MCScope.() -> T): T =
|
suspend fun <T> mc(generator: RandomGenerator, block: suspend MCScope.() -> T): T =
|
||||||
MCScope.init(generator).block()
|
MCScope(coroutineContext, generator).block()
|
||||||
|
|
||||||
suspend fun <T> mc(seed: Long = -1, block: suspend MCScope.() -> T): T =
|
suspend fun <T> mc(seed: Long = -1, block: suspend MCScope.() -> T): T =
|
||||||
mc(SplitRandomWrapper(seed), block)
|
mc(SplitRandomWrapper(seed), block)
|
||||||
@ -42,15 +23,22 @@ inline fun MCScope.launch(
|
|||||||
context: CoroutineContext = EmptyCoroutineContext,
|
context: CoroutineContext = EmptyCoroutineContext,
|
||||||
start: CoroutineStart = CoroutineStart.DEFAULT,
|
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||||
crossinline block: suspend MCScope.() -> Unit
|
crossinline block: suspend MCScope.() -> Unit
|
||||||
): Job = CoroutineScope(coroutineContext).launch(context + coroutineContext[Random]!!.split(), start) {
|
): Job {
|
||||||
mcScope(coroutineContext).block()
|
|
||||||
|
val newRandom = synchronized(this){random.fork()}
|
||||||
|
return CoroutineScope(coroutineContext).launch(context, start) {
|
||||||
|
MCScope(coroutineContext, newRandom).block()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <T> MCScope.async(
|
inline fun <T> MCScope.async(
|
||||||
context: CoroutineContext = EmptyCoroutineContext,
|
context: CoroutineContext = EmptyCoroutineContext,
|
||||||
start: CoroutineStart = CoroutineStart.DEFAULT,
|
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||||
crossinline block: suspend MCScope.() -> T
|
crossinline block: suspend MCScope.() -> T
|
||||||
): Deferred<T> = CoroutineScope(coroutineContext).async(context, start) {
|
): Deferred<T> {
|
||||||
mcScope(coroutineContext).block()
|
val newRandom = synchronized(this){random.fork()}
|
||||||
|
return CoroutineScope(coroutineContext).async(context, start) {
|
||||||
|
MCScope(coroutineContext, newRandom).block()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,7 +2,10 @@ package scientifik.kmath.prob
|
|||||||
|
|
||||||
import java.util.*
|
import java.util.*
|
||||||
|
|
||||||
class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator {
|
class SplitRandomWrapper(random: SplittableRandom) : RandomGenerator {
|
||||||
|
|
||||||
|
var random = random
|
||||||
|
private set
|
||||||
|
|
||||||
constructor(seed: Long) : this(SplittableRandom(seed))
|
constructor(seed: Long) : this(SplittableRandom(seed))
|
||||||
|
|
||||||
@ -14,5 +17,9 @@ class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator {
|
|||||||
|
|
||||||
override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { random.nextBytes(it) }
|
override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { random.nextBytes(it) }
|
||||||
|
|
||||||
override fun fork(): RandomGenerator = SplitRandomWrapper(random.split())
|
override fun fork(): RandomGenerator {
|
||||||
|
synchronized(this) {
|
||||||
|
return SplitRandomWrapper(random.split()).also { this.random = random.split() }
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -1,9 +1,6 @@
|
|||||||
package scientifik.kmath.prob
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
import kotlinx.coroutines.CoroutineScope
|
import kotlinx.coroutines.*
|
||||||
import kotlinx.coroutines.Dispatchers
|
|
||||||
import kotlinx.coroutines.newSingleThreadContext
|
|
||||||
import kotlinx.coroutines.runBlocking
|
|
||||||
import java.util.*
|
import java.util.*
|
||||||
import kotlin.collections.HashSet
|
import kotlin.collections.HashSet
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -15,11 +12,40 @@ typealias ATest = suspend CoroutineScope.() -> Set<RandomResult>
|
|||||||
|
|
||||||
class MCScopeTest {
|
class MCScopeTest {
|
||||||
val test: ATest = {
|
val test: ATest = {
|
||||||
mc(1122) {
|
mc(1111) {
|
||||||
|
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||||
|
|
||||||
|
launch {
|
||||||
|
//println(random)
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
res.add(RandomResult("first", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
launch {
|
||||||
|
"empty fork"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
launch {
|
||||||
|
//println(random)
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
res.add(RandomResult("second", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val test2: ATest = {
|
||||||
|
mc(1111) {
|
||||||
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||||
|
|
||||||
val job = launch {
|
val job = launch {
|
||||||
repeat(10) {
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
res.add(RandomResult("first", it, random.nextInt()))
|
res.add(RandomResult("first", it, random.nextInt()))
|
||||||
}
|
}
|
||||||
launch {
|
launch {
|
||||||
@ -28,18 +54,30 @@ class MCScopeTest {
|
|||||||
}
|
}
|
||||||
launch {
|
launch {
|
||||||
repeat(10) {
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
if(it == 4) job.join()
|
||||||
res.add(RandomResult("second", it, random.nextInt()))
|
res.add(RandomResult("second", it, random.nextInt()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testGenerator() {
|
fun testParallel() {
|
||||||
val res1 = runBlocking(Dispatchers.Default) { test() }
|
val res1 = runBlocking(Dispatchers.Default) { test() }
|
||||||
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
|
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
|
||||||
|
assertEquals(res1.find { it.branch=="first" && it.order==7 }?.value, res2.find { it.branch=="first" && it.order==7 }?.value)
|
||||||
|
assertEquals(res1, res2)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testConditionalJoin() {
|
||||||
|
val res1 = runBlocking(Dispatchers.Default) { test2() }
|
||||||
|
val res2 = runBlocking(newSingleThreadContext("test")) { test2() }
|
||||||
assertEquals(res1, res2)
|
assertEquals(res1, res2)
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user