diff --git a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt index 7e8c433c0..5ea91143c 100644 --- a/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/scientifik/kmath/structures/NDField.kt @@ -8,7 +8,7 @@ fun main(args: Array) { val dim = 1000 val n = 1000 - // automatically build context most suited for given type. + // automatically build coroutineContext most suited for given type. val autoField = NDField.auto(RealField, dim, dim) // specialized nd-field for Double. It works as generic Double field as well val specializedField = NDField.real(dim, dim) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index d5c038dc4..226368f76 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -111,7 +111,7 @@ fun DiffExpression.derivative(vararg orders: Pair) = derivative(map fun DiffExpression.derivative(name: String) = derivative(name to 1) /** - * A context for [DiffExpression] (not to be confused with [DerivativeStructure]) + * A coroutineContext for [DiffExpression] (not to be confused with [DerivativeStructure]) */ object DiffExpressionContext : ExpressionContext, Field { override fun variable(name: String, default: Double?) = diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt new file mode 100644 index 000000000..1cbecb206 --- /dev/null +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/MCScope.kt @@ -0,0 +1,58 @@ +package scientifik.kmath.prob + +import kotlinx.coroutines.* +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.coroutines.coroutineContext + +/** + * A scope for a monte-carlo simulation or multi-coroutine random number generation + */ +class MCScope private constructor(val coroutineContext: CoroutineContext) { + fun mcScope(context: CoroutineContext): MCScope = MCScope(context + coroutineContext[Random]!!.split()) + + companion object { + suspend fun init(generator: RandomGenerator): MCScope = + MCScope(coroutineContext + Random(generator)) + } +} + +inline class Random(val generator: RandomGenerator) : CoroutineContext.Element { + override val key: CoroutineContext.Key<*> get() = Companion + + fun split(): Random = Random(generator.fork()) + + companion object : CoroutineContext.Key +} + +val CoroutineContext.random get() = this[Random]?.generator + +val MCScope.random get() = coroutineContext.random!! + +/** + * Launches a supervised Monte-Carlo scope + */ +suspend fun mc(generator: RandomGenerator, block: suspend MCScope.() -> T): T = + supervisorScope { + MCScope.init(generator).block() + } + +suspend fun mc(seed: Long = -1, block: suspend MCScope.() -> T): T = + mc(SplitRandomWrapper(seed), block) + +inline fun MCScope.launch( + context: CoroutineContext = EmptyCoroutineContext, + start: CoroutineStart = CoroutineStart.DEFAULT, + crossinline block: suspend MCScope.() -> Unit +): Job = CoroutineScope(coroutineContext).launch(context + coroutineContext[Random]!!.split(), start) { + mcScope(coroutineContext).block() +} + +inline fun MCScope.async( + context: CoroutineContext = EmptyCoroutineContext, + start: CoroutineStart = CoroutineStart.DEFAULT, + crossinline block: suspend MCScope.() -> T +): Deferred = CoroutineScope(coroutineContext).async(context, start) { + mcScope(coroutineContext).block() +} + diff --git a/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/SplitRandomWrapper.kt b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/SplitRandomWrapper.kt new file mode 100644 index 000000000..ca488a262 --- /dev/null +++ b/kmath-prob/src/jvmMain/kotlin/scientifik/kmath/prob/SplitRandomWrapper.kt @@ -0,0 +1,18 @@ +package scientifik.kmath.prob + +import java.util.* + +class SplitRandomWrapper(val random: SplittableRandom) : RandomGenerator { + + constructor(seed: Long) : this(SplittableRandom(seed)) + + override fun nextDouble(): Double = random.nextDouble() + + override fun nextInt(): Int = random.nextInt() + + override fun nextLong(): Long = random.nextLong() + + override fun nextBlock(size: Int): ByteArray = ByteArray(size).also { random.nextBytes(it) } + + override fun fork(): RandomGenerator = SplitRandomWrapper(random.split()) +} \ No newline at end of file diff --git a/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt new file mode 100644 index 000000000..a740661b3 --- /dev/null +++ b/kmath-prob/src/jvmTest/kotlin/scientifik/kmath/prob/MCScopeTest.kt @@ -0,0 +1,28 @@ +package scientifik.kmath.prob + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking +import kotlin.test.Test +import kotlin.test.assertEquals + +class MCScopeTest { + @Test + fun testGenerator(){ + runBlocking(Dispatchers.Default) { + mc(1122){ + launch { + repeat(9){ + println("first:${random.nextInt()}") + } + assertEquals(690819834,random.nextInt()) + } + launch { + repeat(9){ + println("second:${random.nextInt()}") + } + assertEquals(691997530,random.nextInt()) + } + } + } + } +} \ No newline at end of file