[WIP] optimization refactoring
This commit is contained in:
parent
f84f7f8e45
commit
c8621ee5b7
@ -35,7 +35,7 @@ public class CMOptimizerData(public val data: List<OptimizationData>) : Optimiza
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class CMOptimization : Optimizer<FunctionOptimization<Double>> {
|
||||
|
||||
override suspend fun process(
|
||||
override suspend fun optimize(
|
||||
problem: FunctionOptimization<Double>,
|
||||
): FunctionOptimization<Double> {
|
||||
val startPoint = problem.getFeature<OptimizationStartPoint<Double>>()?.point
|
||||
|
@ -67,12 +67,12 @@ public fun <T> FunctionOptimization<T>.withFeatures(
|
||||
/**
|
||||
* Optimize differentiable expression using specific [optimizer] form given [startingPoint]
|
||||
*/
|
||||
public suspend fun <T : Any> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
||||
public suspend fun <T : Any> DifferentiableExpression<T>.optimizeWith(
|
||||
optimizer: Optimizer<FunctionOptimization<T>>,
|
||||
startingPoint: Map<Symbol, T>,
|
||||
vararg features: OptimizationFeature,
|
||||
): FunctionOptimization<T> {
|
||||
val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this)
|
||||
return optimizer.process(problem)
|
||||
return optimizer.optimize(problem)
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,3 @@ public class OptimizationParameters(public val symbols: List<Symbol>): Optimizat
|
||||
}
|
||||
|
||||
|
||||
public interface Optimizer<P : OptimizationProblem> {
|
||||
public suspend fun process(problem: P): P
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,10 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.optimization
|
||||
|
||||
public interface Optimizer<P : OptimizationProblem> {
|
||||
public suspend fun optimize(problem: P): P
|
||||
}
|
@ -12,39 +12,63 @@ import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.expressions.Symbol
|
||||
import space.kscience.kmath.misc.FeatureSet
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.Field
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.math.pow
|
||||
|
||||
public fun interface PointToCurveDistance<T> {
|
||||
public fun distance(problem: XYOptimization<T>, index: Int): DifferentiableExpression<T, Expression<T>>
|
||||
public interface PointToCurveDistance : OptimizationFeature {
|
||||
public fun distance(problem: XYOptimization, index: Int): DifferentiableExpression<Double>
|
||||
|
||||
public companion object {
|
||||
public fun <T> byY(
|
||||
algebra: Field<T>,
|
||||
): PointToCurveDistance<T> = PointToCurveDistance { problem, index ->
|
||||
algebra {
|
||||
public val byY: PointToCurveDistance = object : PointToCurveDistance {
|
||||
override fun distance(problem: XYOptimization, index: Int): DifferentiableExpression<Double> {
|
||||
|
||||
val x = problem.data.x[index]
|
||||
val y = problem.data.y[index]
|
||||
|
||||
val model = problem.model(args + (Symbol.x to x))
|
||||
model - y
|
||||
return object : DifferentiableExpression<Double> {
|
||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double>? =
|
||||
problem.model.derivativeOrNull(symbols)
|
||||
|
||||
override fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||
problem.model(arguments + (Symbol.x to x)) - y
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String = "PointToCurveDistanceByY"
|
||||
|
||||
}
|
||||
|
||||
|
||||
// val default = PointToCurveDistance<Double>{args, data, index ->
|
||||
//
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public class XYOptimization<T>(
|
||||
public class XYOptimization(
|
||||
override val features: FeatureSet<OptimizationFeature>,
|
||||
public val data: XYColumnarData<T, T, T>,
|
||||
public val model: DifferentiableExpression<T, Expression<T>>,
|
||||
public val data: XYColumnarData<Double, Double, Double>,
|
||||
public val model: DifferentiableExpression<Double>,
|
||||
) : OptimizationProblem
|
||||
|
||||
|
||||
public suspend fun Optimizer<FunctionOptimization<Double>>.maximumLogLikelihood(problem: XYOptimization): XYOptimization {
|
||||
val distanceBuilder = problem.getFeature() ?: PointToCurveDistance.byY
|
||||
val likelihood: DifferentiableExpression<Double> = object : DifferentiableExpression<Double> {
|
||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double>? {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun invoke(arguments: Map<Symbol, Double>): Double {
|
||||
var res = 0.0
|
||||
for (index in 0 until problem.data.size) {
|
||||
val d = distanceBuilder.distance(problem, index).invoke(arguments)
|
||||
val sigma: Double = TODO()
|
||||
res -= (d / sigma).pow(2)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
}
|
||||
val functionOptimization = FunctionOptimization(problem.features, likelihood)
|
||||
val result = optimize(functionOptimization)
|
||||
return XYOptimization(result.features, problem.data, problem.model)
|
||||
}
|
||||
|
||||
//
|
||||
//@UnstableKMathAPI
|
||||
//public interface XYFit<T> : OptimizationProblem {
|
||||
|
Loading…
Reference in New Issue
Block a user