Reimplement random-forking chain

This commit is contained in:
Alexander Nozik 2024-08-14 19:20:05 +03:00
parent 2becee7f59
commit 91513a1629
9 changed files with 40 additions and 46 deletions

View File

@ -124,7 +124,7 @@ public val <T> MatrixScope<T>.QR: QRDecompositionAttribute<T>
public interface CholeskyDecomposition<T> { public interface CholeskyDecomposition<T> {
/** /**
* The triangular matrix in this decomposition. It may have either [UpperTriangular] or [LowerTriangular]. * The lower triangular matrix in this decomposition. It should have [LowerTriangular].
*/ */
public val l: Matrix<T> public val l: Matrix<T>
} }

View File

@ -5,10 +5,7 @@ plugins {
val ejmlVerision = "0.43.1" val ejmlVerision = "0.43.1"
dependencies { dependencies {
api("org.ejml:ejml-ddense:$ejmlVerision") api("org.ejml:ejml-all:$ejmlVerision")
api("org.ejml:ejml-fdense:$ejmlVerision")
api("org.ejml:ejml-dsparse:$ejmlVerision")
api("org.ejml:ejml-fsparse:$ejmlVerision")
api(projects.kmathCore) api(projects.kmathCore)
} }

View File

@ -27,7 +27,7 @@ public data class Float64Circle2D(
override val radius: Float64, override val radius: Float64,
) : Circle2D<Double> ) : Circle2D<Double>
public fun Circle2D(center: Vector2D<Float64>, radius: Double): Circle2D<Double> = Float64Circle2D( public fun Circle2D(center: Vector2D<Float64>, radius: Double): Float64Circle2D = Float64Circle2D(
center as? Float64Vector2D ?: Float64Vector2D(center.x, center.y), center as? Float64Vector2D ?: Float64Vector2D(center.x, center.y),
radius radius
) )

View File

@ -17,8 +17,9 @@ public interface NamedDistribution<T> : Distribution<Map<String, T>>
/** /**
* A multivariate distribution that has independent distributions for separate axis. * A multivariate distribution that has independent distributions for separate axis.
*/ */
public class FactorizedDistribution<T>(public val distributions: Collection<NamedDistribution<T>>) : public class FactorizedDistribution<T>(
NamedDistribution<T> { public val distributions: Collection<NamedDistribution<T>>
) : NamedDistribution<T> {
override fun probability(arg: Map<String, T>): Double = override fun probability(arg: Map<String, T>): Double =
distributions.fold(1.0) { acc, dist -> acc * dist.probability(arg) } distributions.fold(1.0) { acc, dist -> acc * dist.probability(arg) }
@ -28,8 +29,10 @@ public class FactorizedDistribution<T>(public val distributions: Collection<Name
} }
} }
public class NamedDistributionWrapper<T : Any>(public val name: String, public val distribution: Distribution<T>) : public class NamedDistributionWrapper<T : Any>(
NamedDistribution<T> { public val name: String,
public val distribution: Distribution<T>
) : NamedDistribution<T> {
override fun probability(arg: Map<String, T>): Double = distribution.probability( override fun probability(arg: Map<String, T>): Double = distribution.probability(
arg[name] ?: error("Argument with name $name not found in input parameters") arg[name] ?: error("Argument with name $name not found in input parameters")
) )

View File

@ -9,11 +9,24 @@ import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.operations.Float64Field.pow import space.kscience.kmath.operations.Float64Field.pow
import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.samplers.GaussianSampler import space.kscience.kmath.samplers.GaussianSampler
import space.kscience.kmath.samplers.InternalErf import space.kscience.kmath.samplers.InternalGamma
import space.kscience.kmath.samplers.NormalizedGaussianSampler import space.kscience.kmath.samplers.NormalizedGaussianSampler
import space.kscience.kmath.samplers.ZigguratNormalizedGaussianSampler import space.kscience.kmath.samplers.ZigguratNormalizedGaussianSampler
import kotlin.math.* import kotlin.math.*
/**
* Based on Commons Math implementation.
* See [https://commons.apache.org/proper/commons-math/javadocs/api-3.3/org/apache/commons/math3/special/Erf.html].
*/
internal object InternalErf {
fun erfc(x: Double): Double {
if (abs(x) > 40) return if (x > 0) 0.0 else 2.0
val ret = InternalGamma.regularizedGammaQ(0.5, x * x, 10000)
return if (x < 0) 2 - ret else ret
}
}
/** /**
* Implements [Distribution1D] for the normal (gaussian) distribution. * Implements [Distribution1D] for the normal (gaussian) distribution.
*/ */

View File

@ -1,20 +0,0 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.samplers
import kotlin.math.abs
/**
* Based on Commons Math implementation.
* See [https://commons.apache.org/proper/commons-math/javadocs/api-3.3/org/apache/commons/math3/special/Erf.html].
*/
internal object InternalErf {
fun erfc(x: Double): Double {
if (abs(x) > 40) return if (x > 0) 0.0 else 2.0
val ret = InternalGamma.regularizedGammaQ(0.5, x * x, 10000)
return if (x < 0) 2 - ret else ret
}
}

View File

@ -21,7 +21,7 @@ import space.kscience.kmath.structures.Float64
*/ */
public class MetropolisHastingsSampler<T>( public class MetropolisHastingsSampler<T>(
public val algebra: Group<T>, public val algebra: Group<T>,
public val startPoint: T, public val startPoint: suspend (RandomGenerator) ->T,
public val stepSampler: Sampler<T>, public val stepSampler: Sampler<T>,
public val targetPdf: suspend (T) -> Float64, public val targetPdf: suspend (T) -> Float64,
) : Sampler<T> { ) : Sampler<T> {
@ -30,7 +30,7 @@ public class MetropolisHastingsSampler<T>(
override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>( override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>(
state = stepSampler.sample(generator), state = stepSampler.sample(generator),
seed = { startPoint }, seed = { startPoint(generator) },
forkState = Chain<T>::fork forkState = Chain<T>::fork
) { previousPoint: T -> ) { previousPoint: T ->
val proposalPoint = with(algebra) { previousPoint + next() } val proposalPoint = with(algebra) { previousPoint + next() }
@ -59,7 +59,7 @@ public class MetropolisHastingsSampler<T>(
targetPdf: suspend (Float64) -> Float64, targetPdf: suspend (Float64) -> Float64,
): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler( ): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler(
algebra = Float64.algebra, algebra = Float64.algebra,
startPoint = startPoint, startPoint = {startPoint},
stepSampler = stepSampler, stepSampler = stepSampler,
targetPdf = targetPdf targetPdf = targetPdf
) )

View File

@ -7,13 +7,8 @@
package space.kscience.kmath.samplers package space.kscience.kmath.samplers
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.*
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.chains.Chain import space.kscience.kmath.chains.Chain
import space.kscience.kmath.operations.Group import space.kscience.kmath.operations.Group
@ -28,6 +23,7 @@ public data class RandomForkingSample<T>(
val value: Deferred<T>, val value: Deferred<T>,
val generation: Int, val generation: Int,
val energy: Float64, val energy: Float64,
val generator: RandomGenerator,
val stepChain: Chain<T> val stepChain: Chain<T>
) )
@ -38,11 +34,11 @@ public data class RandomForkingSample<T>(
public class RandomForkingSampler<T : Any>( public class RandomForkingSampler<T : Any>(
private val scope: CoroutineScope, private val scope: CoroutineScope,
private val initialValue: suspend (RandomGenerator) -> T, private val initialValue: suspend (RandomGenerator) -> T,
private val makeStep: suspend RandomGenerator.(T) -> List<T> private val makeStep: suspend (T) -> List<T>
) : Sampler<T?> { ) : Sampler<T?> {
override fun sample(generator: RandomGenerator): Chain<T?> = override fun sample(generator: RandomGenerator): Chain<T?> =
buildChain(scope, initial = { initialValue(generator) }) { generator.makeStep(it) } buildChain(scope, initial = { initialValue(generator) }) { makeStep(it) }
public companion object { public companion object {
private suspend fun <T> Channel<T>.receiveEvents( private suspend fun <T> Channel<T>.receiveEvents(
@ -95,13 +91,14 @@ public class RandomForkingSampler<T : Any>(
stepScaleRule: suspend (Float64) -> Float64 = { 1.0 }, stepScaleRule: suspend (Float64) -> Float64 = { 1.0 },
targetPdf: suspend (T) -> Float64, targetPdf: suspend (T) -> Float64,
): RandomForkingSampler<RandomForkingSample<T>> where A : Group<T>, A : ScaleOperations<T> = ): RandomForkingSampler<RandomForkingSample<T>> where A : Group<T>, A : ScaleOperations<T> =
RandomForkingSampler<RandomForkingSample<T>>( RandomForkingSampler(
scope = scope, scope = scope,
initialValue = { generator -> initialValue = { generator ->
RandomForkingSample<T>( RandomForkingSample<T>(
value = scope.async { startPoint(generator) }, value = scope.async { startPoint(generator) },
generation = 0, generation = 0,
energy = initialEnergy, energy = initialEnergy,
generator = generator,
stepChain = stepSampler.sample(generator) stepChain = stepSampler.sample(generator)
) )
} }
@ -111,14 +108,14 @@ public class RandomForkingSampler<T : Any>(
RandomForkingSample<T>( RandomForkingSample<T>(
value = scope.async<T> { value = scope.async<T> {
val proposalPoint = with(algebra) { val proposalPoint = with(algebra) {
value + previousSample.stepChain.next() * stepScaleRule(previousSample.energy) value + previousSample.stepChain.next() * stepScaleRule(energy)
} }
val ratio = targetPdf(proposalPoint) / targetPdf(value) val ratio = targetPdf(proposalPoint) / targetPdf(value)
if (ratio >= 1.0) { if (ratio >= 1.0) {
proposalPoint proposalPoint
} else { } else {
val acceptanceProbability = nextDouble() val acceptanceProbability = previousSample.generator.nextDouble()
if (acceptanceProbability <= ratio) { if (acceptanceProbability <= ratio) {
proposalPoint proposalPoint
} else { } else {
@ -127,7 +124,8 @@ public class RandomForkingSampler<T : Any>(
} }
}, },
generation = previousSample.generation + 1, generation = previousSample.generation + 1,
energy = 0.0, energy = energy,
generator = previousSample.generator,
stepChain = previousSample.stepChain stepChain = previousSample.stepChain
) )
} }

View File

@ -85,4 +85,7 @@ class TestMetropolisHastingsSampler {
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2) assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
} }
} }
} }