forked from kscience/kmath
Fix orders in DerivativeStructures
This commit is contained in:
parent
30132964dd
commit
4450c0fcc7
@ -38,13 +38,13 @@ public class DerivativeStructureField(
|
||||
key.identity to DerivativeStructureSymbol(key, value)
|
||||
}
|
||||
|
||||
override fun const(value: Double): DerivativeStructure = DerivativeStructure(order, bindings.size, value)
|
||||
override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value)
|
||||
|
||||
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||
|
||||
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
||||
|
||||
public fun Number.const(): DerivativeStructure = const(toDouble())
|
||||
//public fun Number.const(): DerivativeStructure = const(toDouble())
|
||||
|
||||
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
||||
return derivative(mapOf(parameter to order))
|
||||
|
@ -17,9 +17,9 @@ public object CMFit {
|
||||
|
||||
/**
|
||||
* Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives
|
||||
* TODO move to core/separate module
|
||||
* TODO move to prob/stat
|
||||
*/
|
||||
public fun chiSquaredExpression(
|
||||
public fun chiSquared(
|
||||
x: Buffer<Double>,
|
||||
y: Buffer<Double>,
|
||||
yErr: Buffer<Double>,
|
||||
@ -35,7 +35,7 @@ public object CMFit {
|
||||
val yErrValue = yErr[it]
|
||||
val modifiedArgs = arguments + (xSymbol to xValue)
|
||||
val modelValue = model(modifiedArgs)
|
||||
((yValue - modelValue) / yErrValue).pow(2) / 2
|
||||
((yValue - modelValue) / yErrValue).pow(2)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -43,7 +43,7 @@ public object CMFit {
|
||||
/**
|
||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||
*/
|
||||
public fun chiSquaredExpression(
|
||||
public fun chiSquared(
|
||||
x: Buffer<Double>,
|
||||
y: Buffer<Double>,
|
||||
yErr: Buffer<Double>,
|
||||
@ -58,7 +58,7 @@ public object CMFit {
|
||||
val yValue = y[it]
|
||||
val yErrValue = yErr[it]
|
||||
val modelValue = model(const(xValue))
|
||||
sum += ((yValue - modelValue) / yErrValue).pow(2) / 2
|
||||
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
||||
}
|
||||
sum
|
||||
}
|
||||
@ -92,12 +92,13 @@ public fun DifferentiableExpression<Double>.optimize(
|
||||
}
|
||||
|
||||
public fun DifferentiableExpression<Double>.minimize(
|
||||
vararg symbols: Symbol,
|
||||
configuration: CMOptimizationProblem.() -> Unit,
|
||||
vararg startPoint: Pair<Symbol, Double>,
|
||||
configuration: CMOptimizationProblem.() -> Unit = {},
|
||||
): OptimizationResult<Double> {
|
||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
||||
require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||
val problem = CMOptimizationProblem(startPoint.map { it.first }).apply(configuration)
|
||||
problem.diffExpression(this)
|
||||
problem.initialGuess(startPoint.toMap())
|
||||
problem.goal(GoalType.MINIMIZE)
|
||||
return problem.optimize()
|
||||
}
|
@ -17,14 +17,13 @@ public operator fun PointValuePair.component2(): Double = value
|
||||
|
||||
public class CMOptimizationProblem(
|
||||
override val symbols: List<Symbol>,
|
||||
) : OptimizationProblem<Double>, SymbolIndexer {
|
||||
protected val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||
) : OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
|
||||
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
||||
|
||||
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
||||
DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER)
|
||||
|
||||
private fun addOptimizationData(data: OptimizationData) {
|
||||
public fun addOptimizationData(data: OptimizationData) {
|
||||
optimizationData[data::class] = data
|
||||
}
|
||||
|
||||
@ -32,6 +31,8 @@ public class CMOptimizationProblem(
|
||||
addOptimizationData(MaxEval.unlimited())
|
||||
}
|
||||
|
||||
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
||||
|
||||
public fun initialGuess(map: Map<Symbol, Double>): Unit {
|
||||
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||
}
|
||||
@ -90,7 +91,7 @@ public class CMOptimizationProblem(
|
||||
override fun optimize(): OptimizationResult<Double> {
|
||||
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
|
||||
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
||||
return OptimizationResult(point.toMap(), value)
|
||||
return OptimizationResult(point.toMap(), value, setOf(this))
|
||||
}
|
||||
|
||||
public companion object {
|
||||
|
@ -4,14 +4,19 @@ import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
|
||||
public interface OptimizationResultFeature
|
||||
public interface OptimizationFeature
|
||||
|
||||
//TODO move to prob/stat
|
||||
|
||||
public class OptimizationResult<T>(
|
||||
public val point: Map<Symbol, T>,
|
||||
public val value: T,
|
||||
public val features: Set<OptimizationResultFeature> = emptySet(),
|
||||
)
|
||||
public val features: Set<OptimizationFeature> = emptySet(),
|
||||
){
|
||||
override fun toString(): String {
|
||||
return "OptimizationResult(point=$point, value=$value)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A configuration builder for optimization problem
|
||||
|
@ -2,8 +2,9 @@ package kscience.kmath.commons.random
|
||||
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
|
||||
public class CMRandomGeneratorWrapper(public val factory: (IntArray) -> RandomGenerator) :
|
||||
org.apache.commons.math3.random.RandomGenerator {
|
||||
public class CMRandomGeneratorWrapper(
|
||||
public val factory: (IntArray) -> RandomGenerator,
|
||||
) : org.apache.commons.math3.random.RandomGenerator {
|
||||
private var generator: RandomGenerator = factory(intArrayOf())
|
||||
|
||||
public override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||
|
@ -2,14 +2,19 @@ package kscience.kmath.commons.optimization
|
||||
|
||||
import kscience.kmath.commons.expressions.DerivativeStructureExpression
|
||||
import kscience.kmath.expressions.symbol
|
||||
import kscience.kmath.prob.Distribution
|
||||
import kscience.kmath.prob.RandomGenerator
|
||||
import kscience.kmath.prob.normal
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import org.junit.jupiter.api.Test
|
||||
import kotlin.math.pow
|
||||
|
||||
internal class OptimizeTest {
|
||||
val x by symbol
|
||||
val y by symbol
|
||||
|
||||
val normal = DerivativeStructureExpression {
|
||||
exp(-bind(x).pow(2) / 2) + exp(- bind(y).pow(2) / 2)
|
||||
exp(-bind(x).pow(2) / 2) + exp(-bind(y).pow(2) / 2)
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -32,4 +37,29 @@ internal class OptimizeTest {
|
||||
println(result.point)
|
||||
println(result.value)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFit() {
|
||||
val a by symbol
|
||||
val b by symbol
|
||||
val c by symbol
|
||||
|
||||
val sigma = 1.0
|
||||
val generator = Distribution.normal(0.0, sigma)
|
||||
val chain = generator.sample(RandomGenerator.default(1126))
|
||||
val x = (1..100).map { it.toDouble() }
|
||||
val y = x.map { it ->
|
||||
it.pow(2) + it + 1 + chain.nextDouble()
|
||||
}
|
||||
val yErr = x.map { sigma }
|
||||
with(CMFit) {
|
||||
val chi2 = chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
|
||||
bind(a) * x.pow(2) + bind(b) * x + bind(c)
|
||||
}
|
||||
|
||||
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
||||
println(result)
|
||||
println("Chi2/dof = ${result.value / (x.size - 3)}")
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user