[WIP] optimization refactoring

This commit is contained in:
Alexander Nozik 2021-05-25 21:11:57 +03:00
parent f84f7f8e45
commit c8621ee5b7
5 changed files with 56 additions and 26 deletions

View File

@ -35,7 +35,7 @@ public class CMOptimizerData(public val data: List<OptimizationData>) : Optimiza
@OptIn(UnstableKMathAPI::class)
public class CMOptimization : Optimizer<FunctionOptimization<Double>> {
override suspend fun process(
override suspend fun optimize(
problem: FunctionOptimization<Double>,
): FunctionOptimization<Double> {
val startPoint = problem.getFeature<OptimizationStartPoint<Double>>()?.point

View File

@ -67,12 +67,12 @@ public fun <T> FunctionOptimization<T>.withFeatures(
/**
* Optimize differentiable expression using specific [optimizer] form given [startingPoint]
*/
public suspend fun <T : Any> DifferentiableExpression<T, Expression<T>>.optimizeWith(
public suspend fun <T : Any> DifferentiableExpression<T>.optimizeWith(
optimizer: Optimizer<FunctionOptimization<T>>,
startingPoint: Map<Symbol, T>,
vararg features: OptimizationFeature,
): FunctionOptimization<T> {
val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this)
return optimizer.process(problem)
return optimizer.optimize(problem)
}

View File

@ -40,7 +40,3 @@ public class OptimizationParameters(public val symbols: List<Symbol>): Optimizat
}
public interface Optimizer<P : OptimizationProblem> {
public suspend fun process(problem: P): P
}

View File

@ -0,0 +1,10 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.optimization
public interface Optimizer<P : OptimizationProblem> {
public suspend fun optimize(problem: P): P
}

View File

@ -12,39 +12,63 @@ import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.FeatureSet
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Field
import space.kscience.kmath.operations.invoke
import kotlin.math.pow
public fun interface PointToCurveDistance<T> {
public fun distance(problem: XYOptimization<T>, index: Int): DifferentiableExpression<T, Expression<T>>
public interface PointToCurveDistance : OptimizationFeature {
public fun distance(problem: XYOptimization, index: Int): DifferentiableExpression<Double>
public companion object {
public fun <T> byY(
algebra: Field<T>,
): PointToCurveDistance<T> = PointToCurveDistance { problem, index ->
algebra {
public val byY: PointToCurveDistance = object : PointToCurveDistance {
override fun distance(problem: XYOptimization, index: Int): DifferentiableExpression<Double> {
val x = problem.data.x[index]
val y = problem.data.y[index]
val model = problem.model(args + (Symbol.x to x))
model - y
return object : DifferentiableExpression<Double> {
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double>? =
problem.model.derivativeOrNull(symbols)
override fun invoke(arguments: Map<Symbol, Double>): Double =
problem.model(arguments + (Symbol.x to x)) - y
}
}
override fun toString(): String = "PointToCurveDistanceByY"
}
// val default = PointToCurveDistance<Double>{args, data, index ->
//
// }
}
}
public class XYOptimization<T>(
public class XYOptimization(
override val features: FeatureSet<OptimizationFeature>,
public val data: XYColumnarData<T, T, T>,
public val model: DifferentiableExpression<T, Expression<T>>,
public val data: XYColumnarData<Double, Double, Double>,
public val model: DifferentiableExpression<Double>,
) : OptimizationProblem
public suspend fun Optimizer<FunctionOptimization<Double>>.maximumLogLikelihood(problem: XYOptimization): XYOptimization {
val distanceBuilder = problem.getFeature() ?: PointToCurveDistance.byY
val likelihood: DifferentiableExpression<Double> = object : DifferentiableExpression<Double> {
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double>? {
TODO("Not yet implemented")
}
override fun invoke(arguments: Map<Symbol, Double>): Double {
var res = 0.0
for (index in 0 until problem.data.size) {
val d = distanceBuilder.distance(problem, index).invoke(arguments)
val sigma: Double = TODO()
res -= (d / sigma).pow(2)
}
return res
}
}
val functionOptimization = FunctionOptimization(problem.features, likelihood)
val result = optimize(functionOptimization)
return XYOptimization(result.features, problem.data, problem.model)
}
//
//@UnstableKMathAPI
//public interface XYFit<T> : OptimizationProblem {