forked from kscience/kmath
Parallel flow
This commit is contained in:
parent
08e14b15c5
commit
d138ce3889
@ -1,13 +1,32 @@
|
|||||||
package scientifik.kmath
|
package scientifik.kmath
|
||||||
|
|
||||||
import kotlinx.coroutines.*
|
import kotlinx.coroutines.*
|
||||||
|
import kotlinx.coroutines.channels.ReceiveChannel
|
||||||
import kotlinx.coroutines.channels.produce
|
import kotlinx.coroutines.channels.produce
|
||||||
import kotlinx.coroutines.flow.*
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
import kotlinx.coroutines.flow.FlowCollector
|
||||||
|
import kotlinx.coroutines.flow.collect
|
||||||
|
import kotlinx.coroutines.flow.map
|
||||||
|
|
||||||
val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default
|
val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An imitator of [Deferred] which holds a suspended function block and dispatcher
|
||||||
|
*/
|
||||||
|
class LazyDeferred<T>(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) {
|
||||||
|
private var deferred: Deferred<T>? = null
|
||||||
|
|
||||||
|
fun CoroutineScope.start() {
|
||||||
|
if(deferred==null) {
|
||||||
|
deferred = async(dispatcher, block = block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
suspend fun await(): T = deferred?.await() ?: error("Coroutine not started")
|
||||||
|
}
|
||||||
|
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
inline class AsyncFlow<T>(val deferredFlow: Flow<Deferred<T>>) : Flow<T> {
|
inline class AsyncFlow<T>(val deferredFlow: Flow<LazyDeferred<T>>) : Flow<T> {
|
||||||
override suspend fun collect(collector: FlowCollector<T>) {
|
override suspend fun collect(collector: FlowCollector<T>) {
|
||||||
deferredFlow.collect {
|
deferredFlow.collect {
|
||||||
collector.emit((it.await()))
|
collector.emit((it.await()))
|
||||||
@ -18,32 +37,32 @@ inline class AsyncFlow<T>(val deferredFlow: Flow<Deferred<T>>) : Flow<T> {
|
|||||||
@FlowPreview
|
@FlowPreview
|
||||||
fun <T, R> Flow<T>.async(
|
fun <T, R> Flow<T>.async(
|
||||||
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
dispatcher: CoroutineDispatcher = Dispatchers.Default,
|
||||||
block: suspend (T) -> R
|
block: suspend CoroutineScope.(T) -> R
|
||||||
): AsyncFlow<R> {
|
): AsyncFlow<R> {
|
||||||
val flow = map {
|
val flow = map {
|
||||||
coroutineScope {
|
LazyDeferred(dispatcher) { block(it) }
|
||||||
async(dispatcher, start = CoroutineStart.LAZY) { block(it) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return AsyncFlow(flow)
|
return AsyncFlow(flow)
|
||||||
}
|
}
|
||||||
|
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) = deferredFlow.map { input ->
|
fun <T, R> AsyncFlow<T>.map(action: (T) -> R) = deferredFlow.map { input ->
|
||||||
coroutineScope {
|
//TODO add actual composition
|
||||||
async(start = CoroutineStart.LAZY) { action(input.await()) }
|
LazyDeferred(input.dispatcher) {
|
||||||
|
input.run { start() }
|
||||||
|
action(input.await())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>){
|
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<T>) {
|
||||||
require(concurrency >= 0) { "Buffer size should be positive, but was $concurrency" }
|
require(concurrency >= 0) { "Buffer size should be positive, but was $concurrency" }
|
||||||
coroutineScope {
|
coroutineScope {
|
||||||
//Starting up to N deferred coroutines ahead of time
|
//Starting up to N deferred coroutines ahead of time
|
||||||
val channel = produce(capacity = concurrency) {
|
val channel = produce(capacity = concurrency) {
|
||||||
deferredFlow.collect { value ->
|
deferredFlow.collect { value ->
|
||||||
value.start()
|
value.run { start() }
|
||||||
send(value)
|
send(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -66,8 +85,9 @@ suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, collector: FlowCollector<
|
|||||||
|
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
@FlowPreview
|
@FlowPreview
|
||||||
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit{
|
suspend fun <T> AsyncFlow<T>.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit {
|
||||||
collect(concurrency, object : FlowCollector<T> {
|
collect(concurrency, object : FlowCollector<T> {
|
||||||
override suspend fun emit(value: T) = action(value)
|
override suspend fun emit(value: T) = action(value)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,11 +11,10 @@ import scientifik.kmath.collect
|
|||||||
@FlowPreview
|
@FlowPreview
|
||||||
class BufferFlowTest {
|
class BufferFlowTest {
|
||||||
|
|
||||||
|
@Test(timeout = 2000)
|
||||||
@Test
|
|
||||||
fun mapParallel() {
|
fun mapParallel() {
|
||||||
runBlocking {
|
runBlocking {
|
||||||
(1..20).asFlow().async(Dispatchers.IO) {
|
(1..20).asFlow().async(Dispatchers.Default) {
|
||||||
println("Started $it")
|
println("Started $it")
|
||||||
@Suppress("BlockingMethodInNonBlockingContext")
|
@Suppress("BlockingMethodInNonBlockingContext")
|
||||||
Thread.sleep(200)
|
Thread.sleep(200)
|
||||||
|
Loading…
Reference in New Issue
Block a user