Factorized distributions/named priors

This commit is contained in:
Alexander Nozik 2019-09-02 13:07:42 +03:00
parent 92b4e6f28c
commit e5e9367c43
5 changed files with 44 additions and 3 deletions

View File

@ -1,10 +1,21 @@
package scientifik.kmath.operations
import scientifik.kmath.structures.NDElement
import scientifik.kmath.structures.NDField
import scientifik.kmath.structures.complex
fun main() {
val element = NDElement.complex(2, 2) { index: IntArray ->
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
}
val compute = NDField.complex(8).run {
val a = produce { (it) -> i * it - it.toDouble() }
val b = 3
val c = Complex(1.0, 1.0)
(a pow b) + c
}
}

View File

@ -56,6 +56,8 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
* Complex number class
*/
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
override fun unwrap(): Complex = this
override fun Complex.wrap(): Complex = this

View File

@ -32,6 +32,8 @@ fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.conte
interface PowerOperations<T> {
fun power(arg: T, pow: Number): T
fun sqrt(arg: T) = power(arg, 0.5)
infix fun T.pow(pow: Number) = power(this, pow)
}
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)

View File

@ -20,7 +20,7 @@ interface Distribution<T : Any> : Sampler<T> {
/**
* Create a chain of samples from this distribution.
* The chain is not guaranteed to be stateless.
* The chain is not guaranteed to be stateless, but different sample chains should be independent.
*/
override fun sample(generator: RandomGenerator): Chain<T>
@ -32,7 +32,7 @@ interface Distribution<T : Any> : Sampler<T> {
interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
/**
* Cumulative distribution for ordered parameter
* Cumulative distribution for ordered parameter (CDF)
*/
fun cumulative(arg: T): Double
}

View File

@ -0,0 +1,26 @@
package scientifik.kmath.prob
import scientifik.kmath.chains.Chain
import scientifik.kmath.chains.SimpleChain
/**
* A multivariate distribution which takes a map of parameters
*/
interface NamedDistribution<T> : Distribution<Map<String, T>>
/**
* A multivariate distribution that has independent distributions for separate axis
*/
class FactorizedDistribution<T>(val distributions: Collection<NamedDistribution<T>>) : NamedDistribution<T> {
override fun probability(arg: Map<String, T>): Double {
return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
}
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
val chains = distributions.map { it.sample(generator) }
return SimpleChain<Map<String, T>> {
chains.fold(emptyMap()) { acc, chain -> acc + chain.next() }
}
}
}