diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CMRandomGeneratorWrapper.kt index 74e035ecb..5a0aabe3c 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CMRandomGeneratorWrapper.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/prob/CMRandomGeneratorWrapper.kt @@ -4,7 +4,9 @@ import org.apache.commons.math3.random.JDKRandomGenerator import scientifik.kmath.prob.RandomGenerator import org.apache.commons.math3.random.RandomGenerator as CMRandom -inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator { +class CMRandomGeneratorWrapper(seed: Long?, val builder: (Long?) -> CMRandom) : RandomGenerator { + val generator = builder(seed) + override fun nextDouble(): Double = generator.nextDouble() override fun nextInt(): Int = generator.nextInt() @@ -13,20 +15,12 @@ inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator override fun nextBlock(size: Int): ByteArray = ByteArray(size).apply { generator.nextBytes(this) } - override fun fork(): RandomGenerator { - TODO("not implemented") //To change body of created functions use File | Settings | File Templates. - } + override fun fork(): RandomGenerator = CMRandomGeneratorWrapper(nextLong(), builder) } -fun CMRandom.asKmathGenerator(): RandomGenerator = CMRandomGeneratorWrapper(this) - fun RandomGenerator.asCMGenerator(): CMRandom = (this as? CMRandomGeneratorWrapper)?.generator ?: TODO("Implement reverse CM wrapper") -val RandomGenerator.Companion.default: RandomGenerator by lazy { JDKRandomGenerator().asKmathGenerator() } -fun RandomGenerator.Companion.jdk(seed: Int? = null): RandomGenerator = if (seed == null) { - JDKRandomGenerator() -} else { - JDKRandomGenerator(seed) -}.asKmathGenerator() \ No newline at end of file +fun RandomGenerator.Companion.jdk(seed: Long? = null): RandomGenerator = + CMRandomGeneratorWrapper(seed) { JDKRandomGenerator() } \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt index 61e0dae4e..d3e6d77f1 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomChain.kt @@ -5,7 +5,6 @@ import scientifik.kmath.chains.Chain /** * A possibly stateful chain producing random values. - * TODO make random chain properly fork generator */ class RandomChain(val generator: RandomGenerator, private val gen: suspend RandomGenerator.() -> R) : Chain { private val atomicValue = atomic(null) @@ -13,4 +12,9 @@ class RandomChain(val generator: RandomGenerator, private val gen: suspen override suspend fun next(): R = generator.gen().also { atomicValue.lazySet(it) } override fun fork(): Chain = RandomChain(generator.fork(), gen) -} \ No newline at end of file +} + +/** + * Create a chain of doubles from generator after forking it so the chain is not affected by operations on generator + */ +fun RandomGenerator.doubles(): Chain = RandomChain(fork()) { nextDouble() } \ No newline at end of file diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt index 381bd9736..13983c6b2 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/RandomGenerator.kt @@ -1,5 +1,7 @@ package scientifik.kmath.prob +import kotlin.random.Random + /** * A basic generator */ @@ -10,9 +12,30 @@ interface RandomGenerator { fun nextBlock(size: Int): ByteArray /** - * Fork the current state of generator + * Create a new generator which is independent from current generator (operations on new generator do not affect this one + * and vise versa). The statistical properties of new generator should be the same as for this one. + * For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run. + * + * The thread safety of this operation is not guaranteed since it could affect the state of the generator. */ fun fork(): RandomGenerator - companion object + companion object { + val default by lazy { DefaultGenerator(Random.nextLong()) } + } +} + +class DefaultGenerator(seed: Long?) : RandomGenerator { + private val random = seed?.let { Random(it) } ?: Random + + override fun nextDouble(): Double = random.nextDouble() + + override fun nextInt(): Int = random.nextInt() + + override fun nextLong(): Long = random.nextLong() + + override fun nextBlock(size: Int): ByteArray = random.nextBytes(size) + + override fun fork(): RandomGenerator = DefaultGenerator(nextLong()) + } \ No newline at end of file diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/CountingRandomWrapper.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/CountingRandomWrapper.kt deleted file mode 100644 index c1cfbe727..000000000 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/CountingRandomWrapper.kt +++ /dev/null @@ -1,20 +0,0 @@ -package scientifik.kmath.prob - -import kotlinx.atomicfu.atomic -import kotlin.random.Random - -class CountingRandomWrapper(val seed: Long) : RandomGenerator { - private val counter = atomic(0) - private val random = Random(seed) - - override fun nextDouble(): Double = random.nextDouble() - - override fun nextInt(): Int = random.nextInt() - - override fun nextLong(): Long = random.nextLong() - - override fun nextBlock(size: Int): ByteArray = random.nextBytes(size) - - override fun fork(): RandomGenerator = CountingRandomWrapper(seed + counter.addAndGet(10)) - -} \ No newline at end of file diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt index 53717d11b..bc9ae5cb5 100644 --- a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt @@ -8,7 +8,7 @@ import kotlin.coroutines.coroutineContext /** * A scope for a monte-carlo simulation or multi-coroutine random number generation */ -class MCScope(val coroutineContext: CoroutineContext, val random: RandomGenerator) +class MCScope(override val coroutineContext: CoroutineContext, val random: RandomGenerator): CoroutineScope /** * Launches a supervised Monte-Carlo scope diff --git a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt index 2038101e1..b60506136 100644 --- a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt +++ b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt @@ -11,7 +11,7 @@ data class RandomResult(val branch: String, val order: Int, val value: Int) typealias ATest = suspend CoroutineScope.() -> Set class MCScopeTest { - val test: ATest = { + val simpleTest: ATest = { mc(1111) { val res = Collections.synchronizedSet(HashSet()) @@ -39,7 +39,7 @@ class MCScopeTest { } } - val test2: ATest = { + val testWithJoin: ATest = { mc(1111) { val res = Collections.synchronizedSet(HashSet()) @@ -48,14 +48,11 @@ class MCScopeTest { delay(10) res.add(RandomResult("first", it, random.nextInt())) } - launch { - "empty fork" - } } launch { repeat(10) { delay(10) - if(it == 4) job.join() + if (it == 4) job.join() res.add(RandomResult("second", it, random.nextInt())) } } @@ -65,19 +62,24 @@ class MCScopeTest { } - @Test - fun testParallel() { + fun compareResult(test: ATest) { val res1 = runBlocking(Dispatchers.Default) { test() } val res2 = runBlocking(newSingleThreadContext("test")) { test() } - assertEquals(res1.find { it.branch=="first" && it.order==7 }?.value, res2.find { it.branch=="first" && it.order==7 }?.value) + assertEquals( + res1.find { it.branch == "first" && it.order == 7 }?.value, + res2.find { it.branch == "first" && it.order == 7 }?.value + ) assertEquals(res1, res2) } + @Test + fun testParallel() { + compareResult(simpleTest) + } + @Test fun testConditionalJoin() { - val res1 = runBlocking(Dispatchers.Default) { test2() } - val res2 = runBlocking(newSingleThreadContext("test")) { test2() } - assertEquals(res1, res2) + compareResult(testWithJoin) } } \ No newline at end of file