From 6619db3f455da670a3cf363141b8b57c6dade365 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 9 Aug 2024 10:22:37 +0300 Subject: [PATCH] Reimplement random-forging chain --- benchmarks/build.gradle.kts | 4 +- build.gradle.kts | 4 +- examples/build.gradle.kts | 4 +- gradle/libs.versions.toml | 9 ++- kmath-multik/build.gradle.kts | 6 +- .../kmath/samplers/RandomForkingEvent.kt | 71 +++++++++++++++++++ 6 files changed, 82 insertions(+), 16 deletions(-) create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt diff --git a/benchmarks/build.gradle.kts b/benchmarks/build.gradle.kts index 9e0d095d6..b0446946a 100644 --- a/benchmarks/build.gradle.kts +++ b/benchmarks/build.gradle.kts @@ -15,8 +15,6 @@ repositories { mavenCentral() } -val multikVersion: String by rootProject.extra - kotlin { jvm() @@ -45,7 +43,7 @@ kotlin { implementation(project(":kmath-for-real")) implementation(project(":kmath-tensors")) implementation(project(":kmath-multik")) - implementation("org.jetbrains.kotlinx:multik-default:$multikVersion") + implementation(libs.multik.default) implementation(spclibs.kotlinx.benchmark.runtime) } } diff --git a/build.gradle.kts b/build.gradle.kts index 57ff022f2..1b4b4b3ab 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -3,7 +3,7 @@ import space.kscience.gradle.useSPCTeam plugins { id("space.kscience.gradle.project") - id("org.jetbrains.kotlinx.kover") version "0.7.6" + alias(spclibs.plugins.kotlinx.kover) } val attributesVersion by extra("0.2.0") @@ -70,5 +70,3 @@ ksciencePublish { } apiValidation.nonPublicMarkers.add("space.kscience.kmath.UnstableKMathAPI") - -val multikVersion by extra("0.2.3") diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index b74d6edca..74ff891d5 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -10,8 +10,6 @@ repositories { maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers") } -val multikVersion: String by rootProject.extra - dependencies { implementation(project(":kmath-ast")) implementation(project(":kmath-kotlingrad")) @@ -33,7 +31,7 @@ dependencies { implementation(project(":kmath-jafama")) //multik implementation(project(":kmath-multik")) - implementation("org.jetbrains.kotlinx:multik-default:$multikVersion") + implementation(libs.multik.default) //datetime implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.4.0") diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 9796ed176..e1add1777 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,9 +1,12 @@ [versions] commons-rng = "1.6" - +multik = "0.2.3" [libraries] -commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"} -commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"} \ No newline at end of file +commons-rng-simple = { module = "org.apache.commons:commons-rng-simple", version.ref = "commons-rng" } +commons-rng-sampling = { module = "org.apache.commons:commons-rng-sampling", version.ref = "commons-rng" } + +multik-core = { module = "org.jetbrains.kotlinx:multik-core", version.ref = "multik" } +multik-default = { module = "org.jetbrains.kotlinx:multik-default", version.ref = "multik" } \ No newline at end of file diff --git a/kmath-multik/build.gradle.kts b/kmath-multik/build.gradle.kts index e3f27effe..941ac403d 100644 --- a/kmath-multik/build.gradle.kts +++ b/kmath-multik/build.gradle.kts @@ -4,8 +4,6 @@ plugins { description = "JetBrains Multik connector" -val multikVersion: String by rootProject.extra - kscience { jvm() js() @@ -16,12 +14,12 @@ kotlin { commonMain { dependencies { api(projects.kmathTensors) - api("org.jetbrains.kotlinx:multik-core:$multikVersion") + api(libs.multik.core) } } commonTest { dependencies { - api("org.jetbrains.kotlinx:multik-default:$multikVersion") + api(libs.multik.default) } } } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt new file mode 100644 index 000000000..b4d56cf73 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt @@ -0,0 +1,71 @@ +/* + * Copyright 2018-2024 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +@file:OptIn(ExperimentalCoroutinesApi::class) + +package space.kscience.kmath.samplers + +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import space.kscience.kmath.chains.Chain +import space.kscience.kmath.random.RandomGenerator +import space.kscience.kmath.stat.Sampler +import kotlin.coroutines.coroutineContext + + +/** + * A sampler that creates a chain that could be split at each computation + */ +public class RandomForkingSampler( + private val scope: CoroutineScope, + private val initialValue: T, + private val makeStep: suspend RandomGenerator.(T) -> List +) : Sampler { + + override fun sample(generator: RandomGenerator): Chain = buildChain(scope, initialValue) { generator.makeStep(it) } + + public companion object { + private suspend fun Channel.receiveEvents( + initial: T, + makeStep: suspend (T) -> List + ) { + send(initial) + //inner dispatch queue + val innerChannel = Channel(50) + innerChannel.send(initial) + while (coroutineContext.isActive && !innerChannel.isEmpty) { + val current = innerChannel.receive() + //add event immediately, but it does not mean that the value is computed immediately as well + makeStep(current).forEach { + innerChannel.send(it) + send(it) + } + } + innerChannel.close() + close() + } + + + public fun buildChain( + scope: CoroutineScope, + initial: T, + makeStep: suspend (T) -> List + ): Chain { + val channel = Channel(Channel.RENDEZVOUS) + scope.launch { + channel.receiveEvents(initial, makeStep) + } + + return object : Chain { + override suspend fun next(): T? = channel.receiveCatching().getOrNull() + + override suspend fun fork(): Chain = buildChain(scope, channel.receive(), makeStep) + } + } + } +}