Feature/diff api #154

Merged
altavir merged 20 commits from feature/diff-api into dev 2020-10-28 13:25:24 +03:00
4 changed files with 30 additions and 13 deletions
Showing only changes of commit 57781678e5 - Show all commits

View File

@ -36,7 +36,7 @@ public class CMOptimizationProblem(
addOptimizationData(InitialGuess(map.toDoubleArray()))
}
public fun expression(expression: Expression<Double>): Unit {
public override fun expression(expression: Expression<Double>): Unit {
val objectiveFunction = ObjectiveFunction {
val args = it.toMap()
expression(args)
@ -44,7 +44,7 @@ public class CMOptimizationProblem(
addOptimizationData(objectiveFunction)
}
public fun derivatives(expression: DifferentiableExpression<Double>): Unit {
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit {
expression(expression)
val gradientFunction = ObjectiveFunctionGradient {
val args = it.toMap()
@ -83,6 +83,10 @@ public class CMOptimizationProblem(
optimizatorBuilder = block
}
override fun update(result: OptimizationResult<Double>) {
initialGuess(result.point)
}
override fun optimize(): OptimizationResult<Double> {
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())

View File

@ -1,17 +1,32 @@
package kscience.kmath.commons.optimization
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.Symbol
import kotlin.reflect.KClass
public typealias ParameterSpacePoint<T> = Map<Symbol, T>
public interface OptimizationResultFeature
public class OptimizationResult<T>(
public val point: ParameterSpacePoint<T>,
public val point: Map<Symbol, T>,
public val value: T,
public val extra: Map<KClass<*>, Any> = emptyMap()
public val features: Set<OptimizationResultFeature> = emptySet(),
)
/**
* A configuration builder for optimization problem
*/
public interface OptimizationProblem<T : Any> {
/**
* Set an objective function expression
*/
public fun expression(expression: Expression<Double>): Unit
/**
*
*/
public fun diffExpression(expression: DifferentiableExpression<Double>): Unit
public fun update(result: OptimizationResult<T>)
public fun optimize(): OptimizationResult<T>
}

View File

@ -13,7 +13,7 @@ public fun Expression<Double>.optimize(
configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> {
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
problem.expression(this)
return problem.optimize()
}
@ -26,7 +26,7 @@ public fun DifferentiableExpression<Double>.optimize(
configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> {
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
problem.derivatives(this)
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
problem.diffExpression(this)
return problem.optimize()
}

View File

@ -9,16 +9,14 @@ internal class OptimizeTest {
val y by symbol
val normal = DerivativeStructureExpression {
val x = bind(x)
val y = bind(y)
exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2)
exp(-bind(x).pow(2) / 2) + exp(- bind(y).pow(2) / 2)
}
@Test
fun testOptimization() {
val result = normal.optimize(x, y) {
initialGuess(x to 1.0, y to 1.0)
//no need to select optimizer. Gradient optimizer is used by default
//no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function
}
println(result.point)
println(result.value)