diff --git a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt b/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt index 831a167bf..401dde780 100644 --- a/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt +++ b/examples/src/main/kotlin/scientifik/kmath/commons/prob/DistributionDemo.kt @@ -2,17 +2,17 @@ package scientifik.kmath.commons.prob import kotlinx.coroutines.runBlocking import scientifik.kmath.chains.Chain -import scientifik.kmath.chains.StatefulChain +import scientifik.kmath.chains.mapWithState import scientifik.kmath.prob.Distribution import scientifik.kmath.prob.RandomGenerator data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) -fun Chain.mean(): Chain = StatefulChain(AveragingChainState(), 0.0) { - val next = this@mean.next() +fun Chain.mean(): Chain = mapWithState(AveragingChainState(),{it.copy()}){chain-> + val next = chain.next() num++ value += next - return@StatefulChain value / num + return@mapWithState value / num } @@ -23,7 +23,7 @@ fun main() { runBlocking { repeat(10001) { counter -> val mean = chain.next() - if(counter % 1000 ==0){ + if (counter % 1000 == 0) { println("[$counter] Average value is $mean") } } diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt index 4adc2a9a7..58841e7b8 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt @@ -17,6 +17,7 @@ package scientifik.kmath.chains import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.updateAndGet import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.flow.Flow @@ -26,11 +27,6 @@ import kotlinx.coroutines.flow.Flow * @param R - the chain element type */ interface Chain { - /** - * Last cached value of the chain. Returns null if [next] was not called - */ - fun peek(): R? - /** * Generate next value, changing state if needed */ @@ -53,84 +49,59 @@ val Chain.flow: Flow fun Iterator.asChain(): Chain = SimpleChain { next() } fun Sequence.asChain(): Chain = iterator().asChain() -/** - * Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed - * since mapped chain consumes tokens. Accepts regular transformation function - */ -fun Chain.map(func: (T) -> R): Chain { - val parent = this; - return object : Chain { - override fun peek(): R? = parent.peek()?.let(func) - - override suspend fun next(): R { - return func(parent.next()) - } - - override fun fork(): Chain { - return parent.fork().map(func) - } - } -} - /** * A simple chain of independent tokens */ class SimpleChain(private val gen: suspend () -> R) : Chain { - private val atomicValue = atomic(null) - override fun peek(): R? = atomicValue.value - - override suspend fun next(): R = gen().also { atomicValue.lazySet(it) } - + override suspend fun next(): R = gen() override fun fork(): Chain = this } -//TODO force forks on mapping operations? - /** * A stateless Markov chain */ -class MarkovChain(private val seed: () -> R, private val gen: suspend (R) -> R) : - Chain { +class MarkovChain(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain { - constructor(seed: R, gen: suspend (R) -> R) : this({ seed }, gen) + constructor(seedValue: R, gen: suspend (R) -> R) : this({ seedValue }, gen) - private val atomicValue = atomic(null) - override fun peek(): R = atomicValue.value ?: seed() + private val value = atomic(null) override suspend fun next(): R { - val newValue = gen(peek()) - atomicValue.lazySet(newValue) - return peek() + return value.updateAndGet { prev -> gen(prev ?: seed()) }!! } override fun fork(): Chain { - return MarkovChain(peek(), gen) + return MarkovChain(seed = { value.value ?: seed() }, gen = gen) } } /** * A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share the state * @param S - the state of the chain + * @param forkState - the function to copy current state without modifying it */ class StatefulChain( private val state: S, private val seed: S.() -> R, + private val forkState: ((S) -> S), private val gen: suspend S.(R) -> R ) : Chain { - constructor(state: S, seed: R, gen: suspend S.(R) -> R) : this(state, { seed }, gen) + constructor(state: S, seedValue: R, forkState: ((S) -> S), gen: suspend S.(R) -> R) : this( + state, + { seedValue }, + forkState, + gen + ) private val atomicValue = atomic(null) - override fun peek(): R = atomicValue.value ?: seed(state) override suspend fun next(): R { - val newValue = gen(state, peek()) - atomicValue.lazySet(newValue) - return peek() + return atomicValue.updateAndGet { prev -> state.gen(prev ?: state.seed()) }!! } override fun fork(): Chain { - throw RuntimeException("Fork not supported for stateful chain") + return StatefulChain(forkState(state), seed, forkState, gen) } } @@ -138,11 +109,32 @@ class StatefulChain( * A chain that repeats the same value */ class ConstantChain(val value: T) : Chain { - override fun peek(): T? = value - override suspend fun next(): T = value override fun fork(): Chain { return this } -} \ No newline at end of file +} + +/** + * Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed + * since mapped chain consumes tokens. Accepts regular transformation function + */ +fun Chain.pipe(func: suspend (T) -> R): Chain = object : Chain { + override suspend fun next(): R = func(this@pipe.next()) + override fun fork(): Chain = this@pipe.fork().pipe(func) +} + +/** + * Map the whole chain + */ +fun Chain.map(mapper: suspend (Chain) -> R): Chain = object : Chain { + override suspend fun next(): R = mapper(this@map) + override fun fork(): Chain = this@map.fork().map(mapper) +} + +fun Chain.mapWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain) -> R): Chain = + object : Chain { + override suspend fun next(): R = state.mapper(this@mapWithState) + override fun fork(): Chain = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper) + } \ No newline at end of file diff --git a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt index b634c5ae2..013ea2922 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/scientifik/kmath/chains/ChainExt.kt @@ -17,24 +17,4 @@ operator fun Chain.iterator() = object : Iterator { */ fun Chain.asSequence(): Sequence = object : Sequence { override fun iterator(): Iterator = this@asSequence.iterator() -} - - -/** - * Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed - * since mapped chain consumes tokens. Accepts suspending transformation function. - */ -fun Chain.map(func: suspend (T) -> R): Chain { - val parent = this; - return object : Chain { - override fun peek(): R? = runBlocking { parent.peek()?.let { func(it) } } - - override suspend fun next(): R { - return func(parent.next()) - } - - override fun fork(): Chain { - return parent.fork().map(func) - } - } } \ No newline at end of file