forked from kscience/kmath
[WIP] optimization refactoring
This commit is contained in:
parent
88e94a7fd9
commit
c240fa4a9e
@ -45,6 +45,27 @@ public fun <T, X : T, Y : T> XYColumnarData(x: Buffer<X>, y: Buffer<Y>): XYColum
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent a [ColumnarData] as an [XYColumnarData]. The presence or respective columns is checked on creation.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T> ColumnarData<T>.asXYData(
|
||||||
|
xSymbol: Symbol,
|
||||||
|
ySymbol: Symbol,
|
||||||
|
): XYColumnarData<T, T, T> = object : XYColumnarData<T, T, T> {
|
||||||
|
init {
|
||||||
|
requireNotNull(this@asXYData[xSymbol]){"The column with name $xSymbol is not present in $this"}
|
||||||
|
requireNotNull(this@asXYData[ySymbol]){"The column with name $ySymbol is not present in $this"}
|
||||||
|
}
|
||||||
|
override val size: Int get() = this@asXYData.size
|
||||||
|
override val x: Buffer<T> get() = this@asXYData[xSymbol]!!
|
||||||
|
override val y: Buffer<T> get() = this@asXYData[ySymbol]!!
|
||||||
|
override fun get(symbol: Symbol): Buffer<T>? = when (symbol) {
|
||||||
|
Symbol.x -> x
|
||||||
|
Symbol.y -> y
|
||||||
|
else -> this@asXYData.get(symbol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A zero-copy method to represent a [Structure2D] as a two-column x-y data.
|
* A zero-copy method to represent a [Structure2D] as a two-column x-y data.
|
||||||
|
@ -19,9 +19,12 @@ public interface Symbol : MST {
|
|||||||
public val identity: String
|
public val identity: String
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public val x: StringSymbol = StringSymbol("x")
|
public val x: Symbol = Symbol("x")
|
||||||
public val y: StringSymbol = StringSymbol("y")
|
public val xError: Symbol = Symbol("x.error")
|
||||||
public val z: StringSymbol = StringSymbol("z")
|
public val y: Symbol = Symbol("y")
|
||||||
|
public val yError: Symbol = Symbol("y.error")
|
||||||
|
public val z: Symbol = Symbol("z")
|
||||||
|
public val zError: Symbol = Symbol("z.error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,10 +32,15 @@ public interface Symbol : MST {
|
|||||||
* A [Symbol] with a [String] identity
|
* A [Symbol] with a [String] identity
|
||||||
*/
|
*/
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class StringSymbol(override val identity: String) : Symbol {
|
internal value class StringSymbol(override val identity: String) : Symbol {
|
||||||
override fun toString(): String = identity
|
override fun toString(): String = identity
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create s Symbols with a string identity
|
||||||
|
*/
|
||||||
|
public fun Symbol(identity: String): Symbol = StringSymbol(identity)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A delegate to create a symbol with a string identity in this scope
|
* A delegate to create a symbol with a string identity in this scope
|
||||||
*/
|
*/
|
||||||
|
@ -5,10 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.optimization
|
package space.kscience.kmath.optimization
|
||||||
|
|
||||||
import space.kscience.kmath.expressions.AutoDiffProcessor
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
|
||||||
import space.kscience.kmath.expressions.Expression
|
|
||||||
import space.kscience.kmath.expressions.ExpressionAlgebra
|
|
||||||
import space.kscience.kmath.misc.FeatureSet
|
import space.kscience.kmath.misc.FeatureSet
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
import space.kscience.kmath.operations.ExtendedField
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
@ -67,3 +64,15 @@ public fun <T> FunctionOptimization<T>.withFeatures(
|
|||||||
expression,
|
expression,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize differentiable expression using specific [optimizer] form given [startingPoint]
|
||||||
|
*/
|
||||||
|
public suspend fun <T : Any> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
||||||
|
optimizer: Optimizer<FunctionOptimization<T>>,
|
||||||
|
startingPoint: Map<Symbol, T>,
|
||||||
|
vararg features: OptimizationFeature,
|
||||||
|
): FunctionOptimization<T> {
|
||||||
|
val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this)
|
||||||
|
return optimizer.process(problem)
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ public class OptimizationLog(private val loggable: Loggable) : Loggable by logga
|
|||||||
}
|
}
|
||||||
|
|
||||||
public class OptimizationParameters(public val symbols: List<Symbol>): OptimizationFeature{
|
public class OptimizationParameters(public val symbols: List<Symbol>): OptimizationFeature{
|
||||||
|
public constructor(vararg symbols: Symbol) : this(listOf(*symbols))
|
||||||
override fun toString(): String = "Parameters($symbols)"
|
override fun toString(): String = "Parameters($symbols)"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,72 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2021 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.optimization
|
|
||||||
|
|
||||||
import space.kscience.kmath.data.ColumnarData
|
|
||||||
import space.kscience.kmath.expressions.*
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
|
||||||
import space.kscience.kmath.operations.Field
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public interface XYFit<T> : OptimizationProblem {
|
|
||||||
|
|
||||||
public val algebra: Field<T>
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set X-Y data for this fit optionally including x and y errors
|
|
||||||
*/
|
|
||||||
public fun data(
|
|
||||||
dataSet: ColumnarData<T>,
|
|
||||||
xSymbol: Symbol,
|
|
||||||
ySymbol: Symbol,
|
|
||||||
xErrSymbol: Symbol? = null,
|
|
||||||
yErrSymbol: Symbol? = null,
|
|
||||||
)
|
|
||||||
|
|
||||||
public fun model(model: (T) -> DifferentiableExpression<T, *>)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the differentiable model for this fit
|
|
||||||
*/
|
|
||||||
public fun <I : Any, A> model(
|
|
||||||
autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
|
||||||
modelFunction: A.(I) -> I,
|
|
||||||
): Unit where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> = model { arg ->
|
|
||||||
autoDiff.process { modelFunction(const(arg)) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
///**
|
|
||||||
// * Define a chi-squared-based objective function
|
|
||||||
// */
|
|
||||||
//public fun <T : Any, I : Any, A> FunctionOptimization<T>.chiSquared(
|
|
||||||
// autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
|
||||||
// x: Buffer<T>,
|
|
||||||
// y: Buffer<T>,
|
|
||||||
// yErr: Buffer<T>,
|
|
||||||
// model: A.(I) -> I,
|
|
||||||
//) where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
|
||||||
// val chiSquared = FunctionOptimization.chiSquared(autoDiff, x, y, yErr, model)
|
|
||||||
// function(chiSquared)
|
|
||||||
// maximize = false
|
|
||||||
//}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Optimize differentiable expression using specific [Optimizer]
|
|
||||||
*/
|
|
||||||
public suspend fun <T : Any, F : FunctionOptimization<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
|
||||||
optimizer: Optimizer<F>,
|
|
||||||
startingPoint: Map<Symbol,T>,
|
|
||||||
vararg features: OptimizationFeature
|
|
||||||
): OptimizationProblem {
|
|
||||||
// require(startingPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
|
||||||
// val problem = factory(symbols.toList(), configuration)
|
|
||||||
// problem.function(this)
|
|
||||||
// return problem.optimize()
|
|
||||||
val problem = FunctionOptimization<T>()
|
|
||||||
}
|
|
@ -0,0 +1,92 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
@file:OptIn(UnstableKMathAPI::class)
|
||||||
|
|
||||||
|
package space.kscience.kmath.optimization
|
||||||
|
|
||||||
|
import space.kscience.kmath.data.XYColumnarData
|
||||||
|
import space.kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import space.kscience.kmath.expressions.Expression
|
||||||
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.misc.FeatureSet
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.Field
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
|
|
||||||
|
public fun interface PointToCurveDistance<T> {
|
||||||
|
public fun distance(problem: XYOptimization<T>, index: Int): DifferentiableExpression<T, Expression<T>>
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
public fun <T> byY(
|
||||||
|
algebra: Field<T>,
|
||||||
|
): PointToCurveDistance<T> = PointToCurveDistance { problem, index ->
|
||||||
|
algebra {
|
||||||
|
val x = problem.data.x[index]
|
||||||
|
val y = problem.data.y[index]
|
||||||
|
|
||||||
|
val model = problem.model(args + (Symbol.x to x))
|
||||||
|
model - y
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// val default = PointToCurveDistance<Double>{args, data, index ->
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public class XYOptimization<T>(
|
||||||
|
override val features: FeatureSet<OptimizationFeature>,
|
||||||
|
public val data: XYColumnarData<T, T, T>,
|
||||||
|
public val model: DifferentiableExpression<T, Expression<T>>,
|
||||||
|
) : OptimizationProblem
|
||||||
|
|
||||||
|
//
|
||||||
|
//@UnstableKMathAPI
|
||||||
|
//public interface XYFit<T> : OptimizationProblem {
|
||||||
|
//
|
||||||
|
// public val algebra: Field<T>
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Set X-Y data for this fit optionally including x and y errors
|
||||||
|
// */
|
||||||
|
// public fun data(
|
||||||
|
// dataSet: ColumnarData<T>,
|
||||||
|
// xSymbol: Symbol,
|
||||||
|
// ySymbol: Symbol,
|
||||||
|
// xErrSymbol: Symbol? = null,
|
||||||
|
// yErrSymbol: Symbol? = null,
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// public fun model(model: (T) -> DifferentiableExpression<T, *>)
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Set the differentiable model for this fit
|
||||||
|
// */
|
||||||
|
// public fun <I : Any, A> model(
|
||||||
|
// autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
||||||
|
// modelFunction: A.(I) -> I,
|
||||||
|
// ): Unit where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> = model { arg ->
|
||||||
|
// autoDiff.process { modelFunction(const(arg)) }
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
//
|
||||||
|
///**
|
||||||
|
// * Define a chi-squared-based objective function
|
||||||
|
// */
|
||||||
|
//public fun <T : Any, I : Any, A> FunctionOptimization<T>.chiSquared(
|
||||||
|
// autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
||||||
|
// x: Buffer<T>,
|
||||||
|
// y: Buffer<T>,
|
||||||
|
// yErr: Buffer<T>,
|
||||||
|
// model: A.(I) -> I,
|
||||||
|
//) where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||||
|
// val chiSquared = FunctionOptimization.chiSquared(autoDiff, x, y, yErr, model)
|
||||||
|
// function(chiSquared)
|
||||||
|
// maximize = false
|
||||||
|
//}
|
@ -15,7 +15,7 @@ import space.kscience.kmath.operations.Field
|
|||||||
import space.kscience.kmath.optimization.OptimizationFeature
|
import space.kscience.kmath.optimization.OptimizationFeature
|
||||||
import space.kscience.kmath.optimization.OptimizationProblemFactory
|
import space.kscience.kmath.optimization.OptimizationProblemFactory
|
||||||
import space.kscience.kmath.optimization.OptimizationResult
|
import space.kscience.kmath.optimization.OptimizationResult
|
||||||
import space.kscience.kmath.optimization.XYFit
|
import space.kscience.kmath.optimization.XYOptimization
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import space.kscience.kmath.structures.DoubleL2Norm
|
import space.kscience.kmath.structures.DoubleL2Norm
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
@ -28,7 +28,7 @@ public class QowFit(
|
|||||||
override val symbols: List<Symbol>,
|
override val symbols: List<Symbol>,
|
||||||
private val space: LinearSpace<Double, DoubleField>,
|
private val space: LinearSpace<Double, DoubleField>,
|
||||||
private val solver: LinearSolver<Double>,
|
private val solver: LinearSolver<Double>,
|
||||||
) : XYFit<Double>, SymbolIndexer {
|
) : XYOptimization<Double>, SymbolIndexer {
|
||||||
|
|
||||||
private var logger: FitLogger? = null
|
private var logger: FitLogger? = null
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user