Cleanup
This commit is contained in:
parent
1fbe12149d
commit
57781678e5
@ -36,7 +36,7 @@ public class CMOptimizationProblem(
|
|||||||
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun expression(expression: Expression<Double>): Unit {
|
public override fun expression(expression: Expression<Double>): Unit {
|
||||||
val objectiveFunction = ObjectiveFunction {
|
val objectiveFunction = ObjectiveFunction {
|
||||||
val args = it.toMap()
|
val args = it.toMap()
|
||||||
expression(args)
|
expression(args)
|
||||||
@ -44,7 +44,7 @@ public class CMOptimizationProblem(
|
|||||||
addOptimizationData(objectiveFunction)
|
addOptimizationData(objectiveFunction)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun derivatives(expression: DifferentiableExpression<Double>): Unit {
|
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit {
|
||||||
expression(expression)
|
expression(expression)
|
||||||
val gradientFunction = ObjectiveFunctionGradient {
|
val gradientFunction = ObjectiveFunctionGradient {
|
||||||
val args = it.toMap()
|
val args = it.toMap()
|
||||||
@ -83,6 +83,10 @@ public class CMOptimizationProblem(
|
|||||||
optimizatorBuilder = block
|
optimizatorBuilder = block
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun update(result: OptimizationResult<Double>) {
|
||||||
|
initialGuess(result.point)
|
||||||
|
}
|
||||||
|
|
||||||
override fun optimize(): OptimizationResult<Double> {
|
override fun optimize(): OptimizationResult<Double> {
|
||||||
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
|
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
|
||||||
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
||||||
|
@ -1,17 +1,32 @@
|
|||||||
package kscience.kmath.commons.optimization
|
package kscience.kmath.commons.optimization
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.expressions.Symbol
|
import kscience.kmath.expressions.Symbol
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
public typealias ParameterSpacePoint<T> = Map<Symbol, T>
|
public interface OptimizationResultFeature
|
||||||
|
|
||||||
|
|
||||||
public class OptimizationResult<T>(
|
public class OptimizationResult<T>(
|
||||||
public val point: ParameterSpacePoint<T>,
|
public val point: Map<Symbol, T>,
|
||||||
public val value: T,
|
public val value: T,
|
||||||
public val extra: Map<KClass<*>, Any> = emptyMap()
|
public val features: Set<OptimizationResultFeature> = emptySet(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A configuration builder for optimization problem
|
||||||
|
*/
|
||||||
public interface OptimizationProblem<T : Any> {
|
public interface OptimizationProblem<T : Any> {
|
||||||
|
/**
|
||||||
|
* Set an objective function expression
|
||||||
|
*/
|
||||||
|
public fun expression(expression: Expression<Double>): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public fun diffExpression(expression: DifferentiableExpression<Double>): Unit
|
||||||
|
public fun update(result: OptimizationResult<T>)
|
||||||
public fun optimize(): OptimizationResult<T>
|
public fun optimize(): OptimizationResult<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ public fun Expression<Double>.optimize(
|
|||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> {
|
||||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
|
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
||||||
problem.expression(this)
|
problem.expression(this)
|
||||||
return problem.optimize()
|
return problem.optimize()
|
||||||
}
|
}
|
||||||
@ -26,7 +26,7 @@ public fun DifferentiableExpression<Double>.optimize(
|
|||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> {
|
||||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
|
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
||||||
problem.derivatives(this)
|
problem.diffExpression(this)
|
||||||
return problem.optimize()
|
return problem.optimize()
|
||||||
}
|
}
|
@ -9,16 +9,14 @@ internal class OptimizeTest {
|
|||||||
val y by symbol
|
val y by symbol
|
||||||
|
|
||||||
val normal = DerivativeStructureExpression {
|
val normal = DerivativeStructureExpression {
|
||||||
val x = bind(x)
|
exp(-bind(x).pow(2) / 2) + exp(- bind(y).pow(2) / 2)
|
||||||
val y = bind(y)
|
|
||||||
exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testOptimization() {
|
fun testOptimization() {
|
||||||
val result = normal.optimize(x, y) {
|
val result = normal.optimize(x, y) {
|
||||||
initialGuess(x to 1.0, y to 1.0)
|
initialGuess(x to 1.0, y to 1.0)
|
||||||
//no need to select optimizer. Gradient optimizer is used by default
|
//no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function
|
||||||
}
|
}
|
||||||
println(result.point)
|
println(result.point)
|
||||||
println(result.value)
|
println(result.value)
|
||||||
|
Loading…
Reference in New Issue
Block a user