Compare commits
2 Commits
57d1cd8c87
...
6619db3f45
Author | SHA1 | Date | |
---|---|---|---|
6619db3f45 | |||
48d0ee8126 |
@ -15,8 +15,6 @@ repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
val multikVersion: String by rootProject.extra
|
||||
|
||||
kotlin {
|
||||
jvm()
|
||||
|
||||
@ -45,7 +43,7 @@ kotlin {
|
||||
implementation(project(":kmath-for-real"))
|
||||
implementation(project(":kmath-tensors"))
|
||||
implementation(project(":kmath-multik"))
|
||||
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||
implementation(libs.multik.default)
|
||||
implementation(spclibs.kotlinx.benchmark.runtime)
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ import space.kscience.gradle.useSPCTeam
|
||||
|
||||
plugins {
|
||||
id("space.kscience.gradle.project")
|
||||
id("org.jetbrains.kotlinx.kover") version "0.7.6"
|
||||
alias(spclibs.plugins.kotlinx.kover)
|
||||
}
|
||||
|
||||
val attributesVersion by extra("0.2.0")
|
||||
@ -70,5 +70,3 @@ ksciencePublish {
|
||||
}
|
||||
|
||||
apiValidation.nonPublicMarkers.add("space.kscience.kmath.UnstableKMathAPI")
|
||||
|
||||
val multikVersion by extra("0.2.3")
|
||||
|
@ -10,8 +10,6 @@ repositories {
|
||||
maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers")
|
||||
}
|
||||
|
||||
val multikVersion: String by rootProject.extra
|
||||
|
||||
dependencies {
|
||||
implementation(project(":kmath-ast"))
|
||||
implementation(project(":kmath-kotlingrad"))
|
||||
@ -33,7 +31,7 @@ dependencies {
|
||||
implementation(project(":kmath-jafama"))
|
||||
//multik
|
||||
implementation(project(":kmath-multik"))
|
||||
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||
implementation(libs.multik.default)
|
||||
|
||||
//datetime
|
||||
implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.4.0")
|
||||
|
@ -1,9 +1,12 @@
|
||||
[versions]
|
||||
|
||||
commons-rng = "1.6"
|
||||
|
||||
multik = "0.2.3"
|
||||
|
||||
[libraries]
|
||||
|
||||
commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"}
|
||||
commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"}
|
||||
commons-rng-simple = { module = "org.apache.commons:commons-rng-simple", version.ref = "commons-rng" }
|
||||
commons-rng-sampling = { module = "org.apache.commons:commons-rng-sampling", version.ref = "commons-rng" }
|
||||
|
||||
multik-core = { module = "org.jetbrains.kotlinx:multik-core", version.ref = "multik" }
|
||||
multik-default = { module = "org.jetbrains.kotlinx:multik-default", version.ref = "multik" }
|
@ -4,8 +4,6 @@ plugins {
|
||||
|
||||
description = "JetBrains Multik connector"
|
||||
|
||||
val multikVersion: String by rootProject.extra
|
||||
|
||||
kscience {
|
||||
jvm()
|
||||
js()
|
||||
@ -16,12 +14,12 @@ kotlin {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(projects.kmathTensors)
|
||||
api("org.jetbrains.kotlinx:multik-core:$multikVersion")
|
||||
api(libs.multik.core)
|
||||
}
|
||||
}
|
||||
commonTest {
|
||||
dependencies {
|
||||
api("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||
api(libs.multik.default)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,71 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
@file:OptIn(ExperimentalCoroutinesApi::class)
|
||||
|
||||
package space.kscience.kmath.samplers
|
||||
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.channels.Channel
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import space.kscience.kmath.chains.Chain
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.stat.Sampler
|
||||
import kotlin.coroutines.coroutineContext
|
||||
|
||||
|
||||
/**
|
||||
* A sampler that creates a chain that could be split at each computation
|
||||
*/
|
||||
public class RandomForkingSampler<T: Any>(
|
||||
private val scope: CoroutineScope,
|
||||
private val initialValue: T,
|
||||
private val makeStep: suspend RandomGenerator.(T) -> List<T>
|
||||
) : Sampler<T?> {
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<T?> = buildChain(scope, initialValue) { generator.makeStep(it) }
|
||||
|
||||
public companion object {
|
||||
private suspend fun <T> Channel<T>.receiveEvents(
|
||||
initial: T,
|
||||
makeStep: suspend (T) -> List<T>
|
||||
) {
|
||||
send(initial)
|
||||
//inner dispatch queue
|
||||
val innerChannel = Channel<T>(50)
|
||||
innerChannel.send(initial)
|
||||
while (coroutineContext.isActive && !innerChannel.isEmpty) {
|
||||
val current = innerChannel.receive()
|
||||
//add event immediately, but it does not mean that the value is computed immediately as well
|
||||
makeStep(current).forEach {
|
||||
innerChannel.send(it)
|
||||
send(it)
|
||||
}
|
||||
}
|
||||
innerChannel.close()
|
||||
close()
|
||||
}
|
||||
|
||||
|
||||
public fun <T: Any> buildChain(
|
||||
scope: CoroutineScope,
|
||||
initial: T,
|
||||
makeStep: suspend (T) -> List<T>
|
||||
): Chain<T?> {
|
||||
val channel = Channel<T>(Channel.RENDEZVOUS)
|
||||
scope.launch {
|
||||
channel.receiveEvents(initial, makeStep)
|
||||
}
|
||||
|
||||
return object : Chain<T?> {
|
||||
override suspend fun next(): T? = channel.receiveCatching().getOrNull()
|
||||
|
||||
override suspend fun fork(): Chain<T?> = buildChain(scope, channel.receive(), makeStep)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -24,8 +24,8 @@ class TestMetropolisHastingsSampler {
|
||||
|
||||
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
|
||||
|
||||
val sample = 1e6.toInt()
|
||||
val burnIn = sample / 5
|
||||
private val sample = 1e6.toInt()
|
||||
private val burnIn = sample / 5
|
||||
|
||||
@Test
|
||||
fun samplingNormalTest() = runTest {
|
||||
@ -66,6 +66,7 @@ class TestMetropolisHastingsSampler {
|
||||
@Test
|
||||
fun samplingRayleighTest() = runTest {
|
||||
val generator = RandomGenerator.default(1)
|
||||
|
||||
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
|
||||
0.0
|
||||
} else {
|
||||
@ -83,17 +84,5 @@ class TestMetropolisHastingsSampler {
|
||||
|
||||
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
||||
}
|
||||
//
|
||||
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
|
||||
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
|
||||
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
||||
//
|
||||
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
|
||||
//
|
||||
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
|
||||
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
|
||||
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
|
||||
//
|
||||
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user