Added chain mappers
This commit is contained in:
parent
1b647b6014
commit
fe05fc5a14
@ -2,17 +2,17 @@ package scientifik.kmath.commons.prob
|
|||||||
|
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import scientifik.kmath.chains.Chain
|
import scientifik.kmath.chains.Chain
|
||||||
import scientifik.kmath.chains.StatefulChain
|
import scientifik.kmath.chains.mapWithState
|
||||||
import scientifik.kmath.prob.Distribution
|
import scientifik.kmath.prob.Distribution
|
||||||
import scientifik.kmath.prob.RandomGenerator
|
import scientifik.kmath.prob.RandomGenerator
|
||||||
|
|
||||||
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
||||||
|
|
||||||
fun Chain<Double>.mean(): Chain<Double> = StatefulChain(AveragingChainState(), 0.0) {
|
fun Chain<Double>.mean(): Chain<Double> = mapWithState(AveragingChainState(),{it.copy()}){chain->
|
||||||
val next = this@mean.next()
|
val next = chain.next()
|
||||||
num++
|
num++
|
||||||
value += next
|
value += next
|
||||||
return@StatefulChain value / num
|
return@mapWithState value / num
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
package scientifik.kmath.chains
|
package scientifik.kmath.chains
|
||||||
|
|
||||||
import kotlinx.atomicfu.atomic
|
import kotlinx.atomicfu.atomic
|
||||||
|
import kotlinx.atomicfu.updateAndGet
|
||||||
import kotlinx.coroutines.FlowPreview
|
import kotlinx.coroutines.FlowPreview
|
||||||
import kotlinx.coroutines.flow.Flow
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
|
||||||
@ -26,11 +27,6 @@ import kotlinx.coroutines.flow.Flow
|
|||||||
* @param R - the chain element type
|
* @param R - the chain element type
|
||||||
*/
|
*/
|
||||||
interface Chain<out R> {
|
interface Chain<out R> {
|
||||||
/**
|
|
||||||
* Last cached value of the chain. Returns null if [next] was not called
|
|
||||||
*/
|
|
||||||
fun peek(): R?
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate next value, changing state if needed
|
* Generate next value, changing state if needed
|
||||||
*/
|
*/
|
||||||
@ -53,84 +49,59 @@ val <R> Chain<R>.flow: Flow<R>
|
|||||||
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
|
||||||
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
|
||||||
|
|
||||||
/**
|
|
||||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
|
||||||
* since mapped chain consumes tokens. Accepts regular transformation function
|
|
||||||
*/
|
|
||||||
fun <T, R> Chain<T>.map(func: (T) -> R): Chain<R> {
|
|
||||||
val parent = this;
|
|
||||||
return object : Chain<R> {
|
|
||||||
override fun peek(): R? = parent.peek()?.let(func)
|
|
||||||
|
|
||||||
override suspend fun next(): R {
|
|
||||||
return func(parent.next())
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
|
||||||
return parent.fork().map(func)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple chain of independent tokens
|
* A simple chain of independent tokens
|
||||||
*/
|
*/
|
||||||
class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
|
||||||
private val atomicValue = atomic<R?>(null)
|
override suspend fun next(): R = gen()
|
||||||
override fun peek(): R? = atomicValue.value
|
|
||||||
|
|
||||||
override suspend fun next(): R = gen().also { atomicValue.lazySet(it) }
|
|
||||||
|
|
||||||
override fun fork(): Chain<R> = this
|
override fun fork(): Chain<R> = this
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO force forks on mapping operations?
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A stateless Markov chain
|
* A stateless Markov chain
|
||||||
*/
|
*/
|
||||||
class MarkovChain<out R : Any>(private val seed: () -> R, private val gen: suspend (R) -> R) :
|
class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
|
||||||
Chain<R> {
|
|
||||||
|
|
||||||
constructor(seed: R, gen: suspend (R) -> R) : this({ seed }, gen)
|
constructor(seedValue: R, gen: suspend (R) -> R) : this({ seedValue }, gen)
|
||||||
|
|
||||||
private val atomicValue = atomic<R?>(null)
|
private val value = atomic<R?>(null)
|
||||||
override fun peek(): R = atomicValue.value ?: seed()
|
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
val newValue = gen(peek())
|
return value.updateAndGet { prev -> gen(prev ?: seed()) }!!
|
||||||
atomicValue.lazySet(newValue)
|
|
||||||
return peek()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
override fun fork(): Chain<R> {
|
||||||
return MarkovChain(peek(), gen)
|
return MarkovChain(seed = { value.value ?: seed() }, gen = gen)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share the state
|
* A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share the state
|
||||||
* @param S - the state of the chain
|
* @param S - the state of the chain
|
||||||
|
* @param forkState - the function to copy current state without modifying it
|
||||||
*/
|
*/
|
||||||
class StatefulChain<S, out R>(
|
class StatefulChain<S, out R>(
|
||||||
private val state: S,
|
private val state: S,
|
||||||
private val seed: S.() -> R,
|
private val seed: S.() -> R,
|
||||||
|
private val forkState: ((S) -> S),
|
||||||
private val gen: suspend S.(R) -> R
|
private val gen: suspend S.(R) -> R
|
||||||
) : Chain<R> {
|
) : Chain<R> {
|
||||||
|
|
||||||
constructor(state: S, seed: R, gen: suspend S.(R) -> R) : this(state, { seed }, gen)
|
constructor(state: S, seedValue: R, forkState: ((S) -> S), gen: suspend S.(R) -> R) : this(
|
||||||
|
state,
|
||||||
|
{ seedValue },
|
||||||
|
forkState,
|
||||||
|
gen
|
||||||
|
)
|
||||||
|
|
||||||
private val atomicValue = atomic<R?>(null)
|
private val atomicValue = atomic<R?>(null)
|
||||||
override fun peek(): R = atomicValue.value ?: seed(state)
|
|
||||||
|
|
||||||
override suspend fun next(): R {
|
override suspend fun next(): R {
|
||||||
val newValue = gen(state, peek())
|
return atomicValue.updateAndGet { prev -> state.gen(prev ?: state.seed()) }!!
|
||||||
atomicValue.lazySet(newValue)
|
|
||||||
return peek()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
override fun fork(): Chain<R> {
|
||||||
throw RuntimeException("Fork not supported for stateful chain")
|
return StatefulChain(forkState(state), seed, forkState, gen)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,11 +109,32 @@ class StatefulChain<S, out R>(
|
|||||||
* A chain that repeats the same value
|
* A chain that repeats the same value
|
||||||
*/
|
*/
|
||||||
class ConstantChain<out T>(val value: T) : Chain<T> {
|
class ConstantChain<out T>(val value: T) : Chain<T> {
|
||||||
override fun peek(): T? = value
|
|
||||||
|
|
||||||
override suspend fun next(): T = value
|
override suspend fun next(): T = value
|
||||||
|
|
||||||
override fun fork(): Chain<T> {
|
override fun fork(): Chain<T> {
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
||||||
|
* since mapped chain consumes tokens. Accepts regular transformation function
|
||||||
|
*/
|
||||||
|
fun <T, R> Chain<T>.pipe(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
|
||||||
|
override suspend fun next(): R = func(this@pipe.next())
|
||||||
|
override fun fork(): Chain<R> = this@pipe.fork().pipe(func)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map the whole chain
|
||||||
|
*/
|
||||||
|
fun <T, R> Chain<T>.map(mapper: suspend (Chain<T>) -> R): Chain<R> = object : Chain<R> {
|
||||||
|
override suspend fun next(): R = mapper(this@map)
|
||||||
|
override fun fork(): Chain<R> = this@map.fork().map(mapper)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T, S, R> Chain<T>.mapWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
|
||||||
|
object : Chain<R> {
|
||||||
|
override suspend fun next(): R = state.mapper(this@mapWithState)
|
||||||
|
override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper)
|
||||||
|
}
|
@ -18,23 +18,3 @@ operator fun <R> Chain<R>.iterator() = object : Iterator<R> {
|
|||||||
fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> {
|
fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> {
|
||||||
override fun iterator(): Iterator<R> = this@asSequence.iterator()
|
override fun iterator(): Iterator<R> = this@asSequence.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
|
|
||||||
* since mapped chain consumes tokens. Accepts suspending transformation function.
|
|
||||||
*/
|
|
||||||
fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> {
|
|
||||||
val parent = this;
|
|
||||||
return object : Chain<R> {
|
|
||||||
override fun peek(): R? = runBlocking { parent.peek()?.let { func(it) } }
|
|
||||||
|
|
||||||
override suspend fun next(): R {
|
|
||||||
return func(parent.next())
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun fork(): Chain<R> {
|
|
||||||
return parent.fork().map(func)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user