forked from kscience/kmath
Generification of autodiff and chi2
This commit is contained in:
parent
9a147d033e
commit
1c1580c8e6
@ -7,6 +7,6 @@ dependencies {
|
|||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
api(project(":kmath-prob"))
|
api(project(":kmath-prob"))
|
||||||
// api(project(":kmath-functions"))
|
api(project(":kmath-functions"))
|
||||||
api("org.apache.commons:commons-math3:3.6.1")
|
api("org.apache.commons:commons-math3:3.6.1")
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ internal class OptimizeTest {
|
|||||||
|
|
||||||
val sigma = 1.0
|
val sigma = 1.0
|
||||||
val generator = Distribution.normal(0.0, sigma)
|
val generator = Distribution.normal(0.0, sigma)
|
||||||
val chain = generator.sample(RandomGenerator.default(1126))
|
val chain = generator.sample(RandomGenerator.default(112667))
|
||||||
val x = (1..100).map { it.toDouble() }
|
val x = (1..100).map { it.toDouble() }
|
||||||
val y = x.map { it ->
|
val y = x.map { it ->
|
||||||
it.pow(2) + it + 1 + chain.nextDouble()
|
it.pow(2) + it + 1 + chain.nextDouble()
|
||||||
@ -54,7 +54,8 @@ internal class OptimizeTest {
|
|||||||
val yErr = x.map { sigma }
|
val yErr = x.map { sigma }
|
||||||
with(CMFit) {
|
with(CMFit) {
|
||||||
val chi2 = chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
|
val chi2 = chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
|
||||||
bind(a) * x.pow(2) + bind(b) * x + bind(c)
|
val cWithDefault = bindOrNull(c)?: one
|
||||||
|
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
||||||
|
@ -8,13 +8,13 @@ import kotlin.contracts.contract
|
|||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
// TODO make `inline`, when KT-41771 gets fixed
|
|
||||||
/**
|
/**
|
||||||
* Polynomial coefficients without fixation on specific context they are applied to
|
* Polynomial coefficients without fixation on specific context they are applied to
|
||||||
* @param coefficients constant is the leftmost coefficient
|
* @param coefficients constant is the leftmost coefficient
|
||||||
*/
|
*/
|
||||||
public inline class Polynomial<T : Any>(public val coefficients: List<T>)
|
public inline class Polynomial<T : Any>(public val coefficients: List<T>)
|
||||||
|
|
||||||
|
@Suppress("FunctionName")
|
||||||
public fun <T : Any> Polynomial(vararg coefficients: T): Polynomial<T> = Polynomial(coefficients.toList())
|
public fun <T : Any> Polynomial(vararg coefficients: T): Polynomial<T> = Polynomial(coefficients.toList())
|
||||||
|
|
||||||
public fun Polynomial<Double>.value(): Double = coefficients.reduceIndexed { index, acc, d -> acc + d.pow(index) }
|
public fun Polynomial<Double>.value(): Double = coefficients.reduceIndexed { index, acc, d -> acc + d.pow(index) }
|
||||||
@ -33,14 +33,6 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring
|
|||||||
res
|
res
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Represent a polynomial as a context-dependent function
|
|
||||||
*/
|
|
||||||
public fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, C, T> =
|
|
||||||
object : MathFunction<T, C, T> {
|
|
||||||
override fun C.invoke(arg: T): T = value(this, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent the polynomial as a regular context-less function
|
* Represent the polynomial as a regular context-less function
|
||||||
*/
|
*/
|
||||||
@ -49,7 +41,7 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.asFunction(ring: C): (T) -> T =
|
|||||||
/**
|
/**
|
||||||
* An algebra for polynomials
|
* An algebra for polynomials
|
||||||
*/
|
*/
|
||||||
public class PolynomialSpace<T : Any, C : Ring<T>>(public val ring: C) : Space<Polynomial<T>> {
|
public class PolynomialSpace<T : Any, C : Ring<T>>(private val ring: C) : Space<Polynomial<T>> {
|
||||||
public override val zero: Polynomial<T> = Polynomial(emptyList())
|
public override val zero: Polynomial<T> = Polynomial(emptyList())
|
||||||
|
|
||||||
public override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
public override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
package kscience.kmath.functions
|
|
||||||
|
|
||||||
import kscience.kmath.operations.Algebra
|
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
|
|
||||||
// TODO make fun interface when KT-41770 is fixed
|
|
||||||
/**
|
|
||||||
* A regular function that could be called only inside specific algebra context
|
|
||||||
* @param T source type
|
|
||||||
* @param C source algebra constraint
|
|
||||||
* @param R result type
|
|
||||||
*/
|
|
||||||
public /*fun*/ interface MathFunction<T, C : Algebra<T>, R> {
|
|
||||||
public operator fun C.invoke(arg: T): R
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <R> MathFunction<Double, RealField, R>.invoke(arg: Double): R = RealField.invoke(arg)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A suspendable function defined in algebraic context
|
|
||||||
*/
|
|
||||||
// TODO make fun interface, when the new JVM IR is enabled
|
|
||||||
public interface SuspendableMathFunction<T, C : Algebra<T>, R> {
|
|
||||||
public suspend operator fun C.invoke(arg: T): R
|
|
||||||
}
|
|
||||||
|
|
||||||
public suspend fun <R> SuspendableMathFunction<Double, RealField, R>.invoke(arg: Double): R = RealField.invoke(arg)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A parametric function with parameter
|
|
||||||
*/
|
|
||||||
public fun interface ParametricFunction<T, P, C : Algebra<T>> {
|
|
||||||
public operator fun C.invoke(arg: T, parameter: P): T
|
|
||||||
}
|
|
36
kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fit.kt
Normal file
36
kmath-prob/src/commonMain/kotlin/kscience/kmath/prob/Fit.kt
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package kscience.kmath.prob
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.AutoDiffProcessor
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import kscience.kmath.expressions.ExpressionAlgebra
|
||||||
|
import kscience.kmath.operations.ExtendedField
|
||||||
|
import kscience.kmath.structures.Buffer
|
||||||
|
import kscience.kmath.structures.indices
|
||||||
|
|
||||||
|
public object Fit {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
|
*/
|
||||||
|
public fun <T : Any, I : Any, A> chiSquared(
|
||||||
|
autoDiff: AutoDiffProcessor<T, I, A>,
|
||||||
|
x: Buffer<T>,
|
||||||
|
y: Buffer<T>,
|
||||||
|
yErr: Buffer<T>,
|
||||||
|
model: A.(I) -> I,
|
||||||
|
): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||||
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
return autoDiff.process {
|
||||||
|
var sum = zero
|
||||||
|
x.indices.forEach {
|
||||||
|
val xValue = const(x[it])
|
||||||
|
val yValue = const(y[it])
|
||||||
|
val yErrValue = const(yErr[it])
|
||||||
|
val modelValue = model(xValue)
|
||||||
|
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
||||||
|
}
|
||||||
|
sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,91 @@
|
|||||||
|
package kscience.kmath.commons.optimization
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import kscience.kmath.expressions.Expression
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
|
|
||||||
|
public interface OptimizationFeature
|
||||||
|
|
||||||
|
public class OptimizationResult<T>(
|
||||||
|
public val point: Map<Symbol, T>,
|
||||||
|
public val value: T,
|
||||||
|
public val features: Set<OptimizationFeature> = emptySet(),
|
||||||
|
) {
|
||||||
|
override fun toString(): String {
|
||||||
|
return "OptimizationResult(point=$point, value=$value)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun <T> OptimizationResult<T>.plus(
|
||||||
|
feature: OptimizationFeature,
|
||||||
|
): OptimizationResult<T> = OptimizationResult(point, value, features + feature)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A configuration builder for optimization problem
|
||||||
|
*/
|
||||||
|
public interface OptimizationProblem<T : Any> {
|
||||||
|
/**
|
||||||
|
* Define the initial guess for the optimization problem
|
||||||
|
*/
|
||||||
|
public fun initialGuess(map: Map<Symbol, T>): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set an objective function expression
|
||||||
|
*/
|
||||||
|
public fun expression(expression: Expression<T>): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set a differentiable expression as objective function as function and gradient provider
|
||||||
|
*/
|
||||||
|
public fun diffExpression(expression: DifferentiableExpression<T>): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the problem from previous optimization run
|
||||||
|
*/
|
||||||
|
public fun update(result: OptimizationResult<T>)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Make an optimization run
|
||||||
|
*/
|
||||||
|
public fun optimize(): OptimizationResult<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
|
||||||
|
public fun build(symbols: List<Symbol>): P
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
|
||||||
|
symbols: List<Symbol>,
|
||||||
|
block: P.() -> Unit,
|
||||||
|
): P = build(symbols).apply(block)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : OptimizationProblem<T>> Expression<T>.optimizeWith(
|
||||||
|
factory: OptimizationProblemFactory<T, F>,
|
||||||
|
vararg symbols: Symbol,
|
||||||
|
configuration: F.() -> Unit,
|
||||||
|
): OptimizationResult<T> {
|
||||||
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
|
val problem = factory(symbols.toList(),configuration)
|
||||||
|
problem.expression(this)
|
||||||
|
return problem.optimize()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.optimizeWith(
|
||||||
|
factory: OptimizationProblemFactory<T, F>,
|
||||||
|
vararg symbols: Symbol,
|
||||||
|
configuration: F.() -> Unit,
|
||||||
|
): OptimizationResult<T> {
|
||||||
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
|
val problem = factory(symbols.toList(), configuration)
|
||||||
|
problem.diffExpression(this)
|
||||||
|
return problem.optimize()
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user