Compare commits
2 Commits
57d1cd8c87
...
6619db3f45
Author | SHA1 | Date | |
---|---|---|---|
6619db3f45 | |||
48d0ee8126 |
@ -15,8 +15,6 @@ repositories {
|
|||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
val multikVersion: String by rootProject.extra
|
|
||||||
|
|
||||||
kotlin {
|
kotlin {
|
||||||
jvm()
|
jvm()
|
||||||
|
|
||||||
@ -45,7 +43,7 @@ kotlin {
|
|||||||
implementation(project(":kmath-for-real"))
|
implementation(project(":kmath-for-real"))
|
||||||
implementation(project(":kmath-tensors"))
|
implementation(project(":kmath-tensors"))
|
||||||
implementation(project(":kmath-multik"))
|
implementation(project(":kmath-multik"))
|
||||||
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
implementation(libs.multik.default)
|
||||||
implementation(spclibs.kotlinx.benchmark.runtime)
|
implementation(spclibs.kotlinx.benchmark.runtime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ import space.kscience.gradle.useSPCTeam
|
|||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("space.kscience.gradle.project")
|
id("space.kscience.gradle.project")
|
||||||
id("org.jetbrains.kotlinx.kover") version "0.7.6"
|
alias(spclibs.plugins.kotlinx.kover)
|
||||||
}
|
}
|
||||||
|
|
||||||
val attributesVersion by extra("0.2.0")
|
val attributesVersion by extra("0.2.0")
|
||||||
@ -70,5 +70,3 @@ ksciencePublish {
|
|||||||
}
|
}
|
||||||
|
|
||||||
apiValidation.nonPublicMarkers.add("space.kscience.kmath.UnstableKMathAPI")
|
apiValidation.nonPublicMarkers.add("space.kscience.kmath.UnstableKMathAPI")
|
||||||
|
|
||||||
val multikVersion by extra("0.2.3")
|
|
||||||
|
@ -10,8 +10,6 @@ repositories {
|
|||||||
maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers")
|
maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers")
|
||||||
}
|
}
|
||||||
|
|
||||||
val multikVersion: String by rootProject.extra
|
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation(project(":kmath-ast"))
|
implementation(project(":kmath-ast"))
|
||||||
implementation(project(":kmath-kotlingrad"))
|
implementation(project(":kmath-kotlingrad"))
|
||||||
@ -33,7 +31,7 @@ dependencies {
|
|||||||
implementation(project(":kmath-jafama"))
|
implementation(project(":kmath-jafama"))
|
||||||
//multik
|
//multik
|
||||||
implementation(project(":kmath-multik"))
|
implementation(project(":kmath-multik"))
|
||||||
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
implementation(libs.multik.default)
|
||||||
|
|
||||||
//datetime
|
//datetime
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.4.0")
|
implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.4.0")
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
[versions]
|
[versions]
|
||||||
|
|
||||||
commons-rng = "1.6"
|
commons-rng = "1.6"
|
||||||
|
multik = "0.2.3"
|
||||||
|
|
||||||
[libraries]
|
[libraries]
|
||||||
|
|
||||||
commons-rng-simple = {module ="org.apache.commons:commons-rng-simple", version.ref = "commons-rng"}
|
commons-rng-simple = { module = "org.apache.commons:commons-rng-simple", version.ref = "commons-rng" }
|
||||||
commons-rng-sampling = {module ="org.apache.commons:commons-rng-sampling", version.ref = "commons-rng"}
|
commons-rng-sampling = { module = "org.apache.commons:commons-rng-sampling", version.ref = "commons-rng" }
|
||||||
|
|
||||||
|
multik-core = { module = "org.jetbrains.kotlinx:multik-core", version.ref = "multik" }
|
||||||
|
multik-default = { module = "org.jetbrains.kotlinx:multik-default", version.ref = "multik" }
|
@ -4,8 +4,6 @@ plugins {
|
|||||||
|
|
||||||
description = "JetBrains Multik connector"
|
description = "JetBrains Multik connector"
|
||||||
|
|
||||||
val multikVersion: String by rootProject.extra
|
|
||||||
|
|
||||||
kscience {
|
kscience {
|
||||||
jvm()
|
jvm()
|
||||||
js()
|
js()
|
||||||
@ -16,12 +14,12 @@ kotlin {
|
|||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(projects.kmathTensors)
|
api(projects.kmathTensors)
|
||||||
api("org.jetbrains.kotlinx:multik-core:$multikVersion")
|
api(libs.multik.core)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
commonTest {
|
commonTest {
|
||||||
dependencies {
|
dependencies {
|
||||||
api("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
api(libs.multik.default)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,71 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2024 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
@file:OptIn(ExperimentalCoroutinesApi::class)
|
||||||
|
|
||||||
|
package space.kscience.kmath.samplers
|
||||||
|
|
||||||
|
import kotlinx.coroutines.CoroutineScope
|
||||||
|
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||||
|
import kotlinx.coroutines.channels.Channel
|
||||||
|
import kotlinx.coroutines.isActive
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import space.kscience.kmath.chains.Chain
|
||||||
|
import space.kscience.kmath.random.RandomGenerator
|
||||||
|
import space.kscience.kmath.stat.Sampler
|
||||||
|
import kotlin.coroutines.coroutineContext
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A sampler that creates a chain that could be split at each computation
|
||||||
|
*/
|
||||||
|
public class RandomForkingSampler<T: Any>(
|
||||||
|
private val scope: CoroutineScope,
|
||||||
|
private val initialValue: T,
|
||||||
|
private val makeStep: suspend RandomGenerator.(T) -> List<T>
|
||||||
|
) : Sampler<T?> {
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<T?> = buildChain(scope, initialValue) { generator.makeStep(it) }
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
private suspend fun <T> Channel<T>.receiveEvents(
|
||||||
|
initial: T,
|
||||||
|
makeStep: suspend (T) -> List<T>
|
||||||
|
) {
|
||||||
|
send(initial)
|
||||||
|
//inner dispatch queue
|
||||||
|
val innerChannel = Channel<T>(50)
|
||||||
|
innerChannel.send(initial)
|
||||||
|
while (coroutineContext.isActive && !innerChannel.isEmpty) {
|
||||||
|
val current = innerChannel.receive()
|
||||||
|
//add event immediately, but it does not mean that the value is computed immediately as well
|
||||||
|
makeStep(current).forEach {
|
||||||
|
innerChannel.send(it)
|
||||||
|
send(it)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
innerChannel.close()
|
||||||
|
close()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public fun <T: Any> buildChain(
|
||||||
|
scope: CoroutineScope,
|
||||||
|
initial: T,
|
||||||
|
makeStep: suspend (T) -> List<T>
|
||||||
|
): Chain<T?> {
|
||||||
|
val channel = Channel<T>(Channel.RENDEZVOUS)
|
||||||
|
scope.launch {
|
||||||
|
channel.receiveEvents(initial, makeStep)
|
||||||
|
}
|
||||||
|
|
||||||
|
return object : Chain<T?> {
|
||||||
|
override suspend fun next(): T? = channel.receiveCatching().getOrNull()
|
||||||
|
|
||||||
|
override suspend fun fork(): Chain<T?> = buildChain(scope, channel.receive(), makeStep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -24,8 +24,8 @@ class TestMetropolisHastingsSampler {
|
|||||||
|
|
||||||
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
|
data class TestSetup(val mean: Double, val startPoint: Double, val sigma: Double = 0.5)
|
||||||
|
|
||||||
val sample = 1e6.toInt()
|
private val sample = 1e6.toInt()
|
||||||
val burnIn = sample / 5
|
private val burnIn = sample / 5
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun samplingNormalTest() = runTest {
|
fun samplingNormalTest() = runTest {
|
||||||
@ -66,6 +66,7 @@ class TestMetropolisHastingsSampler {
|
|||||||
@Test
|
@Test
|
||||||
fun samplingRayleighTest() = runTest {
|
fun samplingRayleighTest() = runTest {
|
||||||
val generator = RandomGenerator.default(1)
|
val generator = RandomGenerator.default(1)
|
||||||
|
|
||||||
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
|
fun rayleighDist(sigma: Double, arg: Double): Double = if (arg < 0.0) {
|
||||||
0.0
|
0.0
|
||||||
} else {
|
} else {
|
||||||
@ -83,17 +84,5 @@ class TestMetropolisHastingsSampler {
|
|||||||
|
|
||||||
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
||||||
}
|
}
|
||||||
//
|
|
||||||
// fun rayleighDist1(arg: Double) = rayleighDist(arg, 1.0)
|
|
||||||
// var sampler = MetropolisHastingsSampler(::rayleighDist1, initialPoint = 2.0, proposalStd = 1.0)
|
|
||||||
// var sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(1_000_000)
|
|
||||||
//
|
|
||||||
// assertEquals(1.25, Float64Field.mean(sampledValues), 1e-2)
|
|
||||||
//
|
|
||||||
// fun rayleighDist2(arg: Double) = rayleighDist(arg, 2.0)
|
|
||||||
// sampler = MetropolisHastingsSampler(::rayleighDist2, proposalStd = 1.0)
|
|
||||||
// sampledValues = sampler.sample(DefaultGenerator()).nextBufferBlocking(10_000_000)
|
|
||||||
//
|
|
||||||
// assertEquals(2.5, Float64Field.mean(sampledValues), 1e-2)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user