diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/CoroutinesExtra.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/CoroutinesExtra.kt index a0e11390d..8a48893eb 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/CoroutinesExtra.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/CoroutinesExtra.kt @@ -1,13 +1,32 @@ package scientifik.kmath import kotlinx.coroutines.* +import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.produce -import kotlinx.coroutines.flow.* +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.map val Dispatchers.Math: CoroutineDispatcher get() = Dispatchers.Default +/** + * An imitator of [Deferred] which holds a suspended function block and dispatcher + */ +class LazyDeferred(val dispatcher: CoroutineDispatcher, val block: suspend CoroutineScope.() -> T) { + private var deferred: Deferred? = null + + fun CoroutineScope.start() { + if(deferred==null) { + deferred = async(dispatcher, block = block) + } + } + + suspend fun await(): T = deferred?.await() ?: error("Coroutine not started") +} + @FlowPreview -inline class AsyncFlow(val deferredFlow: Flow>) : Flow { +inline class AsyncFlow(val deferredFlow: Flow>) : Flow { override suspend fun collect(collector: FlowCollector) { deferredFlow.collect { collector.emit((it.await())) @@ -18,32 +37,32 @@ inline class AsyncFlow(val deferredFlow: Flow>) : Flow { @FlowPreview fun Flow.async( dispatcher: CoroutineDispatcher = Dispatchers.Default, - block: suspend (T) -> R + block: suspend CoroutineScope.(T) -> R ): AsyncFlow { val flow = map { - coroutineScope { - async(dispatcher, start = CoroutineStart.LAZY) { block(it) } - } + LazyDeferred(dispatcher) { block(it) } } return AsyncFlow(flow) } @FlowPreview fun AsyncFlow.map(action: (T) -> R) = deferredFlow.map { input -> - coroutineScope { - async(start = CoroutineStart.LAZY) { action(input.await()) } + //TODO add actual composition + LazyDeferred(input.dispatcher) { + input.run { start() } + action(input.await()) } } @ExperimentalCoroutinesApi @FlowPreview -suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector){ +suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector) { require(concurrency >= 0) { "Buffer size should be positive, but was $concurrency" } coroutineScope { //Starting up to N deferred coroutines ahead of time val channel = produce(capacity = concurrency) { deferredFlow.collect { value -> - value.start() + value.run { start() } send(value) } } @@ -66,8 +85,9 @@ suspend fun AsyncFlow.collect(concurrency: Int, collector: FlowCollector< @ExperimentalCoroutinesApi @FlowPreview -suspend fun AsyncFlow.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit{ +suspend fun AsyncFlow.collect(concurrency: Int, action: suspend (value: T) -> Unit): Unit { collect(concurrency, object : FlowCollector { override suspend fun emit(value: T) = action(value) }) } + diff --git a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt b/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt index e3a38ff43..0102d615f 100644 --- a/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt +++ b/kmath-coroutines/src/jvmTest/kotlin/scientifik/kmath/streaming/BufferFlowTest.kt @@ -11,11 +11,10 @@ import scientifik.kmath.collect @FlowPreview class BufferFlowTest { - - @Test + @Test(timeout = 2000) fun mapParallel() { runBlocking { - (1..20).asFlow().async(Dispatchers.IO) { + (1..20).asFlow().async(Dispatchers.Default) { println("Started $it") @Suppress("BlockingMethodInNonBlockingContext") Thread.sleep(200)