forked from kscience/kmath
Initial optimization implementation for CM
This commit is contained in:
parent
94df61cd43
commit
d826dd9e83
@ -8,6 +8,9 @@
|
||||
- Automatic README generation for features (#139)
|
||||
- Native support for `memory`, `core` and `dimensions`
|
||||
- `kmath-ejml` to supply EJML SimpleMatrix wrapper.
|
||||
- A separate `Symbol` entity, which is used for global unbound symbol.
|
||||
- A `Symbol` indexing scope.
|
||||
- Basic optimization API for Commons-math.
|
||||
|
||||
### Changed
|
||||
- Package changed from `scientifik` to `kscience.kmath`.
|
||||
@ -16,6 +19,7 @@
|
||||
- `Polynomial` secondary constructor made function.
|
||||
- Kotlin version: 1.3.72 -> 1.4.20-M1
|
||||
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
|
||||
- Full autodiff refactoring based on `Symbol`
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
@ -0,0 +1,103 @@
|
||||
package kscience.kmath.commons.optimization
|
||||
|
||||
import kscience.kmath.expressions.*
|
||||
import org.apache.commons.math3.optim.*
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||
|
||||
public typealias ParameterSpacePoint = Map<Symbol, Double>
|
||||
|
||||
public class OptimizationResult(public val point: ParameterSpacePoint, public val value: Double)
|
||||
|
||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||
public operator fun PointValuePair.component2(): Double = value
|
||||
|
||||
public object Optimization {
|
||||
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
||||
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
||||
public const val DEFAULT_MAX_ITER: Int = 1000
|
||||
}
|
||||
|
||||
|
||||
private fun SymbolIndexer.objectiveFunction(expression: Expression<Double>) = ObjectiveFunction {
|
||||
val args = it.toMap()
|
||||
expression(args)
|
||||
}
|
||||
|
||||
private fun SymbolIndexer.objectiveFunctionGradient(
|
||||
expression: DifferentiableExpression<Double>,
|
||||
) = ObjectiveFunctionGradient {
|
||||
val args = it.toMap()
|
||||
DoubleArray(symbols.size) { index ->
|
||||
expression.derivative(symbols[index])(args)
|
||||
}
|
||||
}
|
||||
|
||||
private fun SymbolIndexer.initialGuess(point: ParameterSpacePoint) = InitialGuess(point.toArray())
|
||||
|
||||
/**
|
||||
* Optimize expression without derivatives
|
||||
*/
|
||||
public fun Expression<Double>.optimize(
|
||||
startingPoint: ParameterSpacePoint,
|
||||
goalType: GoalType = GoalType.MAXIMIZE,
|
||||
vararg additionalArguments: OptimizationData,
|
||||
optimizerBuilder: () -> MultivariateOptimizer = {
|
||||
SimplexOptimizer(
|
||||
SimpleValueChecker(
|
||||
Optimization.DEFAULT_RELATIVE_TOLERANCE,
|
||||
Optimization.DEFAULT_ABSOLUTE_TOLERANCE,
|
||||
Optimization.DEFAULT_MAX_ITER
|
||||
)
|
||||
)
|
||||
},
|
||||
): OptimizationResult = withSymbols(startingPoint.keys) {
|
||||
val optimizer = optimizerBuilder()
|
||||
val objectiveFunction = objectiveFunction(this@optimize)
|
||||
val (point, value) = optimizer.optimize(
|
||||
objectiveFunction,
|
||||
initialGuess(startingPoint),
|
||||
goalType,
|
||||
MaxEval.unlimited(),
|
||||
NelderMeadSimplex(symbols.size, 1.0),
|
||||
*additionalArguments
|
||||
)
|
||||
OptimizationResult(point.toMap(), value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimize differentiable expression
|
||||
*/
|
||||
public fun DifferentiableExpression<Double>.optimize(
|
||||
startingPoint: ParameterSpacePoint,
|
||||
goalType: GoalType = GoalType.MAXIMIZE,
|
||||
vararg additionalArguments: OptimizationData,
|
||||
optimizerBuilder: () -> NonLinearConjugateGradientOptimizer = {
|
||||
NonLinearConjugateGradientOptimizer(
|
||||
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
||||
SimpleValueChecker(
|
||||
Optimization.DEFAULT_RELATIVE_TOLERANCE,
|
||||
Optimization.DEFAULT_ABSOLUTE_TOLERANCE,
|
||||
Optimization.DEFAULT_MAX_ITER
|
||||
)
|
||||
)
|
||||
},
|
||||
): OptimizationResult = withSymbols(startingPoint.keys) {
|
||||
val optimizer = optimizerBuilder()
|
||||
val objectiveFunction = objectiveFunction(this@optimize)
|
||||
val objectiveGradient = objectiveFunctionGradient(this@optimize)
|
||||
val (point, value) = optimizer.optimize(
|
||||
objectiveFunction,
|
||||
objectiveGradient,
|
||||
initialGuess(startingPoint),
|
||||
goalType,
|
||||
MaxEval.unlimited(),
|
||||
*additionalArguments
|
||||
)
|
||||
OptimizationResult(point.toMap(), value)
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
package kscience.kmath.commons.optimization
|
||||
|
||||
import kscience.kmath.commons.expressions.DerivativeStructureExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kscience.kmath.expressions.symbol
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
internal class OptimizeTest {
|
||||
val x by symbol
|
||||
val y by symbol
|
||||
|
||||
val normal = DerivativeStructureExpression {
|
||||
val x = bind(x)
|
||||
val y = bind(y)
|
||||
exp(-x.pow(2)/2) + exp(-y.pow(2)/2)
|
||||
}
|
||||
|
||||
val startingPoint: Map<Symbol, Double> = mapOf(x to 1.0, y to 1.0)
|
||||
|
||||
@Test
|
||||
fun testOptimization() {
|
||||
val result = normal.optimize(startingPoint)
|
||||
println(result.point)
|
||||
println(result.value)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSimplexOptimization() {
|
||||
val result = (normal as Expression<Double>).optimize(startingPoint){
|
||||
SimplexOptimizer(1e-4,1e-4)
|
||||
}
|
||||
println(result.point)
|
||||
println(result.value)
|
||||
}
|
||||
}
|
@ -35,7 +35,7 @@ public fun interface Expression<T> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Invlode an expression without parameters
|
||||
* Invoke an expression without parameters
|
||||
*/
|
||||
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
||||
//This method exists to avoid resolution ambiguity of vararg methods
|
||||
|
@ -0,0 +1,45 @@
|
||||
package kscience.kmath.expressions
|
||||
|
||||
/**
|
||||
* An environment to easy transform indexed variables to symbols and back.
|
||||
*/
|
||||
public interface SymbolIndexer {
|
||||
public val symbols: List<Symbol>
|
||||
public fun indexOf(symbol: Symbol): Int = symbols.indexOf(symbol)
|
||||
|
||||
public operator fun <T> List<T>.get(symbol: Symbol): T {
|
||||
require(size == symbols.size) { "The input list size for indexer should be ${symbols.size} but $size found" }
|
||||
return get(this@SymbolIndexer.indexOf(symbol))
|
||||
}
|
||||
|
||||
public operator fun <T> Array<T>.get(symbol: Symbol): T {
|
||||
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
|
||||
return get(this@SymbolIndexer.indexOf(symbol))
|
||||
}
|
||||
|
||||
public operator fun DoubleArray.get(symbol: Symbol): Double {
|
||||
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
|
||||
return get(this@SymbolIndexer.indexOf(symbol))
|
||||
}
|
||||
|
||||
public fun DoubleArray.toMap(): Map<Symbol, Double> {
|
||||
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
|
||||
return symbols.indices.associate { symbols[it] to get(it) }
|
||||
}
|
||||
|
||||
|
||||
public fun <T> Map<Symbol, T>.toList(): List<T> = symbols.map { getValue(it) }
|
||||
|
||||
public fun Map<Symbol, Double>.toArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) }
|
||||
}
|
||||
|
||||
public inline class SimpleSymbolIndexer(override val symbols: List<Symbol>) : SymbolIndexer
|
||||
|
||||
/**
|
||||
* Execute the block with symbol indexer based on given symbol order
|
||||
*/
|
||||
public inline fun <R> withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R =
|
||||
with(SimpleSymbolIndexer(symbols.toList()), block)
|
||||
|
||||
public inline fun <R> withSymbols(symbols: Collection<Symbol>, block: SymbolIndexer.() -> R): R =
|
||||
with(SimpleSymbolIndexer(symbols.toList()), block)
|
Loading…
Reference in New Issue
Block a user