From 91513a1629b110ca0632eaa6fa110234c38ab527 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 14 Aug 2024 19:20:05 +0300 Subject: [PATCH] Reimplement random-forking chain --- .../kscience/kmath/linear/matrixAttributes.kt | 2 +- kmath-ejml/build.gradle.kts | 5 +---- .../kmath/geometry/euclidean2d/Circle2D.kt | 2 +- .../distributions/FactorizedDistribution.kt | 11 ++++++---- .../kmath/distributions/NormalDistribution.kt | 15 ++++++++++++- .../kscience/kmath/samplers/InternalErf.kt | 20 ----------------- .../samplers/MetropolisHastingsSampler.kt | 6 ++--- .../kmath/samplers/RandomForkingEvent.kt | 22 +++++++++---------- .../samplers/TestMetropolisHastingsSampler.kt | 3 +++ 9 files changed, 40 insertions(+), 46 deletions(-) delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/InternalErf.kt diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt index dadb2f3d7..fdb7e318b 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt @@ -124,7 +124,7 @@ public val MatrixScope.QR: QRDecompositionAttribute public interface CholeskyDecomposition { /** - * The triangular matrix in this decomposition. It may have either [UpperTriangular] or [LowerTriangular]. + * The lower triangular matrix in this decomposition. It should have [LowerTriangular]. */ public val l: Matrix } diff --git a/kmath-ejml/build.gradle.kts b/kmath-ejml/build.gradle.kts index 54b8f7038..50416eaf3 100644 --- a/kmath-ejml/build.gradle.kts +++ b/kmath-ejml/build.gradle.kts @@ -5,10 +5,7 @@ plugins { val ejmlVerision = "0.43.1" dependencies { - api("org.ejml:ejml-ddense:$ejmlVerision") - api("org.ejml:ejml-fdense:$ejmlVerision") - api("org.ejml:ejml-dsparse:$ejmlVerision") - api("org.ejml:ejml-fsparse:$ejmlVerision") + api("org.ejml:ejml-all:$ejmlVerision") api(projects.kmathCore) } diff --git a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/euclidean2d/Circle2D.kt b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/euclidean2d/Circle2D.kt index fcffc7714..ca9a37dcb 100644 --- a/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/euclidean2d/Circle2D.kt +++ b/kmath-geometry/src/commonMain/kotlin/space/kscience/kmath/geometry/euclidean2d/Circle2D.kt @@ -27,7 +27,7 @@ public data class Float64Circle2D( override val radius: Float64, ) : Circle2D -public fun Circle2D(center: Vector2D, radius: Double): Circle2D = Float64Circle2D( +public fun Circle2D(center: Vector2D, radius: Double): Float64Circle2D = Float64Circle2D( center as? Float64Vector2D ?: Float64Vector2D(center.x, center.y), radius ) diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt index 8f5a378bb..3e0aa1c06 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt @@ -17,8 +17,9 @@ public interface NamedDistribution : Distribution> /** * A multivariate distribution that has independent distributions for separate axis. */ -public class FactorizedDistribution(public val distributions: Collection>) : - NamedDistribution { +public class FactorizedDistribution( + public val distributions: Collection> +) : NamedDistribution { override fun probability(arg: Map): Double = distributions.fold(1.0) { acc, dist -> acc * dist.probability(arg) } @@ -28,8 +29,10 @@ public class FactorizedDistribution(public val distributions: Collection(public val name: String, public val distribution: Distribution) : - NamedDistribution { +public class NamedDistributionWrapper( + public val name: String, + public val distribution: Distribution +) : NamedDistribution { override fun probability(arg: Map): Double = distribution.probability( arg[name] ?: error("Argument with name $name not found in input parameters") ) diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt index f425c6285..89f1dd8d7 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt @@ -9,11 +9,24 @@ import space.kscience.kmath.chains.BlockingDoubleChain import space.kscience.kmath.operations.Float64Field.pow import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.samplers.GaussianSampler -import space.kscience.kmath.samplers.InternalErf +import space.kscience.kmath.samplers.InternalGamma import space.kscience.kmath.samplers.NormalizedGaussianSampler import space.kscience.kmath.samplers.ZigguratNormalizedGaussianSampler import kotlin.math.* + +/** + * Based on Commons Math implementation. + * See [https://commons.apache.org/proper/commons-math/javadocs/api-3.3/org/apache/commons/math3/special/Erf.html]. + */ +internal object InternalErf { + fun erfc(x: Double): Double { + if (abs(x) > 40) return if (x > 0) 0.0 else 2.0 + val ret = InternalGamma.regularizedGammaQ(0.5, x * x, 10000) + return if (x < 0) 2 - ret else ret + } +} + /** * Implements [Distribution1D] for the normal (gaussian) distribution. */ diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/InternalErf.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/InternalErf.kt deleted file mode 100644 index c28b4be50..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/InternalErf.kt +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright 2018-2024 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.samplers - -import kotlin.math.abs - -/** - * Based on Commons Math implementation. - * See [https://commons.apache.org/proper/commons-math/javadocs/api-3.3/org/apache/commons/math3/special/Erf.html]. - */ -internal object InternalErf { - fun erfc(x: Double): Double { - if (abs(x) > 40) return if (x > 0) 0.0 else 2.0 - val ret = InternalGamma.regularizedGammaQ(0.5, x * x, 10000) - return if (x < 0) 2 - ret else ret - } -} \ No newline at end of file diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt index 4e00bbad2..a6bba910b 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MetropolisHastingsSampler.kt @@ -21,7 +21,7 @@ import space.kscience.kmath.structures.Float64 */ public class MetropolisHastingsSampler( public val algebra: Group, - public val startPoint: T, + public val startPoint: suspend (RandomGenerator) ->T, public val stepSampler: Sampler, public val targetPdf: suspend (T) -> Float64, ) : Sampler { @@ -30,7 +30,7 @@ public class MetropolisHastingsSampler( override fun sample(generator: RandomGenerator): Chain = StatefulChain, T>( state = stepSampler.sample(generator), - seed = { startPoint }, + seed = { startPoint(generator) }, forkState = Chain::fork ) { previousPoint: T -> val proposalPoint = with(algebra) { previousPoint + next() } @@ -59,7 +59,7 @@ public class MetropolisHastingsSampler( targetPdf: suspend (Float64) -> Float64, ): MetropolisHastingsSampler = MetropolisHastingsSampler( algebra = Float64.algebra, - startPoint = startPoint, + startPoint = {startPoint}, stepSampler = stepSampler, targetPdf = targetPdf ) diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt index 9260b67fe..a9331276c 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/RandomForkingEvent.kt @@ -7,13 +7,8 @@ package space.kscience.kmath.samplers -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Deferred -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.async +import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.isActive -import kotlinx.coroutines.launch import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.chains.Chain import space.kscience.kmath.operations.Group @@ -28,6 +23,7 @@ public data class RandomForkingSample( val value: Deferred, val generation: Int, val energy: Float64, + val generator: RandomGenerator, val stepChain: Chain ) @@ -38,11 +34,11 @@ public data class RandomForkingSample( public class RandomForkingSampler( private val scope: CoroutineScope, private val initialValue: suspend (RandomGenerator) -> T, - private val makeStep: suspend RandomGenerator.(T) -> List + private val makeStep: suspend (T) -> List ) : Sampler { override fun sample(generator: RandomGenerator): Chain = - buildChain(scope, initial = { initialValue(generator) }) { generator.makeStep(it) } + buildChain(scope, initial = { initialValue(generator) }) { makeStep(it) } public companion object { private suspend fun Channel.receiveEvents( @@ -95,13 +91,14 @@ public class RandomForkingSampler( stepScaleRule: suspend (Float64) -> Float64 = { 1.0 }, targetPdf: suspend (T) -> Float64, ): RandomForkingSampler> where A : Group, A : ScaleOperations = - RandomForkingSampler>( + RandomForkingSampler( scope = scope, initialValue = { generator -> RandomForkingSample( value = scope.async { startPoint(generator) }, generation = 0, energy = initialEnergy, + generator = generator, stepChain = stepSampler.sample(generator) ) } @@ -111,14 +108,14 @@ public class RandomForkingSampler( RandomForkingSample( value = scope.async { val proposalPoint = with(algebra) { - value + previousSample.stepChain.next() * stepScaleRule(previousSample.energy) + value + previousSample.stepChain.next() * stepScaleRule(energy) } val ratio = targetPdf(proposalPoint) / targetPdf(value) if (ratio >= 1.0) { proposalPoint } else { - val acceptanceProbability = nextDouble() + val acceptanceProbability = previousSample.generator.nextDouble() if (acceptanceProbability <= ratio) { proposalPoint } else { @@ -127,7 +124,8 @@ public class RandomForkingSampler( } }, generation = previousSample.generation + 1, - energy = 0.0, + energy = energy, + generator = previousSample.generator, stepChain = previousSample.stepChain ) } diff --git a/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt index 3011687ef..1261c75e5 100644 --- a/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt +++ b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/samplers/TestMetropolisHastingsSampler.kt @@ -85,4 +85,7 @@ class TestMetropolisHastingsSampler { assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2) } } + + + } \ No newline at end of file