GaussianSampler inherits Blocking Sampler

This commit is contained in:
Alexander Nozik 2021-04-07 14:14:22 +03:00
parent 5bdc02d18c
commit 31460df721
2 changed files with 5 additions and 4 deletions

View File

@ -3,7 +3,6 @@ package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.chains.map import space.kscience.kmath.chains.map
import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
/** /**
* Sampling from a Gaussian distribution with given mean and standard deviation. * Sampling from a Gaussian distribution with given mean and standard deviation.
@ -18,7 +17,7 @@ public class GaussianSampler(
public val mean: Double, public val mean: Double,
public val standardDeviation: Double, public val standardDeviation: Double,
private val normalized: NormalizedGaussianSampler = BoxMullerSampler private val normalized: NormalizedGaussianSampler = BoxMullerSampler
) : Sampler<Double> { ) : BlockingDoubleSampler {
init { init {
require(standardDeviation > 0.0) { "standard deviation is not strictly positive: $standardDeviation" } require(standardDeviation > 0.0) { "standard deviation is not strictly positive: $standardDeviation" }

View File

@ -18,10 +18,12 @@ internal class CommonsDistributionsTest {
} }
@Test @Test
fun testNormalDistributionBlocking() = runBlocking { fun testNormalDistributionBlocking() {
val distribution = GaussianSampler(7.0, 2.0) val distribution = GaussianSampler(7.0, 2.0)
val generator = RandomGenerator.default(1) val generator = RandomGenerator.default(1)
val sample = distribution.sample(generator).nextBufferBlocking(1000) val sample = distribution.sample(generator).nextBufferBlocking(1000)
Assertions.assertEquals(7.0, Mean.double(sample), 0.2) runBlocking {
Assertions.assertEquals(7.0, Mean.double(sample), 0.2)
}
} }
} }