diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt index f7c136ed2..b5ea59d6b 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt @@ -36,7 +36,7 @@ public class CMOptimizationProblem( addOptimizationData(InitialGuess(map.toDoubleArray())) } - public fun expression(expression: Expression): Unit { + public override fun expression(expression: Expression): Unit { val objectiveFunction = ObjectiveFunction { val args = it.toMap() expression(args) @@ -44,7 +44,7 @@ public class CMOptimizationProblem( addOptimizationData(objectiveFunction) } - public fun derivatives(expression: DifferentiableExpression): Unit { + public override fun diffExpression(expression: DifferentiableExpression): Unit { expression(expression) val gradientFunction = ObjectiveFunctionGradient { val args = it.toMap() @@ -83,6 +83,10 @@ public class CMOptimizationProblem( optimizatorBuilder = block } + override fun update(result: OptimizationResult) { + initialGuess(result.point) + } + override fun optimize(): OptimizationResult { val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined") val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray()) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt index 56291e09c..e52450be1 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/OptimizationProblem.kt @@ -1,17 +1,32 @@ package kscience.kmath.commons.optimization +import kscience.kmath.expressions.DifferentiableExpression +import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Symbol -import kotlin.reflect.KClass -public typealias ParameterSpacePoint = Map +public interface OptimizationResultFeature + public class OptimizationResult( - public val point: ParameterSpacePoint, + public val point: Map, public val value: T, - public val extra: Map, Any> = emptyMap() + public val features: Set = emptySet(), ) +/** + * A configuration builder for optimization problem + */ public interface OptimizationProblem { + /** + * Set an objective function expression + */ + public fun expression(expression: Expression): Unit + + /** + * + */ + public fun diffExpression(expression: DifferentiableExpression): Unit + public fun update(result: OptimizationResult) public fun optimize(): OptimizationResult } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt index a49949b93..c4bd5704e 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/optimize.kt @@ -13,7 +13,7 @@ public fun Expression.optimize( configuration: CMOptimizationProblem.() -> Unit, ): OptimizationResult { require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration) + val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) problem.expression(this) return problem.optimize() } @@ -26,7 +26,7 @@ public fun DifferentiableExpression.optimize( configuration: CMOptimizationProblem.() -> Unit, ): OptimizationResult { require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } - val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration) - problem.derivatives(this) + val problem = CMOptimizationProblem(symbols.toList()).apply(configuration) + problem.diffExpression(this) return problem.optimize() } \ No newline at end of file diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index 65d61dcd1..bd7870573 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -9,16 +9,14 @@ internal class OptimizeTest { val y by symbol val normal = DerivativeStructureExpression { - val x = bind(x) - val y = bind(y) - exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2) + exp(-bind(x).pow(2) / 2) + exp(- bind(y).pow(2) / 2) } @Test fun testOptimization() { val result = normal.optimize(x, y) { initialGuess(x to 1.0, y to 1.0) - //no need to select optimizer. Gradient optimizer is used by default + //no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function } println(result.point) println(result.value)