Reimplement random-forging chain
This commit is contained in:
parent
48d0ee8126
commit
6619db3f45
@ -15,8 +15,6 @@ repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
val multikVersion: String by rootProject.extra
|
||||
|
||||
kotlin {
|
||||
jvm()
|
||||
|
||||
@ -45,7 +43,7 @@ kotlin {
|
||||
implementation(project(":kmath-for-real"))
|
||||
implementation(project(":kmath-tensors"))
|
||||
implementation(project(":kmath-multik"))
|
||||
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||
implementation(libs.multik.default)
|
||||
implementation(spclibs.kotlinx.benchmark.runtime)
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ import space.kscience.gradle.useSPCTeam
|
||||
|
||||
plugins {
|
||||
id("space.kscience.gradle.project")
|
||||
id("org.jetbrains.kotlinx.kover") version "0.7.6"
|
||||
alias(spclibs.plugins.kotlinx.kover)
|
||||
}
|
||||
|
||||
val attributesVersion by extra("0.2.0")
|
||||
@ -70,5 +70,3 @@ ksciencePublish {
|
||||
}
|
||||
|
||||
apiValidation.nonPublicMarkers.add("space.kscience.kmath.UnstableKMathAPI")
|
||||
|
||||
val multikVersion by extra("0.2.3")
|
||||
|
@ -10,8 +10,6 @@ repositories {
|
||||
maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers")
|
||||
}
|
||||
|
||||
val multikVersion: String by rootProject.extra
|
||||
|
||||
dependencies {
|
||||
implementation(project(":kmath-ast"))
|
||||
implementation(project(":kmath-kotlingrad"))
|
||||
@ -33,7 +31,7 @@ dependencies {
|
||||
implementation(project(":kmath-jafama"))
|
||||
//multik
|
||||
implementation(project(":kmath-multik"))
|
||||
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||
implementation(libs.multik.default)
|
||||
|
||||
//datetime
|
||||
implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.4.0")
|
||||
|
@ -1,9 +1,12 @@
|
||||
[versions]
|
||||
|
||||
commons-rng = "1.6"
|
||||
|
||||
multik = "0.2.3"
|
||||
|
||||
[libraries]
|
||||
|
||||
commons-rng-simple = { module = "org.apache.commons:commons-rng-simple", version.ref = "commons-rng" }
|
||||
commons-rng-sampling = { module = "org.apache.commons:commons-rng-sampling", version.ref = "commons-rng" }
|
||||
|
||||
multik-core = { module = "org.jetbrains.kotlinx:multik-core", version.ref = "multik" }
|
||||
multik-default = { module = "org.jetbrains.kotlinx:multik-default", version.ref = "multik" }
|
@ -4,8 +4,6 @@ plugins {
|
||||
|
||||
description = "JetBrains Multik connector"
|
||||
|
||||
val multikVersion: String by rootProject.extra
|
||||
|
||||
kscience {
|
||||
jvm()
|
||||
js()
|
||||
@ -16,12 +14,12 @@ kotlin {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(projects.kmathTensors)
|
||||
api("org.jetbrains.kotlinx:multik-core:$multikVersion")
|
||||
api(libs.multik.core)
|
||||
}
|
||||
}
|
||||
commonTest {
|
||||
dependencies {
|
||||
api("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||
api(libs.multik.default)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,71 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
@file:OptIn(ExperimentalCoroutinesApi::class)
|
||||
|
||||
package space.kscience.kmath.samplers
|
||||
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import space.kscience.kmath.chains.Chain
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.stat.Sampler
|
||||
import kotlin.coroutines.coroutineContext
|
||||
|
||||
|
||||
/**
|
||||
* A sampler that creates a chain that could be split at each computation
|
||||
*/
|
||||
public class RandomForkingSampler<T: Any>(
|
||||
private val scope: CoroutineScope,
|
||||
private val initialValue: T,
|
||||
private val makeStep: suspend RandomGenerator.(T) -> List<T>
|
||||
) : Sampler<T?> {
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<T?> = buildChain(scope, initialValue) { generator.makeStep(it) }
|
||||
|
||||
public companion object {
|
||||
private suspend fun <T> Channel<T>.receiveEvents(
|
||||
initial: T,
|
||||
makeStep: suspend (T) -> List<T>
|
||||
) {
|
||||
send(initial)
|
||||
//inner dispatch queue
|
||||
val innerChannel = Channel<T>(50)
|
||||
innerChannel.send(initial)
|
||||
while (coroutineContext.isActive && !innerChannel.isEmpty) {
|
||||
val current = innerChannel.receive()
|
||||
//add event immediately, but it does not mean that the value is computed immediately as well
|
||||
makeStep(current).forEach {
|
||||
innerChannel.send(it)
|
||||
send(it)
|
||||
}
|
||||
}
|
||||
innerChannel.close()
|
||||
close()
|
||||
}
|
||||
|
||||
|
||||
public fun <T: Any> buildChain(
|
||||
scope: CoroutineScope,
|
||||
initial: T,
|
||||
makeStep: suspend (T) -> List<T>
|
||||
): Chain<T?> {
|
||||
val channel = Channel<T>(Channel.RENDEZVOUS)
|
||||
scope.launch {
|
||||
channel.receiveEvents(initial, makeStep)
|
||||
}
|
||||
|
||||
return object : Chain<T?> {
|
||||
override suspend fun next(): T? = channel.receiveCatching().getOrNull()
|
||||
|
||||
override suspend fun fork(): Chain<T?> = buildChain(scope, channel.receive(), makeStep)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user