[WIP] optimization refactoring
This commit is contained in:
parent
f84f7f8e45
commit
c8621ee5b7
@ -35,7 +35,7 @@ public class CMOptimizerData(public val data: List<OptimizationData>) : Optimiza
|
|||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public class CMOptimization : Optimizer<FunctionOptimization<Double>> {
|
public class CMOptimization : Optimizer<FunctionOptimization<Double>> {
|
||||||
|
|
||||||
override suspend fun process(
|
override suspend fun optimize(
|
||||||
problem: FunctionOptimization<Double>,
|
problem: FunctionOptimization<Double>,
|
||||||
): FunctionOptimization<Double> {
|
): FunctionOptimization<Double> {
|
||||||
val startPoint = problem.getFeature<OptimizationStartPoint<Double>>()?.point
|
val startPoint = problem.getFeature<OptimizationStartPoint<Double>>()?.point
|
||||||
|
@ -67,12 +67,12 @@ public fun <T> FunctionOptimization<T>.withFeatures(
|
|||||||
/**
|
/**
|
||||||
* Optimize differentiable expression using specific [optimizer] form given [startingPoint]
|
* Optimize differentiable expression using specific [optimizer] form given [startingPoint]
|
||||||
*/
|
*/
|
||||||
public suspend fun <T : Any> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
public suspend fun <T : Any> DifferentiableExpression<T>.optimizeWith(
|
||||||
optimizer: Optimizer<FunctionOptimization<T>>,
|
optimizer: Optimizer<FunctionOptimization<T>>,
|
||||||
startingPoint: Map<Symbol, T>,
|
startingPoint: Map<Symbol, T>,
|
||||||
vararg features: OptimizationFeature,
|
vararg features: OptimizationFeature,
|
||||||
): FunctionOptimization<T> {
|
): FunctionOptimization<T> {
|
||||||
val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this)
|
val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this)
|
||||||
return optimizer.process(problem)
|
return optimizer.optimize(problem)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,7 +40,3 @@ public class OptimizationParameters(public val symbols: List<Symbol>): Optimizat
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public interface Optimizer<P : OptimizationProblem> {
|
|
||||||
public suspend fun process(problem: P): P
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -0,0 +1,10 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.optimization
|
||||||
|
|
||||||
|
public interface Optimizer<P : OptimizationProblem> {
|
||||||
|
public suspend fun optimize(problem: P): P
|
||||||
|
}
|
@ -12,39 +12,63 @@ import space.kscience.kmath.expressions.Expression
|
|||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.misc.FeatureSet
|
import space.kscience.kmath.misc.FeatureSet
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.Field
|
import kotlin.math.pow
|
||||||
import space.kscience.kmath.operations.invoke
|
|
||||||
|
|
||||||
public fun interface PointToCurveDistance<T> {
|
public interface PointToCurveDistance : OptimizationFeature {
|
||||||
public fun distance(problem: XYOptimization<T>, index: Int): DifferentiableExpression<T, Expression<T>>
|
public fun distance(problem: XYOptimization, index: Int): DifferentiableExpression<Double>
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public fun <T> byY(
|
public val byY: PointToCurveDistance = object : PointToCurveDistance {
|
||||||
algebra: Field<T>,
|
override fun distance(problem: XYOptimization, index: Int): DifferentiableExpression<Double> {
|
||||||
): PointToCurveDistance<T> = PointToCurveDistance { problem, index ->
|
|
||||||
algebra {
|
|
||||||
val x = problem.data.x[index]
|
val x = problem.data.x[index]
|
||||||
val y = problem.data.y[index]
|
val y = problem.data.y[index]
|
||||||
|
return object : DifferentiableExpression<Double> {
|
||||||
val model = problem.model(args + (Symbol.x to x))
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double>? =
|
||||||
model - y
|
problem.model.derivativeOrNull(symbols)
|
||||||
|
|
||||||
|
override fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
|
problem.model(arguments + (Symbol.x to x)) - y
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun toString(): String = "PointToCurveDistanceByY"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// val default = PointToCurveDistance<Double>{args, data, index ->
|
|
||||||
//
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class XYOptimization(
|
||||||
public class XYOptimization<T>(
|
|
||||||
override val features: FeatureSet<OptimizationFeature>,
|
override val features: FeatureSet<OptimizationFeature>,
|
||||||
public val data: XYColumnarData<T, T, T>,
|
public val data: XYColumnarData<Double, Double, Double>,
|
||||||
public val model: DifferentiableExpression<T, Expression<T>>,
|
public val model: DifferentiableExpression<Double>,
|
||||||
) : OptimizationProblem
|
) : OptimizationProblem
|
||||||
|
|
||||||
|
|
||||||
|
public suspend fun Optimizer<FunctionOptimization<Double>>.maximumLogLikelihood(problem: XYOptimization): XYOptimization {
|
||||||
|
val distanceBuilder = problem.getFeature() ?: PointToCurveDistance.byY
|
||||||
|
val likelihood: DifferentiableExpression<Double> = object : DifferentiableExpression<Double> {
|
||||||
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double>? {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun invoke(arguments: Map<Symbol, Double>): Double {
|
||||||
|
var res = 0.0
|
||||||
|
for (index in 0 until problem.data.size) {
|
||||||
|
val d = distanceBuilder.distance(problem, index).invoke(arguments)
|
||||||
|
val sigma: Double = TODO()
|
||||||
|
res -= (d / sigma).pow(2)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
val functionOptimization = FunctionOptimization(problem.features, likelihood)
|
||||||
|
val result = optimize(functionOptimization)
|
||||||
|
return XYOptimization(result.features, problem.data, problem.model)
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
//@UnstableKMathAPI
|
//@UnstableKMathAPI
|
||||||
//public interface XYFit<T> : OptimizationProblem {
|
//public interface XYFit<T> : OptimizationProblem {
|
||||||
|
Loading…
Reference in New Issue
Block a user