forked from kscience/kmath
Added chain mappers
This commit is contained in:
parent
1b647b6014
commit
fe05fc5a14
@ -2,17 +2,17 @@ package scientifik.kmath.commons.prob
|
||||
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.chains.StatefulChain
|
||||
import scientifik.kmath.chains.mapWithState
|
||||
import scientifik.kmath.prob.Distribution
|
||||
import scientifik.kmath.prob.RandomGenerator
|
||||
|
||||
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
||||
|
||||
fun Chain<Double>.mean(): Chain<Double> = StatefulChain(AveragingChainState(), 0.0) {
|
||||
val next = this@mean.next()
|
||||
fun Chain<Double>.mean(): Chain<Double> = mapWithState(AveragingChainState(),{it.copy()}){chain->
|
||||
val next = chain.next()
|
||||
num++
|
||||
value += next
|
||||
return@StatefulChain value / num
|
||||
return@mapWithState value / num
|
||||
}
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ fun main() {
|
||||
runBlocking {
|
||||
repeat(10001) { counter ->
|
||||
val mean = chain.next()
|
||||
if(counter % 1000 ==0){
|
||||
if (counter % 1000 == 0) {
|
||||
println("[$counter] Average value is $mean")
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
package scientifik.kmath.chains
|
||||
|
||||
import kotlinx.atomicfu.atomic
|
||||
import kotlinx.atomicfu.updateAndGet
|
||||
import kotlinx.coroutines.FlowPreview
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
@ -26,11 +27,6 @@ import kotlinx.coroutines.flow.Flow
|
||||
* @param R - the chain element type
|
||||
*/
|
||||
interface Chain<out R> {
|
||||
/**
|
||||
* Last cached value of the chain. Returns null if [next] was not called
|
||||
*/
|
||||
fun peek(): R?
|
||||
|
||||
/**
|
||||
* Generate next value, changing state if needed
|
||||
*/
|
||||
@ -53,84 +49,59 @@ val <R> Chain<R>.flow: Flow<R>
|
||||
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
||||
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
||||
|
||||
/**
|
||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||
* since mapped chain consumes tokens. Accepts regular transformation function
|
||||
*/
|
||||
fun <T, R> Chain<T>.map(func: (T) -> R): Chain<R> {
|
||||
val parent = this;
|
||||
return object : Chain<R> {
|
||||
override fun peek(): R? = parent.peek()?.let(func)
|
||||
|
||||
override suspend fun next(): R {
|
||||
return func(parent.next())
|
||||
}
|
||||
|
||||
override fun fork(): Chain<R> {
|
||||
return parent.fork().map(func)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A simple chain of independent tokens
|
||||
*/
|
||||
class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
||||
private val atomicValue = atomic<R?>(null)
|
||||
override fun peek(): R? = atomicValue.value
|
||||
|
||||
override suspend fun next(): R = gen().also { atomicValue.lazySet(it) }
|
||||
|
||||
override suspend fun next(): R = gen()
|
||||
override fun fork(): Chain<R> = this
|
||||
}
|
||||
|
||||
//TODO force forks on mapping operations?
|
||||
|
||||
/**
|
||||
* A stateless Markov chain
|
||||
*/
|
||||
class MarkovChain<out R : Any>(private val seed: () -> R, private val gen: suspend (R) -> R) :
|
||||
Chain<R> {
|
||||
class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
||||
|
||||
constructor(seed: R, gen: suspend (R) -> R) : this({ seed }, gen)
|
||||
constructor(seedValue: R, gen: suspend (R) -> R) : this({ seedValue }, gen)
|
||||
|
||||
private val atomicValue = atomic<R?>(null)
|
||||
override fun peek(): R = atomicValue.value ?: seed()
|
||||
private val value = atomic<R?>(null)
|
||||
|
||||
override suspend fun next(): R {
|
||||
val newValue = gen(peek())
|
||||
atomicValue.lazySet(newValue)
|
||||
return peek()
|
||||
return value.updateAndGet { prev -> gen(prev ?: seed()) }!!
|
||||
}
|
||||
|
||||
override fun fork(): Chain<R> {
|
||||
return MarkovChain(peek(), gen)
|
||||
return MarkovChain(seed = { value.value ?: seed() }, gen = gen)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share the state
|
||||
* @param S - the state of the chain
|
||||
* @param forkState - the function to copy current state without modifying it
|
||||
*/
|
||||
class StatefulChain<S, out R>(
|
||||
private val state: S,
|
||||
private val seed: S.() -> R,
|
||||
private val forkState: ((S) -> S),
|
||||
private val gen: suspend S.(R) -> R
|
||||
) : Chain<R> {
|
||||
|
||||
constructor(state: S, seed: R, gen: suspend S.(R) -> R) : this(state, { seed }, gen)
|
||||
constructor(state: S, seedValue: R, forkState: ((S) -> S), gen: suspend S.(R) -> R) : this(
|
||||
state,
|
||||
{ seedValue },
|
||||
forkState,
|
||||
gen
|
||||
)
|
||||
|
||||
private val atomicValue = atomic<R?>(null)
|
||||
override fun peek(): R = atomicValue.value ?: seed(state)
|
||||
|
||||
override suspend fun next(): R {
|
||||
val newValue = gen(state, peek())
|
||||
atomicValue.lazySet(newValue)
|
||||
return peek()
|
||||
return atomicValue.updateAndGet { prev -> state.gen(prev ?: state.seed()) }!!
|
||||
}
|
||||
|
||||
override fun fork(): Chain<R> {
|
||||
throw RuntimeException("Fork not supported for stateful chain")
|
||||
return StatefulChain(forkState(state), seed, forkState, gen)
|
||||
}
|
||||
}
|
||||
|
||||
@ -138,11 +109,32 @@ class StatefulChain<S, out R>(
|
||||
* A chain that repeats the same value
|
||||
*/
|
||||
class ConstantChain<out T>(val value: T) : Chain<T> {
|
||||
override fun peek(): T? = value
|
||||
|
||||
override suspend fun next(): T = value
|
||||
|
||||
override fun fork(): Chain<T> {
|
||||
return this
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||
* since mapped chain consumes tokens. Accepts regular transformation function
|
||||
*/
|
||||
fun <T, R> Chain<T>.pipe(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
|
||||
override suspend fun next(): R = func(this@pipe.next())
|
||||
override fun fork(): Chain<R> = this@pipe.fork().pipe(func)
|
||||
}
|
||||
|
||||
/**
|
||||
* Map the whole chain
|
||||
*/
|
||||
fun <T, R> Chain<T>.map(mapper: suspend (Chain<T>) -> R): Chain<R> = object : Chain<R> {
|
||||
override suspend fun next(): R = mapper(this@map)
|
||||
override fun fork(): Chain<R> = this@map.fork().map(mapper)
|
||||
}
|
||||
|
||||
fun <T, S, R> Chain<T>.mapWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
||||
object : Chain<R> {
|
||||
override suspend fun next(): R = state.mapper(this@mapWithState)
|
||||
override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper)
|
||||
}
|
@ -18,23 +18,3 @@ operator fun <R> Chain<R>.iterator() = object : Iterator<R> {
|
||||
fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> {
|
||||
override fun iterator(): Iterator<R> = this@asSequence.iterator()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||
* since mapped chain consumes tokens. Accepts suspending transformation function.
|
||||
*/
|
||||
fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> {
|
||||
val parent = this;
|
||||
return object : Chain<R> {
|
||||
override fun peek(): R? = runBlocking { parent.peek()?.let { func(it) } }
|
||||
|
||||
override suspend fun next(): R {
|
||||
return func(parent.next())
|
||||
}
|
||||
|
||||
override fun fork(): Chain<R> {
|
||||
return parent.fork().map(func)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user