From e5e9367c436ed2c6e801d37d5552751a889affb4 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 2 Sep 2019 13:07:42 +0300 Subject: [PATCH] Factorized distributions/named priors --- .../kmath/operations/ComplexDemo.kt | 11 ++++++++ .../scientifik/kmath/operations/Complex.kt | 2 ++ .../kmath/operations/OptionalOperations.kt | 4 ++- .../scientifik/kmath/prob/Distribution.kt | 4 +-- .../kmath/prob/FactorizedDistribution.kt | 26 +++++++++++++++++++ 5 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt diff --git a/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt index 86b28303d..4841f9dd8 100644 --- a/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt +++ b/examples/src/main/kotlin/scientifik/kmath/operations/ComplexDemo.kt @@ -1,10 +1,21 @@ package scientifik.kmath.operations import scientifik.kmath.structures.NDElement +import scientifik.kmath.structures.NDField import scientifik.kmath.structures.complex fun main() { val element = NDElement.complex(2, 2) { index: IntArray -> Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) } + + + val compute = NDField.complex(8).run { + val a = produce { (it) -> i * it - it.toDouble() } + val b = 3 + val c = Complex(1.0, 1.0) + + (a pow b) + c + } + } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 79f9e3c35..6c529f55e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -56,6 +56,8 @@ object ComplexField : ExtendedFieldOperations, Field { * Complex number class */ data class Complex(val re: Double, val im: Double) : FieldElement, Comparable { + constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) + override fun unwrap(): Complex = this override fun Complex.wrap(): Complex = this diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index ffbbf69dd..2a77e36ef 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -32,6 +32,8 @@ fun >> ctg(arg: T): T = arg.conte interface PowerOperations { fun power(arg: T, pow: Number): T fun sqrt(arg: T) = power(arg, 0.5) + + infix fun T.pow(pow: Number) = power(this, pow) } infix fun >> T.pow(power: Double): T = context.power(this, power) @@ -48,7 +50,7 @@ interface ExponentialOperations { fun >> exp(arg: T): T = arg.context.exp(arg) fun >> ln(arg: T): T = arg.context.ln(arg) -interface Norm { +interface Norm { fun norm(arg: T): R } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt index 487e13876..a11799517 100644 --- a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/Distribution.kt @@ -20,7 +20,7 @@ interface Distribution : Sampler { /** * Create a chain of samples from this distribution. - * The chain is not guaranteed to be stateless. + * The chain is not guaranteed to be stateless, but different sample chains should be independent. */ override fun sample(generator: RandomGenerator): Chain @@ -32,7 +32,7 @@ interface Distribution : Sampler { interface UnivariateDistribution> : Distribution { /** - * Cumulative distribution for ordered parameter + * Cumulative distribution for ordered parameter (CDF) */ fun cumulative(arg: T): Double } diff --git a/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt new file mode 100644 index 000000000..d70db4f83 --- /dev/null +++ b/kmath-prob/src/commonMain/kotlin/scientifik/kmath/prob/FactorizedDistribution.kt @@ -0,0 +1,26 @@ +package scientifik.kmath.prob + +import scientifik.kmath.chains.Chain +import scientifik.kmath.chains.SimpleChain + +/** + * A multivariate distribution which takes a map of parameters + */ +interface NamedDistribution : Distribution> + +/** + * A multivariate distribution that has independent distributions for separate axis + */ +class FactorizedDistribution(val distributions: Collection>) : NamedDistribution { + + override fun probability(arg: Map): Double { + return distributions.fold(1.0) { acc, distr -> acc * distr.probability(arg) } + } + + override fun sample(generator: RandomGenerator): Chain> { + val chains = distributions.map { it.sample(generator) } + return SimpleChain> { + chains.fold(emptyMap()) { acc, chain -> acc + chain.next() } + } + } +} \ No newline at end of file