[WIP] optimization refactor in process
This commit is contained in:
parent
257337f4fb
commit
7b6361e59d
@ -11,10 +11,8 @@ import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer
|
|||||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
|
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
|
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex
|
import space.kscience.kmath.expressions.derivative
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
|
import space.kscience.kmath.expressions.withSymbols
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
|
||||||
import space.kscience.kmath.expressions.*
|
|
||||||
import space.kscience.kmath.misc.Symbol
|
import space.kscience.kmath.misc.Symbol
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.optimization.*
|
import space.kscience.kmath.optimization.*
|
||||||
@ -24,107 +22,73 @@ public operator fun PointValuePair.component1(): DoubleArray = point
|
|||||||
public operator fun PointValuePair.component2(): Double = value
|
public operator fun PointValuePair.component2(): Double = value
|
||||||
|
|
||||||
public class CMOptimizerFactory(public val optimizerBuilder: () -> MultivariateOptimizer) : OptimizationFeature
|
public class CMOptimizerFactory(public val optimizerBuilder: () -> MultivariateOptimizer) : OptimizationFeature
|
||||||
//public class CMOptimizerData(public val )
|
public class CMOptimizerData(public val data: List<OptimizationData>) : OptimizationFeature {
|
||||||
|
public constructor(vararg data: OptimizationData) : this(data.toList())
|
||||||
|
}
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public class CMOptimization : Optimizer<FunctionOptimization<Double>> {
|
public class CMOptimization : Optimizer<FunctionOptimization<Double>> {
|
||||||
|
|
||||||
override suspend fun process(
|
override suspend fun process(
|
||||||
problem: FunctionOptimization<Double>
|
problem: FunctionOptimization<Double>,
|
||||||
): FunctionOptimization<Double> = withSymbols(problem.parameters) {
|
): FunctionOptimization<Double> = withSymbols(problem.parameters) {
|
||||||
val cmOptimizer: MultivariateOptimizer =
|
|
||||||
problem.getFeature<CMOptimizerFactory>()?.optimizerBuilder?.invoke() ?: SimplexOptimizer()
|
|
||||||
|
|
||||||
val convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(
|
val convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(
|
||||||
DEFAULT_RELATIVE_TOLERANCE,
|
DEFAULT_RELATIVE_TOLERANCE,
|
||||||
DEFAULT_ABSOLUTE_TOLERANCE,
|
DEFAULT_ABSOLUTE_TOLERANCE,
|
||||||
DEFAULT_MAX_ITER
|
DEFAULT_MAX_ITER
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val cmOptimizer: MultivariateOptimizer = problem.getFeature<CMOptimizerFactory>()?.optimizerBuilder?.invoke()
|
||||||
|
?: NonLinearConjugateGradientOptimizer(
|
||||||
|
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
||||||
|
convergenceChecker
|
||||||
|
)
|
||||||
|
|
||||||
val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||||
|
|
||||||
fun addOptimizationData(data: OptimizationData) {
|
fun addOptimizationData(data: OptimizationData) {
|
||||||
optimizationData[data::class] = data
|
optimizationData[data::class] = data
|
||||||
}
|
}
|
||||||
|
|
||||||
addOptimizationData(MaxEval.unlimited())
|
addOptimizationData(MaxEval.unlimited())
|
||||||
addOptimizationData(InitialGuess(problem.initialGuess.toDoubleArray()))
|
addOptimizationData(InitialGuess(problem.initialGuess.toDoubleArray()))
|
||||||
|
|
||||||
fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Register no-deriv function instead of differentiable function
|
|
||||||
*/
|
|
||||||
/**
|
|
||||||
* Register no-deriv function instead of differentiable function
|
|
||||||
*/
|
|
||||||
fun noDerivFunction(expression: Expression<Double>): Unit {
|
|
||||||
val objectiveFunction = ObjectiveFunction {
|
val objectiveFunction = ObjectiveFunction {
|
||||||
val args = problem.initialGuess + it.toMap()
|
val args = problem.initialGuess + it.toMap()
|
||||||
expression(args)
|
problem.expression(args)
|
||||||
}
|
}
|
||||||
addOptimizationData(objectiveFunction)
|
addOptimizationData(objectiveFunction)
|
||||||
}
|
|
||||||
|
|
||||||
public override fun function(expression: DifferentiableExpression<Double, Expression<Double>>) {
|
|
||||||
noDerivFunction(expression)
|
|
||||||
val gradientFunction = ObjectiveFunctionGradient {
|
val gradientFunction = ObjectiveFunctionGradient {
|
||||||
val args = startingPoint + it.toMap()
|
val args = problem.initialGuess + it.toMap()
|
||||||
DoubleArray(symbols.size) { index ->
|
DoubleArray(symbols.size) { index ->
|
||||||
expression.derivative(symbols[index])(args)
|
problem.expression.derivative(symbols[index])(args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
addOptimizationData(gradientFunction)
|
addOptimizationData(gradientFunction)
|
||||||
if (optimizerBuilder == null) {
|
|
||||||
optimizerBuilder = {
|
val logger = problem.getFeature<OptimizationLog>()
|
||||||
NonLinearConjugateGradientOptimizer(
|
|
||||||
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
for (feature in problem.features) {
|
||||||
convergenceChecker
|
when (feature) {
|
||||||
)
|
is CMOptimizerData -> feature.data.forEach { addOptimizationData(it) }
|
||||||
|
is FunctionOptimizationTarget -> when(feature){
|
||||||
|
FunctionOptimizationTarget.MAXIMIZE -> addOptimizationData(GoalType.MAXIMIZE)
|
||||||
|
FunctionOptimizationTarget.MINIMIZE -> addOptimizationData(GoalType.MINIMIZE)
|
||||||
}
|
}
|
||||||
|
else -> logger?.log { "The feature $feature is unused in optimization" }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun simplex(simplex: AbstractSimplex) {
|
val (point, value) = cmOptimizer.optimize(*optimizationData.values.toTypedArray())
|
||||||
addOptimizationData(simplex)
|
return problem.withFeatures(FunctionOptimizationResult(point.toMap(), value))
|
||||||
//Set optimization builder to simplex if it is not present
|
|
||||||
if (optimizerBuilder == null) {
|
|
||||||
optimizerBuilder = { SimplexOptimizer(convergenceChecker) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun simplexSteps(steps: Map<Symbol, Double>) {
|
public companion object {
|
||||||
simplex(NelderMeadSimplex(steps.toDoubleArray()))
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun goal(goalType: GoalType) {
|
|
||||||
addOptimizationData(goalType)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun optimizer(block: () -> MultivariateOptimizer) {
|
|
||||||
optimizerBuilder = block
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun update(result: OptimizationResult<Double>) {
|
|
||||||
initialGuess(result.point)
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun optimize(): OptimizationResult<Double> {
|
|
||||||
val optimizer = optimizerBuilder?.invoke() ?: error("Optimizer not defined")
|
|
||||||
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
|
||||||
return OptimizationResult(point.toMap(), value)
|
|
||||||
}
|
|
||||||
return@withSymbols TODO()
|
|
||||||
}
|
|
||||||
|
|
||||||
public companion object : OptimizationProblemFactory<Double, CMOptimization> {
|
|
||||||
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
||||||
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
||||||
public const val DEFAULT_MAX_ITER: Int = 1000
|
public const val DEFAULT_MAX_ITER: Int = 1000
|
||||||
|
|
||||||
override fun build(symbols: List<Symbol>): CMOptimization = CMOptimization(symbols)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun CMOptimization.initialGuess(vararg pairs: Pair<Symbol, Double>): Unit = initialGuess(pairs.toMap())
|
|
||||||
public fun CMOptimization.simplexSteps(vararg pairs: Pair<Symbol, Double>): Unit = simplexSteps(pairs.toMap())
|
|
||||||
|
@ -17,22 +17,22 @@ import space.kscience.kmath.structures.asBuffer
|
|||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
*/
|
*/
|
||||||
public fun FunctionOptimization.Companion.chiSquared(
|
public fun FunctionOptimization.Companion.chiSquaredExpression(
|
||||||
x: Buffer<Double>,
|
x: Buffer<Double>,
|
||||||
y: Buffer<Double>,
|
y: Buffer<Double>,
|
||||||
yErr: Buffer<Double>,
|
yErr: Buffer<Double>,
|
||||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
): DifferentiableExpression<Double, Expression<Double>> = chiSquaredExpression(DerivativeStructureField, x, y, yErr, model)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
*/
|
*/
|
||||||
public fun FunctionOptimization.Companion.chiSquared(
|
public fun FunctionOptimization.Companion.chiSquaredExpression(
|
||||||
x: Iterable<Double>,
|
x: Iterable<Double>,
|
||||||
y: Iterable<Double>,
|
y: Iterable<Double>,
|
||||||
yErr: Iterable<Double>,
|
yErr: Iterable<Double>,
|
||||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(
|
): DifferentiableExpression<Double, Expression<Double>> = chiSquaredExpression(
|
||||||
DerivativeStructureField,
|
DerivativeStructureField,
|
||||||
x.toList().asBuffer(),
|
x.toList().asBuffer(),
|
||||||
y.toList().asBuffer(),
|
y.toList().asBuffer(),
|
||||||
|
@ -17,7 +17,7 @@ public interface Featured<F : Any> {
|
|||||||
/**
|
/**
|
||||||
* A container for a set of features
|
* A container for a set of features
|
||||||
*/
|
*/
|
||||||
public class FeatureSet<F : Any> private constructor(public val features: Map<KClass<out F>, Any>) : Featured<F> {
|
public class FeatureSet<F : Any> private constructor(public val features: Map<KClass<out F>, F>) : Featured<F> {
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun <T : F> getFeature(type: KClass<out T>): T? = features[type] as? T
|
override fun <T : F> getFeature(type: KClass<out T>): T? = features[type] as? T
|
||||||
|
|
||||||
@ -31,6 +31,8 @@ public class FeatureSet<F : Any> private constructor(public val features: Map<KC
|
|||||||
public fun with(vararg otherFeatures: F): FeatureSet<F> =
|
public fun with(vararg otherFeatures: F): FeatureSet<F> =
|
||||||
FeatureSet(features + otherFeatures.associateBy { it::class })
|
FeatureSet(features + otherFeatures.associateBy { it::class })
|
||||||
|
|
||||||
|
public operator fun iterator(): Iterator<F> = features.values.iterator()
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public fun <F : Any> of(vararg features: F): FeatureSet<F> = FeatureSet(features.associateBy { it::class })
|
public fun <F : Any> of(vararg features: F): FeatureSet<F> = FeatureSet(features.associateBy { it::class })
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.misc
|
||||||
|
|
||||||
|
public interface Loggable {
|
||||||
|
public fun log(tag: String = INFO, block: () -> String)
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
public const val INFO: String = "INFO"
|
||||||
|
}
|
||||||
|
}
|
@ -15,91 +15,57 @@ import space.kscience.kmath.operations.ExtendedField
|
|||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.indices
|
import space.kscience.kmath.structures.indices
|
||||||
|
|
||||||
|
public class FunctionOptimizationResult<T>(point: Map<Symbol, T>, public val value: T) : OptimizationResult<T>(point)
|
||||||
|
|
||||||
public class FunctionOptimization<T : Any>(
|
public enum class FunctionOptimizationTarget : OptimizationFeature {
|
||||||
|
MAXIMIZE,
|
||||||
|
MINIMIZE
|
||||||
|
}
|
||||||
|
|
||||||
|
public class FunctionOptimization<T>(
|
||||||
override val features: FeatureSet<OptimizationFeature>,
|
override val features: FeatureSet<OptimizationFeature>,
|
||||||
public val expression: DifferentiableExpression<T, Expression<T>>,
|
public val expression: DifferentiableExpression<T, Expression<T>>,
|
||||||
public val initialGuess: Map<Symbol, T>,
|
public val initialGuess: Map<Symbol, T>,
|
||||||
public val parameters: Collection<Symbol>,
|
public val parameters: Collection<Symbol>,
|
||||||
public val maximize: Boolean,
|
) : OptimizationProblem{
|
||||||
) : OptimizationProblem
|
public companion object{
|
||||||
|
/**
|
||||||
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
|
*/
|
||||||
|
public fun <T : Any, I : Any, A> chiSquaredExpression(
|
||||||
|
autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
||||||
|
x: Buffer<T>,
|
||||||
|
y: Buffer<T>,
|
||||||
|
yErr: Buffer<T>,
|
||||||
|
model: A.(I) -> I,
|
||||||
|
): DifferentiableExpression<T, Expression<T>> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||||
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
|
||||||
|
return autoDiff.process {
|
||||||
|
var sum = zero
|
||||||
|
|
||||||
|
x.indices.forEach {
|
||||||
|
val xValue = const(x[it])
|
||||||
|
val yValue = const(y[it])
|
||||||
|
val yErrValue = const(yErr[it])
|
||||||
|
val modelValue = model(xValue)
|
||||||
|
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public fun <T> FunctionOptimization<T>.withFeatures(
|
||||||
|
vararg newFeature: OptimizationFeature,
|
||||||
|
): FunctionOptimization<T> = FunctionOptimization(
|
||||||
|
features.with(*newFeature),
|
||||||
|
expression,
|
||||||
|
initialGuess,
|
||||||
|
parameters
|
||||||
|
)
|
||||||
|
|
||||||
//
|
|
||||||
///**
|
|
||||||
// * A likelihood function optimization problem with provided derivatives
|
|
||||||
// */
|
|
||||||
//public interface FunctionOptimizationBuilder<T : Any> {
|
|
||||||
// /**
|
|
||||||
// * The optimization direction. If true search for function maximum, if false, search for the minimum
|
|
||||||
// */
|
|
||||||
// public var maximize: Boolean
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Define the initial guess for the optimization problem
|
|
||||||
// */
|
|
||||||
// public fun initialGuess(map: Map<Symbol, T>)
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Set a differentiable expression as objective function as function and gradient provider
|
|
||||||
// */
|
|
||||||
// public fun function(expression: DifferentiableExpression<T, Expression<T>>)
|
|
||||||
//
|
|
||||||
// public companion object {
|
|
||||||
// /**
|
|
||||||
// * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
|
||||||
// */
|
|
||||||
// public fun <T : Any, I : Any, A> chiSquared(
|
|
||||||
// autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
|
||||||
// x: Buffer<T>,
|
|
||||||
// y: Buffer<T>,
|
|
||||||
// yErr: Buffer<T>,
|
|
||||||
// model: A.(I) -> I,
|
|
||||||
// ): DifferentiableExpression<T, Expression<T>> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
|
||||||
// require(x.size == y.size) { "X and y buffers should be of the same size" }
|
|
||||||
// require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
|
||||||
//
|
|
||||||
// return autoDiff.process {
|
|
||||||
// var sum = zero
|
|
||||||
//
|
|
||||||
// x.indices.forEach {
|
|
||||||
// val xValue = const(x[it])
|
|
||||||
// val yValue = const(y[it])
|
|
||||||
// val yErrValue = const(yErr[it])
|
|
||||||
// val modelValue = model(xValue)
|
|
||||||
// sum += ((yValue - modelValue) / yErrValue).pow(2)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// sum
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
///**
|
|
||||||
// * Define a chi-squared-based objective function
|
|
||||||
// */
|
|
||||||
//public fun <T : Any, I : Any, A> FunctionOptimization<T>.chiSquared(
|
|
||||||
// autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
|
||||||
// x: Buffer<T>,
|
|
||||||
// y: Buffer<T>,
|
|
||||||
// yErr: Buffer<T>,
|
|
||||||
// model: A.(I) -> I,
|
|
||||||
//) where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
|
||||||
// val chiSquared = FunctionOptimization.chiSquared(autoDiff, x, y, yErr, model)
|
|
||||||
// function(chiSquared)
|
|
||||||
// maximize = false
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
///**
|
|
||||||
// * Optimize differentiable expression using specific [OptimizationProblemFactory]
|
|
||||||
// */
|
|
||||||
//public suspend fun <T : Any, F : FunctionOptimization<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
|
||||||
// factory: OptimizationProblemFactory<T, F>,
|
|
||||||
// vararg symbols: Symbol,
|
|
||||||
// configuration: F.() -> Unit,
|
|
||||||
//): OptimizationResult<T> {
|
|
||||||
// require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
|
||||||
// val problem = factory(symbols.toList(), configuration)
|
|
||||||
// problem.function(this)
|
|
||||||
// return problem.optimize()
|
|
||||||
//}
|
|
||||||
|
@ -7,6 +7,8 @@ package space.kscience.kmath.optimization
|
|||||||
|
|
||||||
import space.kscience.kmath.misc.FeatureSet
|
import space.kscience.kmath.misc.FeatureSet
|
||||||
import space.kscience.kmath.misc.Featured
|
import space.kscience.kmath.misc.Featured
|
||||||
|
import space.kscience.kmath.misc.Loggable
|
||||||
|
import space.kscience.kmath.misc.Symbol
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
public interface OptimizationFeature
|
public interface OptimizationFeature
|
||||||
@ -18,6 +20,10 @@ public interface OptimizationProblem : Featured<OptimizationFeature> {
|
|||||||
|
|
||||||
public inline fun <reified T : OptimizationFeature> OptimizationProblem.getFeature(): T? = getFeature(T::class)
|
public inline fun <reified T : OptimizationFeature> OptimizationProblem.getFeature(): T? = getFeature(T::class)
|
||||||
|
|
||||||
|
public open class OptimizationResult<T>(public val point: Map<Symbol, T>) : OptimizationFeature
|
||||||
|
|
||||||
|
public class OptimizationLog(private val loggable: Loggable) : Loggable by loggable, OptimizationFeature
|
||||||
|
|
||||||
//public class OptimizationResult<T>(
|
//public class OptimizationResult<T>(
|
||||||
// public val point: Map<Symbol, T>,
|
// public val point: Map<Symbol, T>,
|
||||||
// public val value: T,
|
// public val value: T,
|
||||||
|
@ -14,9 +14,11 @@ import space.kscience.kmath.misc.Symbol
|
|||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
import space.kscience.kmath.operations.ExtendedField
|
||||||
import space.kscience.kmath.operations.Field
|
import space.kscience.kmath.operations.Field
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.structures.indices
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public interface XYFit<T> : OptimizationProblem<T> {
|
public interface XYFit<T> : OptimizationProblem {
|
||||||
|
|
||||||
public val algebra: Field<T>
|
public val algebra: Field<T>
|
||||||
|
|
||||||
@ -43,3 +45,33 @@ public interface XYFit<T> : OptimizationProblem<T> {
|
|||||||
autoDiff.process { modelFunction(const(arg)) }
|
autoDiff.process { modelFunction(const(arg)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
///**
|
||||||
|
// * Define a chi-squared-based objective function
|
||||||
|
// */
|
||||||
|
//public fun <T : Any, I : Any, A> FunctionOptimization<T>.chiSquared(
|
||||||
|
// autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
||||||
|
// x: Buffer<T>,
|
||||||
|
// y: Buffer<T>,
|
||||||
|
// yErr: Buffer<T>,
|
||||||
|
// model: A.(I) -> I,
|
||||||
|
//) where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||||
|
// val chiSquared = FunctionOptimization.chiSquared(autoDiff, x, y, yErr, model)
|
||||||
|
// function(chiSquared)
|
||||||
|
// maximize = false
|
||||||
|
//}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
||||||
|
*/
|
||||||
|
public suspend fun <T : Any, F : FunctionOptimization<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
||||||
|
factory: OptimizationProblemFactory<T, F>,
|
||||||
|
vararg symbols: Symbol,
|
||||||
|
configuration: F.() -> Unit,
|
||||||
|
): OptimizationResult<T> {
|
||||||
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
|
val problem = factory(symbols.toList(), configuration)
|
||||||
|
problem.function(this)
|
||||||
|
return problem.optimize()
|
||||||
|
}
|
@ -27,10 +27,6 @@ import kotlin.math.pow
|
|||||||
|
|
||||||
private typealias ParamSet = Map<Symbol, Double>
|
private typealias ParamSet = Map<Symbol, Double>
|
||||||
|
|
||||||
public fun interface FitLogger {
|
|
||||||
public fun log(block: () -> String)
|
|
||||||
}
|
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public class QowFit(
|
public class QowFit(
|
||||||
override val symbols: List<Symbol>,
|
override val symbols: List<Symbol>,
|
||||||
|
Loading…
Reference in New Issue
Block a user