forked from kscience/kmath
Fix orders in DerivativeStructures
This commit is contained in:
parent
30132964dd
commit
4450c0fcc7
@ -38,13 +38,13 @@ public class DerivativeStructureField(
|
|||||||
key.identity to DerivativeStructureSymbol(key, value)
|
key.identity to DerivativeStructureSymbol(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun const(value: Double): DerivativeStructure = DerivativeStructure(order, bindings.size, value)
|
override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value)
|
||||||
|
|
||||||
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||||
|
|
||||||
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
||||||
|
|
||||||
public fun Number.const(): DerivativeStructure = const(toDouble())
|
//public fun Number.const(): DerivativeStructure = const(toDouble())
|
||||||
|
|
||||||
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
||||||
return derivative(mapOf(parameter to order))
|
return derivative(mapOf(parameter to order))
|
||||||
|
@ -17,9 +17,9 @@ public object CMFit {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives
|
* Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives
|
||||||
* TODO move to core/separate module
|
* TODO move to prob/stat
|
||||||
*/
|
*/
|
||||||
public fun chiSquaredExpression(
|
public fun chiSquared(
|
||||||
x: Buffer<Double>,
|
x: Buffer<Double>,
|
||||||
y: Buffer<Double>,
|
y: Buffer<Double>,
|
||||||
yErr: Buffer<Double>,
|
yErr: Buffer<Double>,
|
||||||
@ -35,7 +35,7 @@ public object CMFit {
|
|||||||
val yErrValue = yErr[it]
|
val yErrValue = yErr[it]
|
||||||
val modifiedArgs = arguments + (xSymbol to xValue)
|
val modifiedArgs = arguments + (xSymbol to xValue)
|
||||||
val modelValue = model(modifiedArgs)
|
val modelValue = model(modifiedArgs)
|
||||||
((yValue - modelValue) / yErrValue).pow(2) / 2
|
((yValue - modelValue) / yErrValue).pow(2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -43,7 +43,7 @@ public object CMFit {
|
|||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
*/
|
*/
|
||||||
public fun chiSquaredExpression(
|
public fun chiSquared(
|
||||||
x: Buffer<Double>,
|
x: Buffer<Double>,
|
||||||
y: Buffer<Double>,
|
y: Buffer<Double>,
|
||||||
yErr: Buffer<Double>,
|
yErr: Buffer<Double>,
|
||||||
@ -58,7 +58,7 @@ public object CMFit {
|
|||||||
val yValue = y[it]
|
val yValue = y[it]
|
||||||
val yErrValue = yErr[it]
|
val yErrValue = yErr[it]
|
||||||
val modelValue = model(const(xValue))
|
val modelValue = model(const(xValue))
|
||||||
sum += ((yValue - modelValue) / yErrValue).pow(2) / 2
|
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
||||||
}
|
}
|
||||||
sum
|
sum
|
||||||
}
|
}
|
||||||
@ -92,12 +92,13 @@ public fun DifferentiableExpression<Double>.optimize(
|
|||||||
}
|
}
|
||||||
|
|
||||||
public fun DifferentiableExpression<Double>.minimize(
|
public fun DifferentiableExpression<Double>.minimize(
|
||||||
vararg symbols: Symbol,
|
vararg startPoint: Pair<Symbol, Double>,
|
||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit = {},
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> {
|
||||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
val problem = CMOptimizationProblem(startPoint.map { it.first }).apply(configuration)
|
||||||
problem.diffExpression(this)
|
problem.diffExpression(this)
|
||||||
|
problem.initialGuess(startPoint.toMap())
|
||||||
problem.goal(GoalType.MINIMIZE)
|
problem.goal(GoalType.MINIMIZE)
|
||||||
return problem.optimize()
|
return problem.optimize()
|
||||||
}
|
}
|
@ -17,14 +17,13 @@ public operator fun PointValuePair.component2(): Double = value
|
|||||||
|
|
||||||
public class CMOptimizationProblem(
|
public class CMOptimizationProblem(
|
||||||
override val symbols: List<Symbol>,
|
override val symbols: List<Symbol>,
|
||||||
) : OptimizationProblem<Double>, SymbolIndexer {
|
) : OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
|
||||||
protected val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||||
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
||||||
|
|
||||||
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
||||||
DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER)
|
DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER)
|
||||||
|
|
||||||
private fun addOptimizationData(data: OptimizationData) {
|
public fun addOptimizationData(data: OptimizationData) {
|
||||||
optimizationData[data::class] = data
|
optimizationData[data::class] = data
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,6 +31,8 @@ public class CMOptimizationProblem(
|
|||||||
addOptimizationData(MaxEval.unlimited())
|
addOptimizationData(MaxEval.unlimited())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
||||||
|
|
||||||
public fun initialGuess(map: Map<Symbol, Double>): Unit {
|
public fun initialGuess(map: Map<Symbol, Double>): Unit {
|
||||||
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||||
}
|
}
|
||||||
@ -90,7 +91,7 @@ public class CMOptimizationProblem(
|
|||||||
override fun optimize(): OptimizationResult<Double> {
|
override fun optimize(): OptimizationResult<Double> {
|
||||||
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
|
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
|
||||||
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
||||||
return OptimizationResult(point.toMap(), value)
|
return OptimizationResult(point.toMap(), value, setOf(this))
|
||||||
}
|
}
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
|
@ -4,14 +4,19 @@ import kscience.kmath.expressions.DifferentiableExpression
|
|||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.expressions.Symbol
|
import kscience.kmath.expressions.Symbol
|
||||||
|
|
||||||
public interface OptimizationResultFeature
|
public interface OptimizationFeature
|
||||||
|
|
||||||
|
//TODO move to prob/stat
|
||||||
|
|
||||||
public class OptimizationResult<T>(
|
public class OptimizationResult<T>(
|
||||||
public val point: Map<Symbol, T>,
|
public val point: Map<Symbol, T>,
|
||||||
public val value: T,
|
public val value: T,
|
||||||
public val features: Set<OptimizationResultFeature> = emptySet(),
|
public val features: Set<OptimizationFeature> = emptySet(),
|
||||||
)
|
){
|
||||||
|
override fun toString(): String {
|
||||||
|
return "OptimizationResult(point=$point, value=$value)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A configuration builder for optimization problem
|
* A configuration builder for optimization problem
|
||||||
|
@ -2,8 +2,9 @@ package kscience.kmath.commons.random
|
|||||||
|
|
||||||
import kscience.kmath.prob.RandomGenerator
|
import kscience.kmath.prob.RandomGenerator
|
||||||
|
|
||||||
public class CMRandomGeneratorWrapper(public val factory: (IntArray) -> RandomGenerator) :
|
public class CMRandomGeneratorWrapper(
|
||||||
org.apache.commons.math3.random.RandomGenerator {
|
public val factory: (IntArray) -> RandomGenerator,
|
||||||
|
) : org.apache.commons.math3.random.RandomGenerator {
|
||||||
private var generator: RandomGenerator = factory(intArrayOf())
|
private var generator: RandomGenerator = factory(intArrayOf())
|
||||||
|
|
||||||
public override fun nextBoolean(): Boolean = generator.nextBoolean()
|
public override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||||
|
@ -2,14 +2,19 @@ package kscience.kmath.commons.optimization
|
|||||||
|
|
||||||
import kscience.kmath.commons.expressions.DerivativeStructureExpression
|
import kscience.kmath.commons.expressions.DerivativeStructureExpression
|
||||||
import kscience.kmath.expressions.symbol
|
import kscience.kmath.expressions.symbol
|
||||||
|
import kscience.kmath.prob.Distribution
|
||||||
|
import kscience.kmath.prob.RandomGenerator
|
||||||
|
import kscience.kmath.prob.normal
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
|
import kotlin.math.pow
|
||||||
|
|
||||||
internal class OptimizeTest {
|
internal class OptimizeTest {
|
||||||
val x by symbol
|
val x by symbol
|
||||||
val y by symbol
|
val y by symbol
|
||||||
|
|
||||||
val normal = DerivativeStructureExpression {
|
val normal = DerivativeStructureExpression {
|
||||||
exp(-bind(x).pow(2) / 2) + exp(- bind(y).pow(2) / 2)
|
exp(-bind(x).pow(2) / 2) + exp(-bind(y).pow(2) / 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -32,4 +37,29 @@ internal class OptimizeTest {
|
|||||||
println(result.point)
|
println(result.point)
|
||||||
println(result.value)
|
println(result.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testFit() {
|
||||||
|
val a by symbol
|
||||||
|
val b by symbol
|
||||||
|
val c by symbol
|
||||||
|
|
||||||
|
val sigma = 1.0
|
||||||
|
val generator = Distribution.normal(0.0, sigma)
|
||||||
|
val chain = generator.sample(RandomGenerator.default(1126))
|
||||||
|
val x = (1..100).map { it.toDouble() }
|
||||||
|
val y = x.map { it ->
|
||||||
|
it.pow(2) + it + 1 + chain.nextDouble()
|
||||||
|
}
|
||||||
|
val yErr = x.map { sigma }
|
||||||
|
with(CMFit) {
|
||||||
|
val chi2 = chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
|
||||||
|
bind(a) * x.pow(2) + bind(b) * x + bind(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
||||||
|
println(result)
|
||||||
|
println("Chi2/dof = ${result.value / (x.size - 3)}")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user