Advanced configuration API for cm-optimization
This commit is contained in:
parent
d826dd9e83
commit
1fbe12149d
@ -1 +1,3 @@
|
||||
job("Build") { gradlew("openjdk:11", "build") }
|
||||
job("Build") {
|
||||
gradlew("openjdk:11", "build")
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ public class DerivativeStructureExpression(
|
||||
/**
|
||||
* Get the derivative expression with given orders
|
||||
*/
|
||||
public override fun derivative(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
||||
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,100 @@
|
||||
package kscience.kmath.commons.optimization
|
||||
|
||||
import kscience.kmath.expressions.*
|
||||
import org.apache.commons.math3.optim.*
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||
public operator fun PointValuePair.component2(): Double = value
|
||||
|
||||
public class CMOptimizationProblem(
|
||||
override val symbols: List<Symbol>,
|
||||
) : OptimizationProblem<Double>, SymbolIndexer {
|
||||
protected val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
||||
|
||||
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
||||
DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER)
|
||||
|
||||
private fun addOptimizationData(data: OptimizationData) {
|
||||
optimizationData[data::class] = data
|
||||
}
|
||||
|
||||
init {
|
||||
addOptimizationData(MaxEval.unlimited())
|
||||
}
|
||||
|
||||
public fun initialGuess(map: Map<Symbol, Double>): Unit {
|
||||
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||
}
|
||||
|
||||
public fun expression(expression: Expression<Double>): Unit {
|
||||
val objectiveFunction = ObjectiveFunction {
|
||||
val args = it.toMap()
|
||||
expression(args)
|
||||
}
|
||||
addOptimizationData(objectiveFunction)
|
||||
}
|
||||
|
||||
public fun derivatives(expression: DifferentiableExpression<Double>): Unit {
|
||||
expression(expression)
|
||||
val gradientFunction = ObjectiveFunctionGradient {
|
||||
val args = it.toMap()
|
||||
DoubleArray(symbols.size) { index ->
|
||||
expression.derivative(symbols[index])(args)
|
||||
}
|
||||
}
|
||||
addOptimizationData(gradientFunction)
|
||||
if (optimizatorBuilder == null) {
|
||||
optimizatorBuilder = {
|
||||
NonLinearConjugateGradientOptimizer(
|
||||
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
||||
convergenceChecker
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public fun simplex(simplex: AbstractSimplex) {
|
||||
addOptimizationData(simplex)
|
||||
//Set optimization builder to simplex if it is not present
|
||||
if (optimizatorBuilder == null) {
|
||||
optimizatorBuilder = { SimplexOptimizer(convergenceChecker) }
|
||||
}
|
||||
}
|
||||
|
||||
public fun simplexSteps(steps: Map<Symbol, Double>) {
|
||||
simplex(NelderMeadSimplex(steps.toDoubleArray()))
|
||||
}
|
||||
|
||||
public fun goal(goalType: GoalType) {
|
||||
addOptimizationData(goalType)
|
||||
}
|
||||
|
||||
public fun optimizer(block: () -> MultivariateOptimizer) {
|
||||
optimizatorBuilder = block
|
||||
}
|
||||
|
||||
override fun optimize(): OptimizationResult<Double> {
|
||||
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
|
||||
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
||||
return OptimizationResult(point.toMap(), value)
|
||||
}
|
||||
|
||||
public companion object {
|
||||
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
||||
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
||||
public const val DEFAULT_MAX_ITER: Int = 1000
|
||||
}
|
||||
}
|
||||
|
||||
public fun CMOptimizationProblem.initialGuess(vararg pairs: Pair<Symbol, Double>): Unit = initialGuess(pairs.toMap())
|
||||
public fun CMOptimizationProblem.simplexSteps(vararg pairs: Pair<Symbol, Double>): Unit = simplexSteps(pairs.toMap())
|
@ -0,0 +1,17 @@
|
||||
package kscience.kmath.commons.optimization
|
||||
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
public typealias ParameterSpacePoint<T> = Map<Symbol, T>
|
||||
|
||||
public class OptimizationResult<T>(
|
||||
public val point: ParameterSpacePoint<T>,
|
||||
public val value: T,
|
||||
public val extra: Map<KClass<*>, Any> = emptyMap()
|
||||
)
|
||||
|
||||
public interface OptimizationProblem<T : Any> {
|
||||
public fun optimize(): OptimizationResult<T>
|
||||
}
|
||||
|
@ -1,103 +1,32 @@
|
||||
package kscience.kmath.commons.optimization
|
||||
|
||||
import kscience.kmath.expressions.*
|
||||
import org.apache.commons.math3.optim.*
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||
import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
|
||||
public typealias ParameterSpacePoint = Map<Symbol, Double>
|
||||
|
||||
public class OptimizationResult(public val point: ParameterSpacePoint, public val value: Double)
|
||||
|
||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||
public operator fun PointValuePair.component2(): Double = value
|
||||
|
||||
public object Optimization {
|
||||
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
||||
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
||||
public const val DEFAULT_MAX_ITER: Int = 1000
|
||||
}
|
||||
|
||||
|
||||
private fun SymbolIndexer.objectiveFunction(expression: Expression<Double>) = ObjectiveFunction {
|
||||
val args = it.toMap()
|
||||
expression(args)
|
||||
}
|
||||
|
||||
private fun SymbolIndexer.objectiveFunctionGradient(
|
||||
expression: DifferentiableExpression<Double>,
|
||||
) = ObjectiveFunctionGradient {
|
||||
val args = it.toMap()
|
||||
DoubleArray(symbols.size) { index ->
|
||||
expression.derivative(symbols[index])(args)
|
||||
}
|
||||
}
|
||||
|
||||
private fun SymbolIndexer.initialGuess(point: ParameterSpacePoint) = InitialGuess(point.toArray())
|
||||
|
||||
/**
|
||||
* Optimize expression without derivatives
|
||||
*/
|
||||
public fun Expression<Double>.optimize(
|
||||
startingPoint: ParameterSpacePoint,
|
||||
goalType: GoalType = GoalType.MAXIMIZE,
|
||||
vararg additionalArguments: OptimizationData,
|
||||
optimizerBuilder: () -> MultivariateOptimizer = {
|
||||
SimplexOptimizer(
|
||||
SimpleValueChecker(
|
||||
Optimization.DEFAULT_RELATIVE_TOLERANCE,
|
||||
Optimization.DEFAULT_ABSOLUTE_TOLERANCE,
|
||||
Optimization.DEFAULT_MAX_ITER
|
||||
)
|
||||
)
|
||||
},
|
||||
): OptimizationResult = withSymbols(startingPoint.keys) {
|
||||
val optimizer = optimizerBuilder()
|
||||
val objectiveFunction = objectiveFunction(this@optimize)
|
||||
val (point, value) = optimizer.optimize(
|
||||
objectiveFunction,
|
||||
initialGuess(startingPoint),
|
||||
goalType,
|
||||
MaxEval.unlimited(),
|
||||
NelderMeadSimplex(symbols.size, 1.0),
|
||||
*additionalArguments
|
||||
)
|
||||
OptimizationResult(point.toMap(), value)
|
||||
vararg symbols: Symbol,
|
||||
configuration: CMOptimizationProblem.() -> Unit,
|
||||
): OptimizationResult<Double> {
|
||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
|
||||
problem.expression(this)
|
||||
return problem.optimize()
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimize differentiable expression
|
||||
*/
|
||||
public fun DifferentiableExpression<Double>.optimize(
|
||||
startingPoint: ParameterSpacePoint,
|
||||
goalType: GoalType = GoalType.MAXIMIZE,
|
||||
vararg additionalArguments: OptimizationData,
|
||||
optimizerBuilder: () -> NonLinearConjugateGradientOptimizer = {
|
||||
NonLinearConjugateGradientOptimizer(
|
||||
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
||||
SimpleValueChecker(
|
||||
Optimization.DEFAULT_RELATIVE_TOLERANCE,
|
||||
Optimization.DEFAULT_ABSOLUTE_TOLERANCE,
|
||||
Optimization.DEFAULT_MAX_ITER
|
||||
)
|
||||
)
|
||||
},
|
||||
): OptimizationResult = withSymbols(startingPoint.keys) {
|
||||
val optimizer = optimizerBuilder()
|
||||
val objectiveFunction = objectiveFunction(this@optimize)
|
||||
val objectiveGradient = objectiveFunctionGradient(this@optimize)
|
||||
val (point, value) = optimizer.optimize(
|
||||
objectiveFunction,
|
||||
objectiveGradient,
|
||||
initialGuess(startingPoint),
|
||||
goalType,
|
||||
MaxEval.unlimited(),
|
||||
*additionalArguments
|
||||
)
|
||||
OptimizationResult(point.toMap(), value)
|
||||
vararg symbols: Symbol,
|
||||
configuration: CMOptimizationProblem.() -> Unit,
|
||||
): OptimizationResult<Double> {
|
||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
|
||||
problem.derivatives(this)
|
||||
return problem.optimize()
|
||||
}
|
@ -1,10 +1,7 @@
|
||||
package kscience.kmath.commons.optimization
|
||||
|
||||
import kscience.kmath.commons.expressions.DerivativeStructureExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kscience.kmath.expressions.symbol
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
internal class OptimizeTest {
|
||||
@ -17,19 +14,22 @@ internal class OptimizeTest {
|
||||
exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2)
|
||||
}
|
||||
|
||||
val startingPoint: Map<Symbol, Double> = mapOf(x to 1.0, y to 1.0)
|
||||
|
||||
@Test
|
||||
fun testOptimization() {
|
||||
val result = normal.optimize(startingPoint)
|
||||
val result = normal.optimize(x, y) {
|
||||
initialGuess(x to 1.0, y to 1.0)
|
||||
//no need to select optimizer. Gradient optimizer is used by default
|
||||
}
|
||||
println(result.point)
|
||||
println(result.value)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSimplexOptimization() {
|
||||
val result = (normal as Expression<Double>).optimize(startingPoint){
|
||||
SimplexOptimizer(1e-4,1e-4)
|
||||
val result = normal.optimize(x, y) {
|
||||
initialGuess(x to 1.0, y to 1.0)
|
||||
simplexSteps(x to 2.0, y to 0.5)
|
||||
//this sets simplex optimizer
|
||||
}
|
||||
println(result.point)
|
||||
println(result.value)
|
||||
|
@ -4,9 +4,15 @@ package kscience.kmath.expressions
|
||||
* And object that could be differentiated
|
||||
*/
|
||||
public interface Differentiable<T> {
|
||||
public fun derivative(orders: Map<Symbol, Int>): T
|
||||
public fun derivativeOrNull(orders: Map<Symbol, Int>): T?
|
||||
}
|
||||
|
||||
public fun <T> Differentiable<T>.derivative(orders: Map<Symbol, Int>): T =
|
||||
derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided")
|
||||
|
||||
/**
|
||||
* An expression that provid
|
||||
*/
|
||||
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||
@ -14,8 +20,19 @@ public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||
derivative(StringSymbol(name) to 1)
|
||||
|
||||
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
|
||||
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
|
||||
//}
|
||||
|
||||
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
||||
|
||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||
|
||||
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>? {
|
||||
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
|
||||
return derivativeOrNull(dSymbol)
|
||||
}
|
||||
}
|
@ -221,23 +221,16 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||
public val field: F,
|
||||
public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
) : DifferentiableExpression<T> {
|
||||
) : FirstDerivativeExpression<T>() {
|
||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||
return AutoDiffContext(field, arguments).function().value
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the derivative expression with given orders
|
||||
*/
|
||||
public override fun derivative(orders: Map<Symbol, Int>): Expression<T> {
|
||||
val dSymbol = orders.entries.singleOrNull { it.value == 1 }
|
||||
?: error("SimpleAutoDiff supports only first order derivatives")
|
||||
return Expression { arguments ->
|
||||
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||
val derivationResult = AutoDiffContext(field, arguments).derivate(function)
|
||||
derivationResult.derivative(dSymbol.key)
|
||||
}
|
||||
derivationResult.derivative(symbol)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,12 @@
|
||||
package kscience.kmath.expressions
|
||||
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.structures.BufferFactory
|
||||
import kscience.kmath.structures.Structure2D
|
||||
|
||||
/**
|
||||
* An environment to easy transform indexed variables to symbols and back.
|
||||
* TODO requires multi-receivers to be beutiful
|
||||
*/
|
||||
public interface SymbolIndexer {
|
||||
public val symbols: List<Symbol>
|
||||
@ -22,15 +27,26 @@ public interface SymbolIndexer {
|
||||
return get(this@SymbolIndexer.indexOf(symbol))
|
||||
}
|
||||
|
||||
public operator fun <T> Point<T>.get(symbol: Symbol): T {
|
||||
require(size == symbols.size) { "The input buffer size for indexer should be ${symbols.size} but $size found" }
|
||||
return get(this@SymbolIndexer.indexOf(symbol))
|
||||
}
|
||||
|
||||
public fun DoubleArray.toMap(): Map<Symbol, Double> {
|
||||
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
|
||||
return symbols.indices.associate { symbols[it] to get(it) }
|
||||
}
|
||||
|
||||
public operator fun <T> Structure2D<T>.get(rowSymbol: Symbol, columnSymbol: Symbol): T =
|
||||
get(indexOf(rowSymbol), indexOf(columnSymbol))
|
||||
|
||||
|
||||
public fun <T> Map<Symbol, T>.toList(): List<T> = symbols.map { getValue(it) }
|
||||
|
||||
public fun Map<Symbol, Double>.toArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) }
|
||||
public fun <T> Map<Symbol, T>.toPoint(bufferFactory: BufferFactory<T>): Point<T> =
|
||||
bufferFactory(symbols.size) { getValue(symbols[it]) }
|
||||
|
||||
public fun Map<Symbol, Double>.toDoubleArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) }
|
||||
}
|
||||
|
||||
public inline class SimpleSymbolIndexer(override val symbols: List<Symbol>) : SymbolIndexer
|
||||
|
Loading…
Reference in New Issue
Block a user