Cleanup
This commit is contained in:
parent
1fbe12149d
commit
57781678e5
@ -36,7 +36,7 @@ public class CMOptimizationProblem(
|
||||
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||
}
|
||||
|
||||
public fun expression(expression: Expression<Double>): Unit {
|
||||
public override fun expression(expression: Expression<Double>): Unit {
|
||||
val objectiveFunction = ObjectiveFunction {
|
||||
val args = it.toMap()
|
||||
expression(args)
|
||||
@ -44,7 +44,7 @@ public class CMOptimizationProblem(
|
||||
addOptimizationData(objectiveFunction)
|
||||
}
|
||||
|
||||
public fun derivatives(expression: DifferentiableExpression<Double>): Unit {
|
||||
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit {
|
||||
expression(expression)
|
||||
val gradientFunction = ObjectiveFunctionGradient {
|
||||
val args = it.toMap()
|
||||
@ -83,6 +83,10 @@ public class CMOptimizationProblem(
|
||||
optimizatorBuilder = block
|
||||
}
|
||||
|
||||
override fun update(result: OptimizationResult<Double>) {
|
||||
initialGuess(result.point)
|
||||
}
|
||||
|
||||
override fun optimize(): OptimizationResult<Double> {
|
||||
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
|
||||
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
||||
|
@ -1,17 +1,32 @@
|
||||
package kscience.kmath.commons.optimization
|
||||
|
||||
import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
public typealias ParameterSpacePoint<T> = Map<Symbol, T>
|
||||
public interface OptimizationResultFeature
|
||||
|
||||
|
||||
public class OptimizationResult<T>(
|
||||
public val point: ParameterSpacePoint<T>,
|
||||
public val point: Map<Symbol, T>,
|
||||
public val value: T,
|
||||
public val extra: Map<KClass<*>, Any> = emptyMap()
|
||||
public val features: Set<OptimizationResultFeature> = emptySet(),
|
||||
)
|
||||
|
||||
/**
|
||||
* A configuration builder for optimization problem
|
||||
*/
|
||||
public interface OptimizationProblem<T : Any> {
|
||||
/**
|
||||
* Set an objective function expression
|
||||
*/
|
||||
public fun expression(expression: Expression<Double>): Unit
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public fun diffExpression(expression: DifferentiableExpression<Double>): Unit
|
||||
public fun update(result: OptimizationResult<T>)
|
||||
public fun optimize(): OptimizationResult<T>
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@ public fun Expression<Double>.optimize(
|
||||
configuration: CMOptimizationProblem.() -> Unit,
|
||||
): OptimizationResult<Double> {
|
||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
|
||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
||||
problem.expression(this)
|
||||
return problem.optimize()
|
||||
}
|
||||
@ -26,7 +26,7 @@ public fun DifferentiableExpression<Double>.optimize(
|
||||
configuration: CMOptimizationProblem.() -> Unit,
|
||||
): OptimizationResult<Double> {
|
||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
|
||||
problem.derivatives(this)
|
||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
||||
problem.diffExpression(this)
|
||||
return problem.optimize()
|
||||
}
|
@ -9,16 +9,14 @@ internal class OptimizeTest {
|
||||
val y by symbol
|
||||
|
||||
val normal = DerivativeStructureExpression {
|
||||
val x = bind(x)
|
||||
val y = bind(y)
|
||||
exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2)
|
||||
exp(-bind(x).pow(2) / 2) + exp(- bind(y).pow(2) / 2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testOptimization() {
|
||||
val result = normal.optimize(x, y) {
|
||||
initialGuess(x to 1.0, y to 1.0)
|
||||
//no need to select optimizer. Gradient optimizer is used by default
|
||||
//no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function
|
||||
}
|
||||
println(result.point)
|
||||
println(result.value)
|
||||
|
Loading…
Reference in New Issue
Block a user