Parallel flow

This commit is contained in:
Alexander Nozik 2019-04-28 13:56:19 +03:00
parent 08e14b15c5
commit d138ce3889
2 changed files with 33 additions and 14 deletions

View File

@ -1,13 +1,32 @@
package scientifik.kmath package scientifik.kmath
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.produce import kotlinx.coroutines.channels.produce
import kotlinx.coroutines.flow.* import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.map
val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default
/**
* An imitator of [Deferred] which holds a suspended function block and dispatcher
*/
class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) {
private var deferred: Deferred<T>? = null
fun CoroutineScope.start() {
if(deferred==null) {
deferred = async(dispatcher, block = block)
}
}
suspend fun await(): T = deferred?.await() ?: error("Coroutine not started")
}
@FlowPreview @FlowPreview
inline class AsyncFlow<T>(val deferredFlow: Flow<Deferred<T>>) : Flow<T> { inline class AsyncFlow<T>(val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
override suspend fun collect(collector: FlowCollector<T>) { override suspend fun collect(collector: FlowCollector<T>) {
deferredFlow.collect { deferredFlow.collect {
collector.emit((it.await())) collector.emit((it.await()))
@ -18,32 +37,32 @@ inline class AsyncFlow<T>(val deferredFlow: Flow<Deferred<T>>) : Flow<T> {
@FlowPreview @FlowPreview
fun <T, R> Flow<T>.async( fun <T, R> Flow<T>.async(
dispatcher: CoroutineDispatcher = Dispatchers.Default, dispatcher: CoroutineDispatcher = Dispatchers.Default,
block: suspend (T) -> R block: suspend CoroutineScope.(T) -> R
): AsyncFlow<R> { ): AsyncFlow<R> {
val flow = map { val flow = map {
coroutineScope { LazyDeferred(dispatcher) { block(it) }
async(dispatcher, start = CoroutineStart.LAZY) { block(it) }
}
} }
return AsyncFlow(flow) return AsyncFlow(flow)
} }
@FlowPreview @FlowPreview
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) = deferredFlow.map { input -> fun <T, R> AsyncFlow<T>.map(action: (T) -> R) = deferredFlow.map { input ->
coroutineScope { //TODO add actual composition
async(start = CoroutineStart.LAZY) { action(input.await()) } LazyDeferred(input.dispatcher) {
input.run { start() }
action(input.await())
} }
} }
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi
@FlowPreview @FlowPreview
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>){ suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) {
require(concurrency >= 0) { "Buffer size should be positive, but was $concurrency" } require(concurrency >= 0) { "Buffer size should be positive, but was $concurrency" }
coroutineScope { coroutineScope {
//Starting up to N deferred coroutines ahead of time //Starting up to N deferred coroutines ahead of time
val channel = produce(capacity = concurrency) { val channel = produce(capacity = concurrency) {
deferredFlow.collect { value -> deferredFlow.collect { value ->
value.start() value.run { start() }
send(value) send(value)
} }
} }
@ -66,8 +85,9 @@ suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi
@FlowPreview @FlowPreview
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit{ suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit {
collect(concurrency, object : FlowCollector<T> { collect(concurrency, object : FlowCollector<T> {
override suspend fun emit(value: T) = action(value) override suspend fun emit(value: T) = action(value)
}) })
} }

View File

@ -11,11 +11,10 @@ import scientifik.kmath.collect
@FlowPreview @FlowPreview
class BufferFlowTest { class BufferFlowTest {
@Test(timeout = 2000)
@Test
fun mapParallel() { fun mapParallel() {
runBlocking { runBlocking {
(1..20).asFlow().async(Dispatchers.IO) { (1..20).asFlow().async(Dispatchers.Default) {
println("Started $it") println("Started $it")
@Suppress("BlockingMethodInNonBlockingContext") @Suppress("BlockingMethodInNonBlockingContext")
Thread.sleep(200) Thread.sleep(200)