forked from kscience/kmath
Compare commits
4 Commits
dev
...
split-rand
Author | SHA1 | Date | |
---|---|---|---|
9efc6da74a | |||
9c239ebfbd | |||
2892ea11e2 | |||
c267372c70 |
@ -8,7 +8,7 @@ fun main(args: Array<String>) {
|
|||||||
val dim = 1000
|
val dim = 1000
|
||||||
val n = 1000
|
val n = 1000
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build coroutineContext most suited for given type.
|
||||||
val autoField = NDField.auto(RealField, dim, dim)
|
val autoField = NDField.auto(RealField, dim, dim)
|
||||||
// specialized nd-field for Double. It works as generic Double field as well
|
// specialized nd-field for Double. It works as generic Double field as well
|
||||||
val specializedField = NDField.real(dim, dim)
|
val specializedField = NDField.real(dim, dim)
|
||||||
|
@ -111,7 +111,7 @@ fun DiffExpression.derivative(vararg orders: Pair<String, Int>) = derivative(map
|
|||||||
fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
fun DiffExpression.derivative(name: String) = derivative(name to 1)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
* A coroutineContext for [DiffExpression] (not to be confused with [DerivativeStructure])
|
||||||
*/
|
*/
|
||||||
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
|
object DiffExpressionContext : ExpressionContext<Double>, Field<DiffExpression> {
|
||||||
override fun variable(name: String, default: Double?) =
|
override fun variable(name: String, default: Double?) =
|
||||||
|
@ -4,7 +4,9 @@ import org.apache.commons.math3.random.JDKRandomGenerator
|
|||||||
import scientifik.kmath.prob.RandomGenerator
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
import org.apache.commons.math3.random.RandomGenerator as CMRandom
|
import org.apache.commons.math3.random.RandomGenerator as CMRandom
|
||||||
|
|
||||||
inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator {
|
class CMRandomGeneratorWrapper(seed: Long?, val builder: (Long?) -> CMRandom) : RandomGenerator {
|
||||||
|
val generator = builder(seed)
|
||||||
|
|
||||||
override fun nextDouble(): Double = generator.nextDouble()
|
override fun nextDouble(): Double = generator.nextDouble()
|
||||||
|
|
||||||
override fun nextInt(): Int = generator.nextInt()
|
override fun nextInt(): Int = generator.nextInt()
|
||||||
@ -13,20 +15,12 @@ inline class CMRandomGeneratorWrapper(val generator: CMRandom) : RandomGenerator
|
|||||||
|
|
||||||
override fun nextBlock(size: Int): ByteArray = ByteArray(size).apply { generator.nextBytes(this) }
|
override fun nextBlock(size: Int): ByteArray = ByteArray(size).apply { generator.nextBytes(this) }
|
||||||
|
|
||||||
override fun fork(): RandomGenerator {
|
override fun fork(): RandomGenerator = CMRandomGeneratorWrapper(nextLong(), builder)
|
||||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun CMRandom.asKmathGenerator(): RandomGenerator = CMRandomGeneratorWrapper(this)
|
|
||||||
|
|
||||||
fun RandomGenerator.asCMGenerator(): CMRandom =
|
fun RandomGenerator.asCMGenerator(): CMRandom =
|
||||||
(this as? CMRandomGeneratorWrapper)?.generator ?: TODO("Implement reverse CM wrapper")
|
(this as? CMRandomGeneratorWrapper)?.generator ?: TODO("Implement reverse CM wrapper")
|
||||||
|
|
||||||
val RandomGenerator.Companion.default: RandomGenerator by lazy { JDKRandomGenerator().asKmathGenerator() }
|
|
||||||
|
|
||||||
fun RandomGenerator.Companion.jdk(seed: Int? = null): RandomGenerator = if (seed == null) {
|
fun RandomGenerator.Companion.jdk(seed: Long? = null): RandomGenerator =
|
||||||
JDKRandomGenerator()
|
CMRandomGeneratorWrapper(seed) { JDKRandomGenerator() }
|
||||||
} else {
|
|
||||||
JDKRandomGenerator(seed)
|
|
||||||
}.asKmathGenerator()
|
|
@ -8,19 +8,19 @@ kotlin.sourceSets {
|
|||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Versions.coroutinesVersion}")
|
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-common:${Versions.coroutinesVersion}")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
api("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jvmMain {
|
jvmMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutinesVersion}")
|
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutinesVersion}")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
api("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jsMain {
|
jsMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Versions.coroutinesVersion}")
|
api("org.jetbrains.kotlinx:kotlinx-coroutines-core-js:${Versions.coroutinesVersion}")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
api("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,19 +7,19 @@ kotlin.sourceSets {
|
|||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
//compileOnly("org.jetbrains.kotlinx:atomicfu-common:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jvmMain {
|
jvmMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
// https://mvnrepository.com/artifact/org.apache.commons/commons-rng-simple
|
// https://mvnrepository.com/artifact/org.apache.commons/commons-rng-simple
|
||||||
//api("org.apache.commons:commons-rng-sampling:1.2")
|
//api("org.apache.commons:commons-rng-sampling:1.2")
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
// compileOnly("org.jetbrains.kotlinx:atomicfu:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jsMain {
|
jsMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
// compileOnly("org.jetbrains.kotlinx:atomicfu-js:${Versions.atomicfuVersion}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ import scientifik.kmath.chains.Chain
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* A possibly stateful chain producing random values.
|
* A possibly stateful chain producing random values.
|
||||||
* TODO make random chain properly fork generator
|
|
||||||
*/
|
*/
|
||||||
class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspend RandomGenerator.() -> R) : Chain<R> {
|
class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspend RandomGenerator.() -> R) : Chain<R> {
|
||||||
private val atomicValue = atomic<R?>(null)
|
private val atomicValue = atomic<R?>(null)
|
||||||
@ -13,4 +12,9 @@ class RandomChain<out R>(val generator: RandomGenerator, private val gen: suspen
|
|||||||
override suspend fun next(): R = generator.gen().also { atomicValue.lazySet(it) }
|
override suspend fun next(): R = generator.gen().also { atomicValue.lazySet(it) }
|
||||||
|
|
||||||
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a chain of doubles from generator after forking it so the chain is not affected by operations on generator
|
||||||
|
*/
|
||||||
|
fun RandomGenerator.doubles(): Chain<Double> = RandomChain(fork()) { nextDouble() }
|
@ -1,5 +1,7 @@
|
|||||||
package scientifik.kmath.prob
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlin.random.Random
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A basic generator
|
* A basic generator
|
||||||
*/
|
*/
|
||||||
@ -10,9 +12,30 @@ interface RandomGenerator {
|
|||||||
fun nextBlock(size: Int): ByteArray
|
fun nextBlock(size: Int): ByteArray
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fork the current state of generator
|
* Create a new generator which is independent from current generator (operations on new generator do not affect this one
|
||||||
|
* and vise versa). The statistical properties of new generator should be the same as for this one.
|
||||||
|
* For pseudo-random generator, the fork is keeping the same sequence of numbers for given call order for each run.
|
||||||
|
*
|
||||||
|
* The thread safety of this operation is not guaranteed since it could affect the state of the generator.
|
||||||
*/
|
*/
|
||||||
fun fork(): RandomGenerator
|
fun fork(): RandomGenerator
|
||||||
|
|
||||||
companion object
|
companion object {
|
||||||
|
val default by lazy { DefaultGenerator(Random.nextLong()) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class DefaultGenerator(seed: Long?) : RandomGenerator {
|
||||||
|
private val random = seed?.let { Random(it) } ?: Random
|
||||||
|
|
||||||
|
override fun nextDouble(): Double = random.nextDouble()
|
||||||
|
|
||||||
|
override fun nextInt(): Int = random.nextInt()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = random.nextLong()
|
||||||
|
|
||||||
|
override fun nextBlock(size: Int): ByteArray = random.nextBytes(size)
|
||||||
|
|
||||||
|
override fun fork(): RandomGenerator = DefaultGenerator(nextLong())
|
||||||
|
|
||||||
}
|
}
|
@ -0,0 +1,44 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.*
|
||||||
|
import kotlin.coroutines.CoroutineContext
|
||||||
|
import kotlin.coroutines.EmptyCoroutineContext
|
||||||
|
import kotlin.coroutines.coroutineContext
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A scope for a monte-carlo simulation or multi-coroutine random number generation
|
||||||
|
*/
|
||||||
|
class MCScope(override val coroutineContext: CoroutineContext, val random: RandomGenerator): CoroutineScope
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Launches a supervised Monte-Carlo scope
|
||||||
|
*/
|
||||||
|
suspend fun <T> mc(generator: RandomGenerator, block: suspend MCScope.() -> T): T =
|
||||||
|
MCScope(coroutineContext, generator).block()
|
||||||
|
|
||||||
|
suspend fun <T> mc(seed: Long = -1, block: suspend MCScope.() -> T): T =
|
||||||
|
mc(SplitRandomWrapper(seed), block)
|
||||||
|
|
||||||
|
inline fun MCScope.launch(
|
||||||
|
context: CoroutineContext = EmptyCoroutineContext,
|
||||||
|
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||||
|
crossinline block: suspend MCScope.() -> Unit
|
||||||
|
): Job {
|
||||||
|
|
||||||
|
val newRandom = synchronized(this){random.fork()}
|
||||||
|
return CoroutineScope(coroutineContext).launch(context, start) {
|
||||||
|
MCScope(coroutineContext, newRandom).block()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <T> MCScope.async(
|
||||||
|
context: CoroutineContext = EmptyCoroutineContext,
|
||||||
|
start: CoroutineStart = CoroutineStart.DEFAULT,
|
||||||
|
crossinline block: suspend MCScope.() -> T
|
||||||
|
): Deferred<T> {
|
||||||
|
val newRandom = synchronized(this){random.fork()}
|
||||||
|
return CoroutineScope(coroutineContext).async(context, start) {
|
||||||
|
MCScope(coroutineContext, newRandom).block()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,25 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import java.util.*
|
||||||
|
|
||||||
|
class SplitRandomWrapper(random: SplittableRandom) : RandomGenerator {
|
||||||
|
|
||||||
|
var random = random
|
||||||
|
private set
|
||||||
|
|
||||||
|
constructor(seed: Long) : this(SplittableRandom(seed))
|
||||||
|
|
||||||
|
override fun nextDouble(): Double = random.nextDouble()
|
||||||
|
|
||||||
|
override fun nextInt(): Int = random.nextInt()
|
||||||
|
|
||||||
|
override fun nextLong(): Long = random.nextLong()
|
||||||
|
|
||||||
|
override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { random.nextBytes(it) }
|
||||||
|
|
||||||
|
override fun fork(): RandomGenerator {
|
||||||
|
synchronized(this) {
|
||||||
|
return SplitRandomWrapper(random.split()).also { this.random = random.split() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,85 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import kotlinx.coroutines.*
|
||||||
|
import java.util.*
|
||||||
|
import kotlin.collections.HashSet
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
data class RandomResult(val branch: String, val order: Int, val value: Int)
|
||||||
|
|
||||||
|
typealias ATest = suspend CoroutineScope.() -> Set<RandomResult>
|
||||||
|
|
||||||
|
class MCScopeTest {
|
||||||
|
val simpleTest: ATest = {
|
||||||
|
mc(1111) {
|
||||||
|
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||||
|
|
||||||
|
launch {
|
||||||
|
//println(random)
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
res.add(RandomResult("first", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
launch {
|
||||||
|
"empty fork"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
launch {
|
||||||
|
//println(random)
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
res.add(RandomResult("second", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val testWithJoin: ATest = {
|
||||||
|
mc(1111) {
|
||||||
|
val res = Collections.synchronizedSet(HashSet<RandomResult>())
|
||||||
|
|
||||||
|
val job = launch {
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
res.add(RandomResult("first", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
launch {
|
||||||
|
repeat(10) {
|
||||||
|
delay(10)
|
||||||
|
if (it == 4) job.join()
|
||||||
|
res.add(RandomResult("second", it, random.nextInt()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun compareResult(test: ATest) {
|
||||||
|
val res1 = runBlocking(Dispatchers.Default) { test() }
|
||||||
|
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
|
||||||
|
assertEquals(
|
||||||
|
res1.find { it.branch == "first" && it.order == 7 }?.value,
|
||||||
|
res2.find { it.branch == "first" && it.order == 7 }?.value
|
||||||
|
)
|
||||||
|
assertEquals(res1, res2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testParallel() {
|
||||||
|
compareResult(simpleTest)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testConditionalJoin() {
|
||||||
|
compareResult(testWithJoin)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user