forked from kscience/kmath
Factorized distributions/named priors
This commit is contained in:
parent
92b4e6f28c
commit
e5e9367c43
@ -1,10 +1,21 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import scientifik.kmath.structures.NDElement
|
||||
import scientifik.kmath.structures.NDField
|
||||
import scientifik.kmath.structures.complex
|
||||
|
||||
fun main() {
|
||||
val element = NDElement.complex(2, 2) { index: IntArray ->
|
||||
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
||||
}
|
||||
|
||||
|
||||
val compute = NDField.complex(8).run {
|
||||
val a = produce { (it) -> i * it - it.toDouble() }
|
||||
val b = 3
|
||||
val c = Complex(1.0, 1.0)
|
||||
|
||||
(a pow b) + c
|
||||
}
|
||||
|
||||
}
|
@ -56,6 +56,8 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
||||
* Complex number class
|
||||
*/
|
||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
||||
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
|
||||
|
||||
override fun unwrap(): Complex = this
|
||||
|
||||
override fun Complex.wrap(): Complex = this
|
||||
|
@ -32,6 +32,8 @@ fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.conte
|
||||
interface PowerOperations<T> {
|
||||
fun power(arg: T, pow: Number): T
|
||||
fun sqrt(arg: T) = power(arg, 0.5)
|
||||
|
||||
infix fun T.pow(pow: Number) = power(this, pow)
|
||||
}
|
||||
|
||||
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
|
||||
@ -48,7 +50,7 @@ interface ExponentialOperations<T> {
|
||||
fun <T : MathElement<out ExponentialOperations<T>>> exp(arg: T): T = arg.context.exp(arg)
|
||||
fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.ln(arg)
|
||||
|
||||
interface Norm<in T: Any, out R> {
|
||||
interface Norm<in T : Any, out R> {
|
||||
fun norm(arg: T): R
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,7 @@ interface Distribution<T : Any> : Sampler<T> {
|
||||
|
||||
/**
|
||||
* Create a chain of samples from this distribution.
|
||||
* The chain is not guaranteed to be stateless.
|
||||
* The chain is not guaranteed to be stateless, but different sample chains should be independent.
|
||||
*/
|
||||
override fun sample(generator: RandomGenerator): Chain<T>
|
||||
|
||||
@ -32,7 +32,7 @@ interface Distribution<T : Any> : Sampler<T> {
|
||||
|
||||
interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
|
||||
/**
|
||||
* Cumulative distribution for ordered parameter
|
||||
* Cumulative distribution for ordered parameter (CDF)
|
||||
*/
|
||||
fun cumulative(arg: T): Double
|
||||
}
|
||||
|
@ -0,0 +1,26 @@
|
||||
package scientifik.kmath.prob
|
||||
|
||||
import scientifik.kmath.chains.Chain
|
||||
import scientifik.kmath.chains.SimpleChain
|
||||
|
||||
/**
|
||||
* A multivariate distribution which takes a map of parameters
|
||||
*/
|
||||
interface NamedDistribution<T> : Distribution<Map<String, T>>
|
||||
|
||||
/**
|
||||
* A multivariate distribution that has independent distributions for separate axis
|
||||
*/
|
||||
class FactorizedDistribution<T>(val distributions: Collection<NamedDistribution<T>>) : NamedDistribution<T> {
|
||||
|
||||
override fun probability(arg: Map<String, T>): Double {
|
||||
return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
|
||||
}
|
||||
|
||||
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
||||
val chains = distributions.map { it.sample(generator) }
|
||||
return SimpleChain<Map<String, T>> {
|
||||
chains.fold(emptyMap()) { acc, chain -> acc + chain.next() }
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user