Feature/diff api #154

Merged
altavir merged 20 commits from feature/diff-api into dev 2020-10-28 13:25:24 +03:00
4 changed files with 30 additions and 13 deletions
Showing only changes of commit 57781678e5 - Show all commits

View File

@ -36,7 +36,7 @@ public class CMOptimizationProblem(
addOptimizationData(InitialGuess(map.toDoubleArray())) addOptimizationData(InitialGuess(map.toDoubleArray()))
} }
public fun expression(expression: Expression<Double>): Unit { public override fun expression(expression: Expression<Double>): Unit {
val objectiveFunction = ObjectiveFunction { val objectiveFunction = ObjectiveFunction {
val args = it.toMap() val args = it.toMap()
expression(args) expression(args)
@ -44,7 +44,7 @@ public class CMOptimizationProblem(
addOptimizationData(objectiveFunction) addOptimizationData(objectiveFunction)
} }
public fun derivatives(expression: DifferentiableExpression<Double>): Unit { public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit {
expression(expression) expression(expression)
val gradientFunction = ObjectiveFunctionGradient { val gradientFunction = ObjectiveFunctionGradient {
val args = it.toMap() val args = it.toMap()
@ -83,6 +83,10 @@ public class CMOptimizationProblem(
optimizatorBuilder = block optimizatorBuilder = block
} }
override fun update(result: OptimizationResult<Double>) {
initialGuess(result.point)
}
override fun optimize(): OptimizationResult<Double> { override fun optimize(): OptimizationResult<Double> {
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined") val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray()) val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())

View File

@ -1,17 +1,32 @@
package kscience.kmath.commons.optimization package kscience.kmath.commons.optimization
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.Symbol import kscience.kmath.expressions.Symbol
import kotlin.reflect.KClass
public typealias ParameterSpacePoint<T> = Map<Symbol, T> public interface OptimizationResultFeature
public class OptimizationResult<T>( public class OptimizationResult<T>(
public val point: ParameterSpacePoint<T>, public val point: Map<Symbol, T>,
public val value: T, public val value: T,
public val extra: Map<KClass<*>, Any> = emptyMap() public val features: Set<OptimizationResultFeature> = emptySet(),
) )
/**
* A configuration builder for optimization problem
*/
public interface OptimizationProblem<T : Any> { public interface OptimizationProblem<T : Any> {
/**
* Set an objective function expression
*/
public fun expression(expression: Expression<Double>): Unit
/**
*
*/
public fun diffExpression(expression: DifferentiableExpression<Double>): Unit
public fun update(result: OptimizationResult<T>)
public fun optimize(): OptimizationResult<T> public fun optimize(): OptimizationResult<T>
} }

View File

@ -13,7 +13,7 @@ public fun Expression<Double>.optimize(
configuration: CMOptimizationProblem.() -> Unit, configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> { ): OptimizationResult<Double> {
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration) val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
problem.expression(this) problem.expression(this)
return problem.optimize() return problem.optimize()
} }
@ -26,7 +26,7 @@ public fun DifferentiableExpression<Double>.optimize(
configuration: CMOptimizationProblem.() -> Unit, configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> { ): OptimizationResult<Double> {
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" } require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration) val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
problem.derivatives(this) problem.diffExpression(this)
return problem.optimize() return problem.optimize()
} }

View File

@ -9,16 +9,14 @@ internal class OptimizeTest {
val y by symbol val y by symbol
val normal = DerivativeStructureExpression { val normal = DerivativeStructureExpression {
val x = bind(x) exp(-bind(x).pow(2) / 2) + exp(- bind(y).pow(2) / 2)
val y = bind(y)
exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2)
} }
@Test @Test
fun testOptimization() { fun testOptimization() {
val result = normal.optimize(x, y) { val result = normal.optimize(x, y) {
initialGuess(x to 1.0, y to 1.0) initialGuess(x to 1.0, y to 1.0)
//no need to select optimizer. Gradient optimizer is used by default //no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function
} }
println(result.point) println(result.point)
println(result.value) println(result.value)