Compare commits

...

1 Commits

Author SHA1 Message Date
4e34d98dc0 Annealing optimization 2024-05-22 12:33:05 +03:00
4 changed files with 410 additions and 0 deletions

View File

@ -0,0 +1,45 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.stat
import kotlinx.coroutines.runBlocking
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.invoke
import space.kscience.plotly.Plotly
import space.kscience.plotly.ResourceLocation
import space.kscience.plotly.makeFile
import space.kscience.plotly.models.ScatterMode
import space.kscience.plotly.scatter
fun main() {
fun schaffer(x: Double) = (Float64Field) {
if (x in -5.0..5.0) 0.5 + ((sin(x pow 2) pow 2) - 0.5) / (1.0 + 0.01 * (x pow 2))
else Double.POSITIVE_INFINITY
}
val optimizer = DoubleField.annealing(targetTemperature = 0.0, maxIterations = 150)
val optimum = runBlocking {
optimizer.minimize(1.0, ::schaffer)
}
println(optimum)
val xSpace = DoubleArray(200) { -5.0 + it / 20.0 }
Plotly.plot {
scatter {
name = "Function"
x.doubles = xSpace
y.doubles = DoubleArray(200) { schaffer(xSpace[it]) }
}
scatter {
name = "Optimum"
x.numbers = listOf(optimum.x)
y.numbers = listOf(optimum.y)
mode = ScatterMode.markers
marker.color("red")
}
}.makeFile(resourceLocation = ResourceLocation.REMOTE)
}

View File

@ -17,6 +17,12 @@ kotlin.sourceSets {
}
}
commonTest {
dependencies {
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.8.1")
}
}
getByName("jvmMain") {
dependencies {
api("org.apache.commons:commons-rng-sampling:1.3")

View File

@ -0,0 +1,251 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.stat
import space.kscience.kmath.distributions.UniformDistribution
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.samplers.GaussianSampler
import kotlin.math.exp
/**
* Temperature schedule - function that gives temperature on a specific iteration.
*/
public fun interface Temperature<T: Comparable<T>> {
/**
* Function that gives temperature on a specific iteration.
*/
public operator fun invoke(iteration: Int): T
}
/**
* Condition that determines whether a new point should be selected in case it does not
* deliver a more optimal value that the current point.
*/
public fun interface Condition<T: Comparable<T>> {
/**
* Whether a point should be selected as a new point.
*
* @param currentValue function value in the current point
* @param proposedValue function value in the point proposed to be chosen as the new point
* @param temperature temperature on the current iteration
*/
public suspend operator fun invoke(currentValue: T, proposedValue: T, temperature: T): Boolean
}
/**
* Result of optimization
*
*
*/
public data class OptimizationResult<X: Comparable<X>, Y: Comparable<Y>>(
val y: Y,
val x: X,
val iterations: Int,
val temperature: Y
)
private data class AnnealingPoint<X: Comparable<X>, Y: Comparable<Y>>(val x: X, val y: Y)
/**
* Simulated annealing - stochastic function optimization.
*
* @param targetTemperature temperature required to reach to stop the optimization
* @param temperature temperature decay law used for optimization. By default,
* an exponential law is used.
* @param condition
* @param sampler sampler used to draw deltas from. By default, an appropriate Gaussian distribution
* should be used
* @param generator
* @param maxIterations
* @param samplingStride
* @param algebra
*/
public class Annealing<X: Comparable<X>, Y: Comparable<Y>>(
private val targetTemperature: Y,
private val temperature: Temperature<Y>,
private val condition: Condition<Y>,
sampler: Sampler<X>,
generator: RandomGenerator,
private val maxIterations: Int? = null,
private val samplingStride: Int = 1,
private val algebra: GroupOps<X>
) {
init {
require(samplingStride >= 1)
maxIterations?.let {
require(maxIterations > 0)
}
}
private val samples = sampler.sample(generator)
/**
* A single optimization iteration
*/
private tailrec suspend fun optimizeInner(
f: (X) -> Y,
iteration: Int,
current: AnnealingPoint<X, Y>,
best: AnnealingPoint<X, Y>
): OptimizationResult<X, Y> {
val temp = temperature(iteration)
maxIterations?.let {
if (iteration >= maxIterations) return OptimizationResult(current.y, current.x, iteration, temp)
}
if (temp <= targetTemperature) return OptimizationResult(current.y, current.x, iteration, temp)
var xProposed = current.x
repeat(samplingStride) { xProposed = (algebra) { samples.next() + xProposed } }
val yProposed = f(xProposed)
val next = when {
// In case a condition different from the standard Gibbs is used
yProposed < current.y -> AnnealingPoint(xProposed, f(xProposed))
(condition(current.y, yProposed, temp)) -> AnnealingPoint(xProposed, f(xProposed))
else -> current
}
val newBest = if (next.y < best.y) next else best
return optimizeInner(f, iteration + 1, next, newBest)
}
/**
* Minimize function using parameters defined when the optimizer is constructed.
*
* @param x0 initial point for optimization
* @param function objective function
*
* @return optimization result, see documentation entry for [OptimizationResult]
*/
public suspend fun minimize(x0: X, function: (X) -> Y): OptimizationResult<X, Y> {
val initial = AnnealingPoint(x0, function(x0))
return optimizeInner(function, iteration = 0, initial, best = initial)
}
}
/**
* Retrieve an [Annealing] optimizer for this algebra.
*
* @param targetTemperature temperature required to reach to stop the optimization
* @param maxIterations (optional) number of iterations required to stop the algorithm.
* If `null`, [targetTemperature] solely is used as a stoppage criterion.
* @param samplingStride
* @param temperature temperature decay function used for optimization. By default,
* an exponential law is used.
* @param condition
* @param sampler sampler used to draw deltas from. By default, an appropriate Gaussian distribution
* should be used
* @param generator (optional) random generator used by the sampler.
* By default, [RandomGenerator.default] is used.
*/
public fun <X: Comparable<X>, Y: Comparable<Y>> GroupOps<X>.annealing(
targetTemperature: Y,
maxIterations: Int? = null,
samplingStride: Int = 1,
temperature: Temperature<Y>,
condition: Condition<Y>,
sampler: Sampler<X>,
generator: RandomGenerator = RandomGenerator.default,
): Annealing<X, Y> = Annealing(
targetTemperature,
temperature,
condition,
sampler,
generator,
maxIterations,
samplingStride,
algebra = this
)
/**
* Retrieve an [Annealing] optimizer for Double field.
*
* @param targetTemperature temperature required to reach to stop the optimization
* @param maxIterations (optional) number of iterations required to stop the algorithm.
* If `null`, [targetTemperature] solely is used as a stoppage criterion.
* @param samplingStride
* @param temperature temperature decay function used for optimization. By default,
* an exponential law is used.
* @param condition
* @param sampler sampler used to draw deltas from. By default, an appropriate Gaussian distribution
* should be used
* @param generator (optional) random generator used by the sampler.
* By default, [RandomGenerator.default] is used.
*/
public fun GroupOps<Double>.annealing(
targetTemperature: Double,
maxIterations: Int? = null,
samplingStride: Int = 1,
temperature: Temperature<Double> = BoltzmannTemp(),
condition: Condition<Double> = GibbsCondition(),
sampler: Sampler<Double> = GaussianSampler(0.0, 1.0),
generator: RandomGenerator = RandomGenerator.default,
): Annealing<Double, Double> = Annealing(
targetTemperature,
temperature,
condition,
sampler,
generator,
maxIterations,
samplingStride,
algebra = this
)
/**
* Temperature schedule that uses the following formula:
* `tk = t0 - slope * k`,
* where `k` is the iteration number
*
* @param t0 Initial temperature (`k = 0`)
* @param slope Slope of the temperature decay line
*/
public class LinearTemp(
private val t0: Double = 1000.0,
private val slope: Double = 1.0,
): Temperature<Double> {
init {
require(slope > 0.0)
}
override fun invoke(iteration: Int): Double = t0 - slope * iteration.toDouble()
}
/**
* Temperature schedule that uses the following formula:
* `tk = t0 / exp(rate * k)`,
* where `k` is the iteration number.
*
* @param t0 Initial temperature (`k = 0`)
* @param rate Temperature decay rate
*/
public class BoltzmannTemp(
private val t0: Double = 1000.0,
private val rate: Double = 1.0
): Temperature<Double> {
init {
require(t0 > 0.0)
require(rate > 0.0)
}
override fun invoke(iteration: Int): Double = t0 / exp(iteration * rate)
}
/**
* Jump condition that uses Gibbs distribution formula and uniform sampling to determine
* the next point. If value `exp(-(y{k} - y{k-1}) / T)` is greater than the sampled value,
* the new point is taken.
*/
public class GibbsCondition: Condition<Double> {
private val distribution = UniformDistribution(0.0..1.0)
override suspend fun invoke(
currentValue: Double,
proposedValue: Double,
temperature: Double
): Boolean = (Float64Field) {
exp(-(proposedValue - currentValue) / temperature) > distribution.next(RandomGenerator.default)
}
}

View File

@ -0,0 +1,108 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.stat
import kotlinx.coroutines.test.runTest
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.Float64Field.cos
import space.kscience.kmath.operations.invoke
import kotlin.math.E
import kotlin.math.PI
import kotlin.math.exp
import kotlin.math.sqrt
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertTrue
class TestAnnealing {
companion object {
fun rastrigin(x: Double, a: Double = 10.0) =
if (x in -5.0..5.0) a + x * x - a * cos(2 * PI * x) else Double.POSITIVE_INFINITY
fun ackley(x: Double) = if (x in -5.0..5.0) {
-20.0 * exp(-0.2 * sqrt(0.5 * x * x)) - exp(2 * cos(2 * PI * x)) + E + 20.0
} else {
Double.POSITIVE_INFINITY
}
fun schaffer(x: Double) = (Float64Field) {
if (x in -5.0..5.0) 0.5 + ((sin(x pow 2) pow 2) - 0.5) / (1.0 + 0.01 * (x pow 2))
else Double.POSITIVE_INFINITY
}
fun melon(x: Double) =
if (x in -4.0..4.0) (x - 1.0) * (x + 3.0) * (x - 3.0) * (x + 2.0) else Double.POSITIVE_INFINITY
suspend fun countSuccessRate(
function: (Double) -> Double,
objective: Double = 0.0,
tolerance: Double = 0.1,
iterations: Int = 150
): Int = (Float64Field) {
val optimizer = annealing(
targetTemperature = 0.0,
maxIterations = iterations,
temperature = BoltzmannTemp(),
condition = GibbsCondition()
)
val initial = { ((Random.nextDouble() - 0.5) * 6.0).coerceIn(-5.0..5.0) }
val results = List(100) { optimizer.minimize(initial()) { x -> function(x) }.x }
return results.count { it in objective - tolerance .. objective + tolerance }
}
}
@Test
fun testRastrigin() = runTest {
val successRate = countSuccessRate(::rastrigin)
println("Rastrigin success rate: $successRate%")
assertTrue { successRate >= 98 }
}
@Test
fun testAckley() = runTest {
val successRate = countSuccessRate(::ackley)
println("Ackley success rate: $successRate%")
assertTrue { successRate >= 98 }
}
@Test
fun testSchaffer() = runTest {
val successRate = countSuccessRate(::schaffer, iterations = 180)
println("Schaffer success rate: $successRate%")
assertTrue { successRate >= 98 }
}
@Test
fun testTailCallOptimization() = runTest {
(Float64Field) {
val optimizer = annealing(
targetTemperature = -1.0,
maxIterations = 1000000,
temperature = BoltzmannTemp(),
condition = GibbsCondition()
)
val result = optimizer.minimize(3.0) { x -> rastrigin(x) }
println("Sustained ${result.iterations} iterations without stack overflow")
}
}
@Test
fun testGreediness() = runTest {
(Float64Field) {
val optimizer = annealing(
targetTemperature = -1.0,
maxIterations = 100,
temperature = { 0.0 },
condition = GibbsCondition()
)
val results = List(100) { optimizer.minimize(-3.2, ::melon).x }
val rate = results.count { it in -2.6..-2.5}
println("Greediness rate: $rate")
assertTrue { rate >= 95 }
}
}
}