From 9c239ebfbda35be2dcb37a3c5ae2ea1b5a5aeae8 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 16 Jun 2019 11:24:55 +0300 Subject: [PATCH] Fixed race on generator fork --- kmath-coroutines/build.gradle.kts | 6 +-- kmath-prob/build.gradle.kts | 6 +-- .../kmath/prob/CountingRandomWrapper.kt | 20 +++++++ .../kotlin/scientifik/kmath/prob/MCScope.kt | 38 +++++-------- .../kmath/prob/SplitRandomWrapper.kt | 11 +++- .../scientifik/kmath/prob/MCScopeTest.kt | 54 ++++++++++++++++--- 6 files changed, 94 insertions(+), 41 deletions(-) create mode 100644 kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/CountingRandomWrapper.kt diff --git a/kmath-coroutines/build.gradle.kts b/kmath-coroutines/build.gradle.kts index 6ffaf4df6..451471698 100644 --- a/kmath-coroutines/build.gradle.kts +++ b/kmath-coroutines/build.gradle.kts @@ -8,19 +8,19 @@ kotlin.sourceSets { dependencies { api(project(":kmath-core")) api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Versions.coroutinesVersion}") - compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}") + api("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}") } } jvmMain { dependencies { api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutinesVersion}") - compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}") + api("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}") } } jsMain { dependencies { api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Versions.coroutinesVersion}") - compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}") + api("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}") } } } diff --git a/kmath-prob/build.gradle.kts b/kmath-prob/build.gradle.kts index 972b9bba0..e595347fd 100644 --- a/kmath-prob/build.gradle.kts +++ b/kmath-prob/build.gradle.kts @@ -7,19 +7,19 @@ kotlin.sourceSets { commonMain { dependencies { api(project(":kmath-coroutines")) - compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}") + //compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}") } } jvmMain { dependencies { // https://mvnrepository.com/artifact/org.apache.commons/commons-rng-simple //api("org.apache.commons:commons-rng-sampling:1.2") - compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}") + // compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}") } } jsMain { dependencies { - compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}") + // compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}") } } diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/CountingRandomWrapper.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/CountingRandomWrapper.kt new file mode 100644 index 000000000..c1cfbe727 --- /dev/null +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/CountingRandomWrapper.kt @@ -0,0 +1,20 @@ +package scientifik.kmath.prob + +import kotlinx.atomicfu.atomic +import kotlin.random.Random + +class CountingRandomWrapper(val seed: Long) : RandomGenerator { + private val counter = atomic(0) + private val random = Random(seed) + + override fun nextDouble(): Double = random.nextDouble() + + override fun nextInt(): Int = random.nextInt() + + override fun nextLong(): Long = random.nextLong() + + override fun nextBlock(size: Int): ByteArray = random.nextBytes(size) + + override fun fork(): RandomGenerator = CountingRandomWrapper(seed + counter.addAndGet(10)) + +} \ No newline at end of file diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt index e65a589f3..53717d11b 100644 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt @@ -8,32 +8,13 @@ import kotlin.coroutines.coroutineContext /** * A scope for a monte-carlo simulation or multi-coroutine random number generation */ -class MCScope private constructor(val coroutineContext: CoroutineContext) { - fun mcScope(context: CoroutineContext): MCScope = MCScope(context + coroutineContext[Random]!!.split()) - - companion object { - suspend fun init(generator: RandomGenerator): MCScope = - MCScope(coroutineContext + Random(generator)) - } -} - -inline class Random(val generator: RandomGenerator) : CoroutineContext.Element { - override val key: CoroutineContext.Key<*> get() = Companion - - fun split(): Random = Random(generator.fork()) - - companion object : CoroutineContext.Key -} - -val CoroutineContext.random get() = this[Random]?.generator - -val MCScope.random get() = coroutineContext.random!! +class MCScope(val coroutineContext: CoroutineContext, val random: RandomGenerator) /** * Launches a supervised Monte-Carlo scope */ suspend fun mc(generator: RandomGenerator, block: suspend MCScope.() -> T): T = - MCScope.init(generator).block() + MCScope(coroutineContext, generator).block() suspend fun mc(seed: Long = -1, block: suspend MCScope.() -> T): T = mc(SplitRandomWrapper(seed), block) @@ -42,15 +23,22 @@ inline fun MCScope.launch( context: CoroutineContext = EmptyCoroutineContext, start: CoroutineStart = CoroutineStart.DEFAULT, crossinline block: suspend MCScope.() -> Unit -): Job = CoroutineScope(coroutineContext).launch(context + coroutineContext[Random]!!.split(), start) { - mcScope(coroutineContext).block() +): Job { + + val newRandom = synchronized(this){random.fork()} + return CoroutineScope(coroutineContext).launch(context, start) { + MCScope(coroutineContext, newRandom).block() + } } inline fun MCScope.async( context: CoroutineContext = EmptyCoroutineContext, start: CoroutineStart = CoroutineStart.DEFAULT, crossinline block: suspend MCScope.() -> T -): Deferred = CoroutineScope(coroutineContext).async(context, start) { - mcScope(coroutineContext).block() +): Deferred { + val newRandom = synchronized(this){random.fork()} + return CoroutineScope(coroutineContext).async(context, start) { + MCScope(coroutineContext, newRandom).block() + } } diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/SplitRandomWrapper.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/SplitRandomWrapper.kt index ca488a262..a0702e6db 100644 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/SplitRandomWrapper.kt +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/SplitRandomWrapper.kt @@ -2,7 +2,10 @@ package scientifik.kmath.prob import java.util.* -class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator { +class SplitRandomWrapper(random: SplittableRandom) : RandomGenerator { + + var random = random + private set constructor(seed: Long) : this(SplittableRandom(seed)) @@ -14,5 +17,9 @@ class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator { override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { random.nextBytes(it) } - override fun fork(): RandomGenerator = SplitRandomWrapper(random.split()) + override fun fork(): RandomGenerator { + synchronized(this) { + return SplitRandomWrapper(random.split()).also { this.random = random.split() } + } + } } \ No newline at end of file diff --git a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt index 69480ec22..2038101e1 100644 --- a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt +++ b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt @@ -1,9 +1,6 @@ package scientifik.kmath.prob -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.newSingleThreadContext -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.* import java.util.* import kotlin.collections.HashSet import kotlin.test.Test @@ -15,11 +12,40 @@ typealias ATest = suspend CoroutineScope.() -> Set class MCScopeTest { val test: ATest = { - mc(1122) { + mc(1111) { + val res = Collections.synchronizedSet(HashSet()) + + launch { + //println(random) + repeat(10) { + delay(10) + res.add(RandomResult("first", it, random.nextInt())) + } + launch { + "empty fork" + } + } + + launch { + //println(random) + repeat(10) { + delay(10) + res.add(RandomResult("second", it, random.nextInt())) + } + } + + + res + } + } + + val test2: ATest = { + mc(1111) { val res = Collections.synchronizedSet(HashSet()) val job = launch { repeat(10) { + delay(10) res.add(RandomResult("first", it, random.nextInt())) } launch { @@ -28,18 +54,30 @@ class MCScopeTest { } launch { repeat(10) { + delay(10) + if(it == 4) job.join() res.add(RandomResult("second", it, random.nextInt())) } } + res } } @Test - fun testGenerator() { + fun testParallel() { val res1 = runBlocking(Dispatchers.Default) { test() } - val res2 = runBlocking(newSingleThreadContext("test")) {test()} - assertEquals(res1,res2) + val res2 = runBlocking(newSingleThreadContext("test")) { test() } + assertEquals(res1.find { it.branch=="first" && it.order==7 }?.value, res2.find { it.branch=="first" && it.order==7 }?.value) + assertEquals(res1, res2) + } + + + @Test + fun testConditionalJoin() { + val res1 = runBlocking(Dispatchers.Default) { test2() } + val res2 = runBlocking(newSingleThreadContext("test")) { test2() } + assertEquals(res1, res2) } } \ No newline at end of file