Implemented fork for default random generator
This commit is contained in:
parent
9c239ebfbd
commit
9efc6da74a
@ -4,7 +4,9 @@ import org.apache.commons.math3.random.JDKRandomGenerator
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
import org.apache.commons.math3.random.RandomGenerator as CMRandom
|
||||
|
||||
inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator {
|
||||
class CMRandomGeneratorWrapper(seed: Long?, val builder: (Long?) -> CMRandom) : RandomGenerator {
|
||||
val generator = builder(seed)
|
||||
|
||||
override fun nextDouble(): Double = generator.nextDouble()
|
||||
|
||||
override fun nextInt(): Int = generator.nextInt()
|
||||
@ -13,20 +15,12 @@ inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator
|
||||
|
||||
override fun nextBlock(size: Int): ByteArray = ByteArray(size).apply { generator.nextBytes(this) }
|
||||
|
||||
override fun fork(): RandomGenerator {
|
||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||
}
|
||||
override fun fork(): RandomGenerator = CMRandomGeneratorWrapper(nextLong(), builder)
|
||||
}
|
||||
|
||||
fun CMRandom.asKmathGenerator(): RandomGenerator = CMRandomGeneratorWrapper(this)
|
||||
|
||||
fun RandomGenerator.asCMGenerator(): CMRandom =
|
||||
(this as? CMRandomGeneratorWrapper)?.generator ?: TODO("Implement reverse CM wrapper")
|
||||
|
||||
val RandomGenerator.Companion.default: RandomGenerator by lazy { JDKRandomGenerator().asKmathGenerator() }
|
||||
|
||||
fun RandomGenerator.Companion.jdk(seed: Int? = null): RandomGenerator = if (seed == null) {
|
||||
JDKRandomGenerator()
|
||||
} else {
|
||||
JDKRandomGenerator(seed)
|
||||
}.asKmathGenerator()
|
||||
fun RandomGenerator.Companion.jdk(seed: Long? = null): RandomGenerator =
|
||||
CMRandomGeneratorWrapper(seed) { JDKRandomGenerator() }
|
@ -5,7 +5,6 @@ import scientifik.kmath.chains.Chain
|
||||
|
||||
/**
|
||||
* A possibly stateful chain producing random values.
|
||||
* TODO make random chain properly fork generator
|
||||
*/
|
||||
class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspend RandomGenerator.() -> R) : Chain<R> {
|
||||
private val atomicValue = atomic<R?>(null)
|
||||
@ -13,4 +12,9 @@ class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspen
|
||||
override suspend fun next(): R = generator.gen().also { atomicValue.lazySet(it) }
|
||||
|
||||
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a chain of doubles from generator after forking it so the chain is not affected by operations on generator
|
||||
*/
|
||||
fun RandomGenerator.doubles(): Chain<Double> = RandomChain(fork()) { nextDouble() }
|
@ -1,5 +1,7 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import kotlin.random.Random
|
||||
|
||||
/**
|
||||
* A basic generator
|
||||
*/
|
||||
@ -10,9 +12,30 @@ interface RandomGenerator {
|
||||
fun nextBlock(size: Int): ByteArray
|
||||
|
||||
/**
|
||||
* Fork the current state of generator
|
||||
* Create a new generator which is independent from current generator (operations on new generator do not affect this one
|
||||
* and vise versa). The statistical properties of new generator should be the same as for this one.
|
||||
* For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run.
|
||||
*
|
||||
* The thread safety of this operation is not guaranteed since it could affect the state of the generator.
|
||||
*/
|
||||
fun fork(): RandomGenerator
|
||||
|
||||
companion object
|
||||
companion object {
|
||||
val default by lazy { DefaultGenerator(Random.nextLong()) }
|
||||
}
|
||||
}
|
||||
|
||||
class DefaultGenerator(seed: Long?) : RandomGenerator {
|
||||
private val random = seed?.let { Random(it) } ?: Random
|
||||
|
||||
override fun nextDouble(): Double = random.nextDouble()
|
||||
|
||||
override fun nextInt(): Int = random.nextInt()
|
||||
|
||||
override fun nextLong(): Long = random.nextLong()
|
||||
|
||||
override fun nextBlock(size: Int): ByteArray = random.nextBytes(size)
|
||||
|
||||
override fun fork(): RandomGenerator = DefaultGenerator(nextLong())
|
||||
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlin.random.Random
|
||||
|
||||
class CountingRandomWrapper(val seed: Long) : RandomGenerator {
|
||||
private val counter = atomic(0)
|
||||
private val random = Random(seed)
|
||||
|
||||
override fun nextDouble(): Double = random.nextDouble()
|
||||
|
||||
override fun nextInt(): Int = random.nextInt()
|
||||
|
||||
override fun nextLong(): Long = random.nextLong()
|
||||
|
||||
override fun nextBlock(size: Int): ByteArray = random.nextBytes(size)
|
||||
|
||||
override fun fork(): RandomGenerator = CountingRandomWrapper(seed + counter.addAndGet(10))
|
||||
|
||||
}
|
@ -8,7 +8,7 @@ import kotlin.coroutines.coroutineContext
|
||||
/**
|
||||
* A scope for a monte-carlo simulation or multi-coroutine random number generation
|
||||
*/
|
||||
class MCScope(val coroutineContext: CoroutineContext, val random: RandomGenerator)
|
||||
class MCScope(override val coroutineContext: CoroutineContext, val random: RandomGenerator): CoroutineScope
|
||||
|
||||
/**
|
||||
* Launches a supervised Monte-Carlo scope
|
||||
|
@ -11,7 +11,7 @@ data class RandomResult(val branch: String, val order: Int, val value: Int)
|
||||
typealias ATest = suspend CoroutineScope.() -> Set<RandomResult>
|
||||
|
||||
class MCScopeTest {
|
||||
val test: ATest = {
|
||||
val simpleTest: ATest = {
|
||||
mc(1111) {
|
||||
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||
|
||||
@ -39,7 +39,7 @@ class MCScopeTest {
|
||||
}
|
||||
}
|
||||
|
||||
val test2: ATest = {
|
||||
val testWithJoin: ATest = {
|
||||
mc(1111) {
|
||||
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||
|
||||
@ -48,14 +48,11 @@ class MCScopeTest {
|
||||
delay(10)
|
||||
res.add(RandomResult("first", it, random.nextInt()))
|
||||
}
|
||||
launch {
|
||||
"empty fork"
|
||||
}
|
||||
}
|
||||
launch {
|
||||
repeat(10) {
|
||||
delay(10)
|
||||
if(it == 4) job.join()
|
||||
if (it == 4) job.join()
|
||||
res.add(RandomResult("second", it, random.nextInt()))
|
||||
}
|
||||
}
|
||||
@ -65,19 +62,24 @@ class MCScopeTest {
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun testParallel() {
|
||||
fun compareResult(test: ATest) {
|
||||
val res1 = runBlocking(Dispatchers.Default) { test() }
|
||||
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
|
||||
assertEquals(res1.find { it.branch=="first" && it.order==7 }?.value, res2.find { it.branch=="first" && it.order==7 }?.value)
|
||||
assertEquals(
|
||||
res1.find { it.branch == "first" && it.order == 7 }?.value,
|
||||
res2.find { it.branch == "first" && it.order == 7 }?.value
|
||||
)
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testParallel() {
|
||||
compareResult(simpleTest)
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun testConditionalJoin() {
|
||||
val res1 = runBlocking(Dispatchers.Default) { test2() }
|
||||
val res2 = runBlocking(newSingleThreadContext("test")) { test2() }
|
||||
assertEquals(res1, res2)
|
||||
compareResult(testWithJoin)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user