Compare commits

..

2 Commits

Author SHA1 Message Date
6619db3f45 Reimplement random-forging chain 2024-08-09 10:22:37 +03:00
48d0ee8126 Add Metropolis-Hastings sampler.
Minor fixes.
2024-08-04 21:26:51 +03:00
7 changed files with 85 additions and 30 deletions

View File

@ -15,8 +15,6 @@ repositories {
mavenCentral()
}
val multikVersion: String by rootProject.extra
kotlin {
jvm()
@ -45,7 +43,7 @@ kotlin {
implementation(project(":kmath-for-real"))
implementation(project(":kmath-tensors"))
implementation(project(":kmath-multik"))
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
implementation(libs.multik.default)
implementation(spclibs.kotlinx.benchmark.runtime)
}
}

View File

@ -3,7 +3,7 @@ import space.kscience.gradle.useSPCTeam
plugins {
id("space.kscience.gradle.project")
id("org.jetbrains.kotlinx.kover") version "0.7.6"
alias(spclibs.plugins.kotlinx.kover)
}
val attributesVersion by extra("0.2.0")
@ -70,5 +70,3 @@ ksciencePublish {
}
apiValidation.nonPublicMarkers.add("space.kscience.kmath.UnstableKMathAPI")
val multikVersion by extra("0.2.3")

View File

@ -10,8 +10,6 @@ repositories {
maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers")
}
val multikVersion: String by rootProject.extra
dependencies {
implementation(project(":kmath-ast"))
implementation(project(":kmath-kotlingrad"))
@ -33,7 +31,7 @@ dependencies {
implementation(project(":kmath-jafama"))
//multik
implementation(project(":kmath-multik"))
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
implementation(libs.multik.default)
//datetime
implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.4.0")

View File

@ -1,9 +1,12 @@
[versions]
commons-rng = "1.6"
multik = "0.2.3"
[libraries]
commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"}
commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"}
commons-rng-simple = { module = "org.apache.commons:commons-rng-simple", version.ref = "commons-rng" }
commons-rng-sampling = { module = "org.apache.commons:commons-rng-sampling", version.ref = "commons-rng" }
multik-core = { module = "org.jetbrains.kotlinx:multik-core", version.ref = "multik" }
multik-default = { module = "org.jetbrains.kotlinx:multik-default", version.ref = "multik" }

View File

@ -4,8 +4,6 @@ plugins {
description = "JetBrains Multik connector"
val multikVersion: String by rootProject.extra
kscience {
jvm()
js()
@ -16,12 +14,12 @@ kotlin {
commonMain {
dependencies {
api(projects.kmathTensors)
api("org.jetbrains.kotlinx:multik-core:$multikVersion")
api(libs.multik.core)
}
}
commonTest {
dependencies {
api("org.jetbrains.kotlinx:multik-default:$multikVersion")
api(libs.multik.default)
}
}
}

View File

@ -0,0 +1,71 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
@file:OptIn(ExperimentalCoroutinesApi::class)
package space.kscience.kmath.samplers
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.stat.Sampler
import kotlin.coroutines.coroutineContext
/**
* A sampler that creates a chain that could be split at each computation
*/
public class RandomForkingSampler<T: Any>(
private val scope: CoroutineScope,
private val initialValue: T,
private val makeStep: suspend RandomGenerator.(T) -> List<T>
) : Sampler<T?> {
override fun sample(generator: RandomGenerator): Chain<T?> = buildChain(scope, initialValue) { generator.makeStep(it) }
public companion object {
private suspend fun <T> Channel<T>.receiveEvents(
initial: T,
makeStep: suspend (T) -> List<T>
) {
send(initial)
//inner dispatch queue
val innerChannel = Channel<T>(50)
innerChannel.send(initial)
while (coroutineContext.isActive && !innerChannel.isEmpty) {
val current = innerChannel.receive()
//add event immediately, but it does not mean that the value is computed immediately as well
makeStep(current).forEach {
innerChannel.send(it)
send(it)
}
}
innerChannel.close()
close()
}
public fun <T: Any> buildChain(
scope: CoroutineScope,
initial: T,
makeStep: suspend (T) -> List<T>
): Chain<T?> {
val channel = Channel<T>(Channel.RENDEZVOUS)
scope.launch {
channel.receiveEvents(initial, makeStep)
}
return object : Chain<T?> {
override suspend fun next(): T? = channel.receiveCatching().getOrNull()
override suspend fun fork(): Chain<T?> = buildChain(scope, channel.receive(), makeStep)
}
}
}
}

View File

@ -24,8 +24,8 @@ class TestMetropolisHastingsSampler {
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
val sample = 1e6.toInt()
val burnIn = sample / 5
private val sample = 1e6.toInt()
private val burnIn = sample / 5
@Test
fun samplingNormalTest() = runTest {
@ -66,6 +66,7 @@ class TestMetropolisHastingsSampler {
@Test
fun samplingRayleighTest() = runTest {
val generator = RandomGenerator.default(1)
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
0.0
} else {
@ -83,17 +84,5 @@ class TestMetropolisHastingsSampler {
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
}
//
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
//
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
//
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
//
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
}
}