Reimplement random-forking chain
This commit is contained in:
parent
2becee7f59
commit
91513a1629
@ -124,7 +124,7 @@ public val <T> MatrixScope<T>.QR: QRDecompositionAttribute<T>
|
|||||||
|
|
||||||
public interface CholeskyDecomposition<T> {
|
public interface CholeskyDecomposition<T> {
|
||||||
/**
|
/**
|
||||||
* The triangular matrix in this decomposition. It may have either [UpperTriangular] or [LowerTriangular].
|
* The lower triangular matrix in this decomposition. It should have [LowerTriangular].
|
||||||
*/
|
*/
|
||||||
public val l: Matrix<T>
|
public val l: Matrix<T>
|
||||||
}
|
}
|
||||||
|
@ -5,10 +5,7 @@ plugins {
|
|||||||
val ejmlVerision = "0.43.1"
|
val ejmlVerision = "0.43.1"
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api("org.ejml:ejml-ddense:$ejmlVerision")
|
api("org.ejml:ejml-all:$ejmlVerision")
|
||||||
api("org.ejml:ejml-fdense:$ejmlVerision")
|
|
||||||
api("org.ejml:ejml-dsparse:$ejmlVerision")
|
|
||||||
api("org.ejml:ejml-fsparse:$ejmlVerision")
|
|
||||||
api(projects.kmathCore)
|
api(projects.kmathCore)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ public data class Float64Circle2D(
|
|||||||
override val radius: Float64,
|
override val radius: Float64,
|
||||||
) : Circle2D<Double>
|
) : Circle2D<Double>
|
||||||
|
|
||||||
public fun Circle2D(center: Vector2D<Float64>, radius: Double): Circle2D<Double> = Float64Circle2D(
|
public fun Circle2D(center: Vector2D<Float64>, radius: Double): Float64Circle2D = Float64Circle2D(
|
||||||
center as? Float64Vector2D ?: Float64Vector2D(center.x, center.y),
|
center as? Float64Vector2D ?: Float64Vector2D(center.x, center.y),
|
||||||
radius
|
radius
|
||||||
)
|
)
|
||||||
|
@ -17,8 +17,9 @@ public interface NamedDistribution<T> : Distribution<Map<String, T>>
|
|||||||
/**
|
/**
|
||||||
* A multivariate distribution that has independent distributions for separate axis.
|
* A multivariate distribution that has independent distributions for separate axis.
|
||||||
*/
|
*/
|
||||||
public class FactorizedDistribution<T>(public val distributions: Collection<NamedDistribution<T>>) :
|
public class FactorizedDistribution<T>(
|
||||||
NamedDistribution<T> {
|
public val distributions: Collection<NamedDistribution<T>>
|
||||||
|
) : NamedDistribution<T> {
|
||||||
override fun probability(arg: Map<String, T>): Double =
|
override fun probability(arg: Map<String, T>): Double =
|
||||||
distributions.fold(1.0) { acc, dist -> acc * dist.probability(arg) }
|
distributions.fold(1.0) { acc, dist -> acc * dist.probability(arg) }
|
||||||
|
|
||||||
@ -28,8 +29,10 @@ public class FactorizedDistribution<T>(public val distributions: Collection<Name
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public class NamedDistributionWrapper<T : Any>(public val name: String, public val distribution: Distribution<T>) :
|
public class NamedDistributionWrapper<T : Any>(
|
||||||
NamedDistribution<T> {
|
public val name: String,
|
||||||
|
public val distribution: Distribution<T>
|
||||||
|
) : NamedDistribution<T> {
|
||||||
override fun probability(arg: Map<String, T>): Double = distribution.probability(
|
override fun probability(arg: Map<String, T>): Double = distribution.probability(
|
||||||
arg[name] ?: error("Argument with name $name not found in input parameters")
|
arg[name] ?: error("Argument with name $name not found in input parameters")
|
||||||
)
|
)
|
||||||
|
@ -9,11 +9,24 @@ import space.kscience.kmath.chains.BlockingDoubleChain
|
|||||||
import space.kscience.kmath.operations.Float64Field.pow
|
import space.kscience.kmath.operations.Float64Field.pow
|
||||||
import space.kscience.kmath.random.RandomGenerator
|
import space.kscience.kmath.random.RandomGenerator
|
||||||
import space.kscience.kmath.samplers.GaussianSampler
|
import space.kscience.kmath.samplers.GaussianSampler
|
||||||
import space.kscience.kmath.samplers.InternalErf
|
import space.kscience.kmath.samplers.InternalGamma
|
||||||
import space.kscience.kmath.samplers.NormalizedGaussianSampler
|
import space.kscience.kmath.samplers.NormalizedGaussianSampler
|
||||||
import space.kscience.kmath.samplers.ZigguratNormalizedGaussianSampler
|
import space.kscience.kmath.samplers.ZigguratNormalizedGaussianSampler
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Based on Commons Math implementation.
|
||||||
|
* See [https://commons.apache.org/proper/commons-math/javadocs/api-3.3/org/apache/commons/math3/special/Erf.html].
|
||||||
|
*/
|
||||||
|
internal object InternalErf {
|
||||||
|
fun erfc(x: Double): Double {
|
||||||
|
if (abs(x) > 40) return if (x > 0) 0.0 else 2.0
|
||||||
|
val ret = InternalGamma.regularizedGammaQ(0.5, x * x, 10000)
|
||||||
|
return if (x < 0) 2 - ret else ret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implements [Distribution1D] for the normal (gaussian) distribution.
|
* Implements [Distribution1D] for the normal (gaussian) distribution.
|
||||||
*/
|
*/
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2024 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.samplers
|
|
||||||
|
|
||||||
import kotlin.math.abs
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Based on Commons Math implementation.
|
|
||||||
* See [https://commons.apache.org/proper/commons-math/javadocs/api-3.3/org/apache/commons/math3/special/Erf.html].
|
|
||||||
*/
|
|
||||||
internal object InternalErf {
|
|
||||||
fun erfc(x: Double): Double {
|
|
||||||
if (abs(x) > 40) return if (x > 0) 0.0 else 2.0
|
|
||||||
val ret = InternalGamma.regularizedGammaQ(0.5, x * x, 10000)
|
|
||||||
return if (x < 0) 2 - ret else ret
|
|
||||||
}
|
|
||||||
}
|
|
@ -21,7 +21,7 @@ import space.kscience.kmath.structures.Float64
|
|||||||
*/
|
*/
|
||||||
public class MetropolisHastingsSampler<T>(
|
public class MetropolisHastingsSampler<T>(
|
||||||
public val algebra: Group<T>,
|
public val algebra: Group<T>,
|
||||||
public val startPoint: T,
|
public val startPoint: suspend (RandomGenerator) ->T,
|
||||||
public val stepSampler: Sampler<T>,
|
public val stepSampler: Sampler<T>,
|
||||||
public val targetPdf: suspend (T) -> Float64,
|
public val targetPdf: suspend (T) -> Float64,
|
||||||
) : Sampler<T> {
|
) : Sampler<T> {
|
||||||
@ -30,7 +30,7 @@ public class MetropolisHastingsSampler<T>(
|
|||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>(
|
override fun sample(generator: RandomGenerator): Chain<T> = StatefulChain<Chain<T>, T>(
|
||||||
state = stepSampler.sample(generator),
|
state = stepSampler.sample(generator),
|
||||||
seed = { startPoint },
|
seed = { startPoint(generator) },
|
||||||
forkState = Chain<T>::fork
|
forkState = Chain<T>::fork
|
||||||
) { previousPoint: T ->
|
) { previousPoint: T ->
|
||||||
val proposalPoint = with(algebra) { previousPoint + next() }
|
val proposalPoint = with(algebra) { previousPoint + next() }
|
||||||
@ -59,7 +59,7 @@ public class MetropolisHastingsSampler<T>(
|
|||||||
targetPdf: suspend (Float64) -> Float64,
|
targetPdf: suspend (Float64) -> Float64,
|
||||||
): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler(
|
): MetropolisHastingsSampler<Double> = MetropolisHastingsSampler(
|
||||||
algebra = Float64.algebra,
|
algebra = Float64.algebra,
|
||||||
startPoint = startPoint,
|
startPoint = {startPoint},
|
||||||
stepSampler = stepSampler,
|
stepSampler = stepSampler,
|
||||||
targetPdf = targetPdf
|
targetPdf = targetPdf
|
||||||
)
|
)
|
||||||
|
@ -7,13 +7,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.samplers
|
package space.kscience.kmath.samplers
|
||||||
|
|
||||||
import kotlinx.coroutines.CoroutineScope
|
import kotlinx.coroutines.*
|
||||||
import kotlinx.coroutines.Deferred
|
|
||||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
|
||||||
import kotlinx.coroutines.async
|
|
||||||
import kotlinx.coroutines.channels.Channel
|
import kotlinx.coroutines.channels.Channel
|
||||||
import kotlinx.coroutines.isActive
|
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.chains.Chain
|
import space.kscience.kmath.chains.Chain
|
||||||
import space.kscience.kmath.operations.Group
|
import space.kscience.kmath.operations.Group
|
||||||
@ -28,6 +23,7 @@ public data class RandomForkingSample<T>(
|
|||||||
val value: Deferred<T>,
|
val value: Deferred<T>,
|
||||||
val generation: Int,
|
val generation: Int,
|
||||||
val energy: Float64,
|
val energy: Float64,
|
||||||
|
val generator: RandomGenerator,
|
||||||
val stepChain: Chain<T>
|
val stepChain: Chain<T>
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -38,11 +34,11 @@ public data class RandomForkingSample<T>(
|
|||||||
public class RandomForkingSampler<T : Any>(
|
public class RandomForkingSampler<T : Any>(
|
||||||
private val scope: CoroutineScope,
|
private val scope: CoroutineScope,
|
||||||
private val initialValue: suspend (RandomGenerator) -> T,
|
private val initialValue: suspend (RandomGenerator) -> T,
|
||||||
private val makeStep: suspend RandomGenerator.(T) -> List<T>
|
private val makeStep: suspend (T) -> List<T>
|
||||||
) : Sampler<T?> {
|
) : Sampler<T?> {
|
||||||
|
|
||||||
override fun sample(generator: RandomGenerator): Chain<T?> =
|
override fun sample(generator: RandomGenerator): Chain<T?> =
|
||||||
buildChain(scope, initial = { initialValue(generator) }) { generator.makeStep(it) }
|
buildChain(scope, initial = { initialValue(generator) }) { makeStep(it) }
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
private suspend fun <T> Channel<T>.receiveEvents(
|
private suspend fun <T> Channel<T>.receiveEvents(
|
||||||
@ -95,13 +91,14 @@ public class RandomForkingSampler<T : Any>(
|
|||||||
stepScaleRule: suspend (Float64) -> Float64 = { 1.0 },
|
stepScaleRule: suspend (Float64) -> Float64 = { 1.0 },
|
||||||
targetPdf: suspend (T) -> Float64,
|
targetPdf: suspend (T) -> Float64,
|
||||||
): RandomForkingSampler<RandomForkingSample<T>> where A : Group<T>, A : ScaleOperations<T> =
|
): RandomForkingSampler<RandomForkingSample<T>> where A : Group<T>, A : ScaleOperations<T> =
|
||||||
RandomForkingSampler<RandomForkingSample<T>>(
|
RandomForkingSampler(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
initialValue = { generator ->
|
initialValue = { generator ->
|
||||||
RandomForkingSample<T>(
|
RandomForkingSample<T>(
|
||||||
value = scope.async { startPoint(generator) },
|
value = scope.async { startPoint(generator) },
|
||||||
generation = 0,
|
generation = 0,
|
||||||
energy = initialEnergy,
|
energy = initialEnergy,
|
||||||
|
generator = generator,
|
||||||
stepChain = stepSampler.sample(generator)
|
stepChain = stepSampler.sample(generator)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -111,14 +108,14 @@ public class RandomForkingSampler<T : Any>(
|
|||||||
RandomForkingSample<T>(
|
RandomForkingSample<T>(
|
||||||
value = scope.async<T> {
|
value = scope.async<T> {
|
||||||
val proposalPoint = with(algebra) {
|
val proposalPoint = with(algebra) {
|
||||||
value + previousSample.stepChain.next() * stepScaleRule(previousSample.energy)
|
value + previousSample.stepChain.next() * stepScaleRule(energy)
|
||||||
}
|
}
|
||||||
val ratio = targetPdf(proposalPoint) / targetPdf(value)
|
val ratio = targetPdf(proposalPoint) / targetPdf(value)
|
||||||
|
|
||||||
if (ratio >= 1.0) {
|
if (ratio >= 1.0) {
|
||||||
proposalPoint
|
proposalPoint
|
||||||
} else {
|
} else {
|
||||||
val acceptanceProbability = nextDouble()
|
val acceptanceProbability = previousSample.generator.nextDouble()
|
||||||
if (acceptanceProbability <= ratio) {
|
if (acceptanceProbability <= ratio) {
|
||||||
proposalPoint
|
proposalPoint
|
||||||
} else {
|
} else {
|
||||||
@ -127,7 +124,8 @@ public class RandomForkingSampler<T : Any>(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
generation = previousSample.generation + 1,
|
generation = previousSample.generation + 1,
|
||||||
energy = 0.0,
|
energy = energy,
|
||||||
|
generator = previousSample.generator,
|
||||||
stepChain = previousSample.stepChain
|
stepChain = previousSample.stepChain
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -85,4 +85,7 @@ class TestMetropolisHastingsSampler {
|
|||||||
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
assertEquals(setup.mean * sqrt(PI / 2), Float64Field.mean(sampledValues), 1e-2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user