[WIP] optimization refactoring

This commit is contained in:
Alexander Nozik 2021-05-25 21:11:57 +03:00
parent f84f7f8e45
commit c8621ee5b7
5 changed files with 56 additions and 26 deletions

View File

@ -35,7 +35,7 @@ public class CMOptimizerData(public val data: List<OptimizationData>) : Optimiza
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public class CMOptimization : Optimizer<FunctionOptimization<Double>> { public class CMOptimization : Optimizer<FunctionOptimization<Double>> {
override suspend fun process( override suspend fun optimize(
problem: FunctionOptimization<Double>, problem: FunctionOptimization<Double>,
): FunctionOptimization<Double> { ): FunctionOptimization<Double> {
val startPoint = problem.getFeature<OptimizationStartPoint<Double>>()?.point val startPoint = problem.getFeature<OptimizationStartPoint<Double>>()?.point

View File

@ -67,12 +67,12 @@ public fun <T> FunctionOptimization<T>.withFeatures(
/** /**
* Optimize differentiable expression using specific [optimizer] form given [startingPoint] * Optimize differentiable expression using specific [optimizer] form given [startingPoint]
*/ */
public suspend fun <T : Any> DifferentiableExpression<T, Expression<T>>.optimizeWith( public suspend fun <T : Any> DifferentiableExpression<T>.optimizeWith(
optimizer: Optimizer<FunctionOptimization<T>>, optimizer: Optimizer<FunctionOptimization<T>>,
startingPoint: Map<Symbol, T>, startingPoint: Map<Symbol, T>,
vararg features: OptimizationFeature, vararg features: OptimizationFeature,
): FunctionOptimization<T> { ): FunctionOptimization<T> {
val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this) val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this)
return optimizer.process(problem) return optimizer.optimize(problem)
} }

View File

@ -40,7 +40,3 @@ public class OptimizationParameters(public val symbols: List<Symbol>): Optimizat
} }
public interface Optimizer<P : OptimizationProblem> {
public suspend fun process(problem: P): P
}

View File

@ -0,0 +1,10 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.optimization
public interface Optimizer<P : OptimizationProblem> {
public suspend fun optimize(problem: P): P
}

View File

@ -12,39 +12,63 @@ import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.FeatureSet import space.kscience.kmath.misc.FeatureSet
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Field import kotlin.math.pow
import space.kscience.kmath.operations.invoke
public fun interface PointToCurveDistance<T> { public interface PointToCurveDistance : OptimizationFeature {
public fun distance(problem: XYOptimization<T>, index: Int): DifferentiableExpression<T, Expression<T>> public fun distance(problem: XYOptimization, index: Int): DifferentiableExpression<Double>
public companion object { public companion object {
public fun <T> byY( public val byY: PointToCurveDistance = object : PointToCurveDistance {
algebra: Field<T>, override fun distance(problem: XYOptimization, index: Int): DifferentiableExpression<Double> {
): PointToCurveDistance<T> = PointToCurveDistance { problem, index ->
algebra {
val x = problem.data.x[index] val x = problem.data.x[index]
val y = problem.data.y[index] val y = problem.data.y[index]
return object : DifferentiableExpression<Double> {
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double>? =
problem.model.derivativeOrNull(symbols)
val model = problem.model(args + (Symbol.x to x)) override fun invoke(arguments: Map<Symbol, Double>): Double =
model - y problem.model(arguments + (Symbol.x to x)) - y
}
} }
override fun toString(): String = "PointToCurveDistanceByY"
} }
// val default = PointToCurveDistance<Double>{args, data, index ->
//
// }
} }
} }
public class XYOptimization(
public class XYOptimization<T>(
override val features: FeatureSet<OptimizationFeature>, override val features: FeatureSet<OptimizationFeature>,
public val data: XYColumnarData<T, T, T>, public val data: XYColumnarData<Double, Double, Double>,
public val model: DifferentiableExpression<T, Expression<T>>, public val model: DifferentiableExpression<Double>,
) : OptimizationProblem ) : OptimizationProblem
public suspend fun Optimizer<FunctionOptimization<Double>>.maximumLogLikelihood(problem: XYOptimization): XYOptimization {
val distanceBuilder = problem.getFeature() ?: PointToCurveDistance.byY
val likelihood: DifferentiableExpression<Double> = object : DifferentiableExpression<Double> {
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double>? {
TODO("Not yet implemented")
}
override fun invoke(arguments: Map<Symbol, Double>): Double {
var res = 0.0
for (index in 0 until problem.data.size) {
val d = distanceBuilder.distance(problem, index).invoke(arguments)
val sigma: Double = TODO()
res -= (d / sigma).pow(2)
}
return res
}
}
val functionOptimization = FunctionOptimization(problem.features, likelihood)
val result = optimize(functionOptimization)
return XYOptimization(result.features, problem.data, problem.model)
}
// //
//@UnstableKMathAPI //@UnstableKMathAPI
//public interface XYFit<T> : OptimizationProblem { //public interface XYFit<T> : OptimizationProblem {