Factorized distributions/named priors
This commit is contained in:
parent
92b4e6f28c
commit
e5e9367c43
@ -1,10 +1,21 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
import scientifik.kmath.structures.NDElement
|
import scientifik.kmath.structures.NDElement
|
||||||
|
import scientifik.kmath.structures.NDField
|
||||||
import scientifik.kmath.structures.complex
|
import scientifik.kmath.structures.complex
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
val element = NDElement.complex(2, 2) { index: IntArray ->
|
val element = NDElement.complex(2, 2) { index: IntArray ->
|
||||||
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
val compute = NDField.complex(8).run {
|
||||||
|
val a = produce { (it) -> i * it - it.toDouble() }
|
||||||
|
val b = 3
|
||||||
|
val c = Complex(1.0, 1.0)
|
||||||
|
|
||||||
|
(a pow b) + c
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -56,6 +56,8 @@ object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
|||||||
* Complex number class
|
* Complex number class
|
||||||
*/
|
*/
|
||||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
||||||
|
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
|
||||||
|
|
||||||
override fun unwrap(): Complex = this
|
override fun unwrap(): Complex = this
|
||||||
|
|
||||||
override fun Complex.wrap(): Complex = this
|
override fun Complex.wrap(): Complex = this
|
||||||
|
@ -32,6 +32,8 @@ fun <T : MathElement<out TrigonometricOperations<T>>> ctg(arg: T): T = arg.conte
|
|||||||
interface PowerOperations<T> {
|
interface PowerOperations<T> {
|
||||||
fun power(arg: T, pow: Number): T
|
fun power(arg: T, pow: Number): T
|
||||||
fun sqrt(arg: T) = power(arg, 0.5)
|
fun sqrt(arg: T) = power(arg, 0.5)
|
||||||
|
|
||||||
|
infix fun T.pow(pow: Number) = power(this, pow)
|
||||||
}
|
}
|
||||||
|
|
||||||
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
|
infix fun <T : MathElement<out PowerOperations<T>>> T.pow(power: Double): T = context.power(this, power)
|
||||||
|
@ -20,7 +20,7 @@ interface Distribution<T : Any> : Sampler<T> {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a chain of samples from this distribution.
|
* Create a chain of samples from this distribution.
|
||||||
* The chain is not guaranteed to be stateless.
|
* The chain is not guaranteed to be stateless, but different sample chains should be independent.
|
||||||
*/
|
*/
|
||||||
override fun sample(generator: RandomGenerator): Chain<T>
|
override fun sample(generator: RandomGenerator): Chain<T>
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ interface Distribution<T : Any> : Sampler<T> {
|
|||||||
|
|
||||||
interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
|
interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
|
||||||
/**
|
/**
|
||||||
* Cumulative distribution for ordered parameter
|
* Cumulative distribution for ordered parameter (CDF)
|
||||||
*/
|
*/
|
||||||
fun cumulative(arg: T): Double
|
fun cumulative(arg: T): Double
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
package scientifik.kmath.prob
|
||||||
|
|
||||||
|
import scientifik.kmath.chains.Chain
|
||||||
|
import scientifik.kmath.chains.SimpleChain
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A multivariate distribution which takes a map of parameters
|
||||||
|
*/
|
||||||
|
interface NamedDistribution<T> : Distribution<Map<String, T>>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A multivariate distribution that has independent distributions for separate axis
|
||||||
|
*/
|
||||||
|
class FactorizedDistribution<T>(val distributions: Collection<NamedDistribution<T>>) : NamedDistribution<T> {
|
||||||
|
|
||||||
|
override fun probability(arg: Map<String, T>): Double {
|
||||||
|
return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sample(generator: RandomGenerator): Chain<Map<String, T>> {
|
||||||
|
val chains = distributions.map { it.sample(generator) }
|
||||||
|
return SimpleChain<Map<String, T>> {
|
||||||
|
chains.fold(emptyMap()) { acc, chain -> acc + chain.next() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user