Added chain mappers

This commit is contained in:
Alexander Nozik 2019-06-09 10:06:48 +03:00
parent 1b647b6014
commit fe05fc5a14
3 changed files with 46 additions and 74 deletions

View File

@ -2,17 +2,17 @@ package scientifik.kmath.commons.prob
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import scientifik.kmath.chains.Chain import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.StatefulChain import scientifik.kmath.chains.mapWithState
import scientifik.kmath.prob.Distribution import scientifik.kmath.prob.Distribution
import scientifik.kmath.prob.RandomGenerator import scientifik.kmath.prob.RandomGenerator
data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
fun Chain<Double>.mean(): Chain<Double> = StatefulChain(AveragingChainState(), 0.0) { fun Chain<Double>.mean(): Chain<Double> = mapWithState(AveragingChainState(),{it.copy()}){chain->
val next = this@mean.next() val next = chain.next()
num++ num++
value += next value += next
return@StatefulChain value / num return@mapWithState value / num
} }
@ -23,7 +23,7 @@ fun main() {
runBlocking { runBlocking {
repeat(10001) { counter -> repeat(10001) { counter ->
val mean = chain.next() val mean = chain.next()
if(counter % 1000 ==0){ if (counter % 1000 == 0) {
println("[$counter] Average value is $mean") println("[$counter] Average value is $mean")
} }
} }

View File

@ -17,6 +17,7 @@
package scientifik.kmath.chains package scientifik.kmath.chains
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.updateAndGet
import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
@ -26,11 +27,6 @@ import kotlinx.coroutines.flow.Flow
* @param R - the chain element type * @param R - the chain element type
*/ */
interface Chain<out R> { interface Chain<out R> {
/**
* Last cached value of the chain. Returns null if [next] was not called
*/
fun peek(): R?
/** /**
* Generate next value, changing state if needed * Generate next value, changing state if needed
*/ */
@ -53,84 +49,59 @@ val <R> Chain<R>.flow: Flow<R>
fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() } fun <T> Iterator<T>.asChain(): Chain<T> = SimpleChain { next() }
fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain() fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
/**
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
* since mapped chain consumes tokens. Accepts regular transformation function
*/
fun <T, R> Chain<T>.map(func: (T) -> R): Chain<R> {
val parent = this;
return object : Chain<R> {
override fun peek(): R? = parent.peek()?.let(func)
override suspend fun next(): R {
return func(parent.next())
}
override fun fork(): Chain<R> {
return parent.fork().map(func)
}
}
}
/** /**
* A simple chain of independent tokens * A simple chain of independent tokens
*/ */
class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> { class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
private val atomicValue = atomic<R?>(null) override suspend fun next(): R = gen()
override fun peek(): R? = atomicValue.value
override suspend fun next(): R = gen().also { atomicValue.lazySet(it) }
override fun fork(): Chain<R> = this override fun fork(): Chain<R> = this
} }
//TODO force forks on mapping operations?
/** /**
* A stateless Markov chain * A stateless Markov chain
*/ */
class MarkovChain<out R : Any>(private val seed: () -> R, private val gen: suspend (R) -> R) : class MarkovChain<out R : Any>(private val seed: suspend () -> R, private val gen: suspend (R) -> R) : Chain<R> {
Chain<R> {
constructor(seed: R, gen: suspend (R) -> R) : this({ seed }, gen) constructor(seedValue: R, gen: suspend (R) -> R) : this({ seedValue }, gen)
private val atomicValue = atomic<R?>(null) private val value = atomic<R?>(null)
override fun peek(): R = atomicValue.value ?: seed()
override suspend fun next(): R { override suspend fun next(): R {
val newValue = gen(peek()) return value.updateAndGet { prev -> gen(prev ?: seed()) }!!
atomicValue.lazySet(newValue)
return peek()
} }
override fun fork(): Chain<R> { override fun fork(): Chain<R> {
return MarkovChain(peek(), gen) return MarkovChain(seed = { value.value ?: seed() }, gen = gen)
} }
} }
/** /**
* A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share the state * A chain with possibly mutable state. The state must not be changed outside the chain. Two chins should never share the state
* @param S - the state of the chain * @param S - the state of the chain
* @param forkState - the function to copy current state without modifying it
*/ */
class StatefulChain<S, out R>( class StatefulChain<S, out R>(
private val state: S, private val state: S,
private val seed: S.() -> R, private val seed: S.() -> R,
private val forkState: ((S) -> S),
private val gen: suspend S.(R) -> R private val gen: suspend S.(R) -> R
) : Chain<R> { ) : Chain<R> {
constructor(state: S, seed: R, gen: suspend S.(R) -> R) : this(state, { seed }, gen) constructor(state: S, seedValue: R, forkState: ((S) -> S), gen: suspend S.(R) -> R) : this(
state,
{ seedValue },
forkState,
gen
)
private val atomicValue = atomic<R?>(null) private val atomicValue = atomic<R?>(null)
override fun peek(): R = atomicValue.value ?: seed(state)
override suspend fun next(): R { override suspend fun next(): R {
val newValue = gen(state, peek()) return atomicValue.updateAndGet { prev -> state.gen(prev ?: state.seed()) }!!
atomicValue.lazySet(newValue)
return peek()
} }
override fun fork(): Chain<R> { override fun fork(): Chain<R> {
throw RuntimeException("Fork not supported for stateful chain") return StatefulChain(forkState(state), seed, forkState, gen)
} }
} }
@ -138,11 +109,32 @@ class StatefulChain<S, out R>(
* A chain that repeats the same value * A chain that repeats the same value
*/ */
class ConstantChain<out T>(val value: T) : Chain<T> { class ConstantChain<out T>(val value: T) : Chain<T> {
override fun peek(): T? = value
override suspend fun next(): T = value override suspend fun next(): T = value
override fun fork(): Chain<T> { override fun fork(): Chain<T> {
return this return this
} }
} }
/**
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
* since mapped chain consumes tokens. Accepts regular transformation function
*/
fun <T, R> Chain<T>.pipe(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
override suspend fun next(): R = func(this@pipe.next())
override fun fork(): Chain<R> = this@pipe.fork().pipe(func)
}
/**
* Map the whole chain
*/
fun <T, R> Chain<T>.map(mapper: suspend (Chain<T>) -> R): Chain<R> = object : Chain<R> {
override suspend fun next(): R = mapper(this@map)
override fun fork(): Chain<R> = this@map.fork().map(mapper)
}
fun <T, S, R> Chain<T>.mapWithState(state: S, stateFork: (S) -> S, mapper: suspend S.(Chain<T>) -> R): Chain<R> =
object : Chain<R> {
override suspend fun next(): R = state.mapper(this@mapWithState)
override fun fork(): Chain<R> = this@mapWithState.fork().mapWithState(stateFork(state), stateFork, mapper)
}

View File

@ -17,24 +17,4 @@ operator fun <R> Chain<R>.iterator() = object : Iterator<R> {
*/ */
fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> { fun <R> Chain<R>.asSequence(): Sequence<R> = object : Sequence<R> {
override fun iterator(): Iterator<R> = this@asSequence.iterator() override fun iterator(): Iterator<R> = this@asSequence.iterator()
}
/**
* Map the chain result using suspended transformation. Initial chain result can no longer be safely consumed
* since mapped chain consumes tokens. Accepts suspending transformation function.
*/
fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> {
val parent = this;
return object : Chain<R> {
override fun peek(): R? = runBlocking { parent.peek()?.let { func(it) } }
override suspend fun next(): R {
return func(parent.next())
}
override fun fork(): Chain<R> {
return parent.fork().map(func)
}
}
} }