[WIP] optimization refactoring

This commit is contained in:
Alexander Nozik 2021-05-25 16:42:23 +03:00
parent 88e94a7fd9
commit c240fa4a9e
7 changed files with 141 additions and 82 deletions

View File

@ -45,6 +45,27 @@ public fun <T, X : T, Y : T> XYColumnarData(x: Buffer<X>, y: Buffer<Y>): XYColum
}
}
/**
* Represent a [ColumnarData] as an [XYColumnarData]. The presence or respective columns is checked on creation.
*/
@UnstableKMathAPI
public fun <T> ColumnarData<T>.asXYData(
xSymbol: Symbol,
ySymbol: Symbol,
): XYColumnarData<T, T, T> = object : XYColumnarData<T, T, T> {
init {
requireNotNull(this@asXYData[xSymbol]){"The column with name $xSymbol is not present in $this"}
requireNotNull(this@asXYData[ySymbol]){"The column with name $ySymbol is not present in $this"}
}
override val size: Int get() = this@asXYData.size
override val x: Buffer<T> get() = this@asXYData[xSymbol]!!
override val y: Buffer<T> get() = this@asXYData[ySymbol]!!
override fun get(symbol: Symbol): Buffer<T>? = when (symbol) {
Symbol.x -> x
Symbol.y -> y
else -> this@asXYData.get(symbol)
}
}
/**
* A zero-copy method to represent a [Structure2D] as a two-column x-y data.

View File

@ -19,9 +19,12 @@ public interface Symbol : MST {
public val identity: String
public companion object {
public val x: StringSymbol = StringSymbol("x")
public val y: StringSymbol = StringSymbol("y")
public val z: StringSymbol = StringSymbol("z")
public val x: Symbol = Symbol("x")
public val xError: Symbol = Symbol("x.error")
public val y: Symbol = Symbol("y")
public val yError: Symbol = Symbol("y.error")
public val z: Symbol = Symbol("z")
public val zError: Symbol = Symbol("z.error")
}
}
@ -29,10 +32,15 @@ public interface Symbol : MST {
* A [Symbol] with a [String] identity
*/
@JvmInline
public value class StringSymbol(override val identity: String) : Symbol {
internal value class StringSymbol(override val identity: String) : Symbol {
override fun toString(): String = identity
}
/**
* Create s Symbols with a string identity
*/
public fun Symbol(identity: String): Symbol = StringSymbol(identity)
/**
* A delegate to create a symbol with a string identity in this scope
*/

View File

@ -5,10 +5,7 @@
package space.kscience.kmath.optimization
import space.kscience.kmath.expressions.AutoDiffProcessor
import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.ExpressionAlgebra
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.FeatureSet
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.structures.Buffer
@ -67,3 +64,15 @@ public fun <T> FunctionOptimization<T>.withFeatures(
expression,
)
/**
* Optimize differentiable expression using specific [optimizer] form given [startingPoint]
*/
public suspend fun <T : Any> DifferentiableExpression<T, Expression<T>>.optimizeWith(
optimizer: Optimizer<FunctionOptimization<T>>,
startingPoint: Map<Symbol, T>,
vararg features: OptimizationFeature,
): FunctionOptimization<T> {
val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this)
return optimizer.process(problem)
}

View File

@ -35,6 +35,7 @@ public class OptimizationLog(private val loggable: Loggable) : Loggable by logga
}
public class OptimizationParameters(public val symbols: List<Symbol>): OptimizationFeature{
public constructor(vararg symbols: Symbol) : this(listOf(*symbols))
override fun toString(): String = "Parameters($symbols)"
}

View File

@ -1,72 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.optimization
import space.kscience.kmath.data.ColumnarData
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.Field
@UnstableKMathAPI
public interface XYFit<T> : OptimizationProblem {
public val algebra: Field<T>
/**
* Set X-Y data for this fit optionally including x and y errors
*/
public fun data(
dataSet: ColumnarData<T>,
xSymbol: Symbol,
ySymbol: Symbol,
xErrSymbol: Symbol? = null,
yErrSymbol: Symbol? = null,
)
public fun model(model: (T) -> DifferentiableExpression<T, *>)
/**
* Set the differentiable model for this fit
*/
public fun <I : Any, A> model(
autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
modelFunction: A.(I) -> I,
): Unit where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> = model { arg ->
autoDiff.process { modelFunction(const(arg)) }
}
}
//
///**
// * Define a chi-squared-based objective function
// */
//public fun <T : Any, I : Any, A> FunctionOptimization<T>.chiSquared(
// autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
// x: Buffer<T>,
// y: Buffer<T>,
// yErr: Buffer<T>,
// model: A.(I) -> I,
//) where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
// val chiSquared = FunctionOptimization.chiSquared(autoDiff, x, y, yErr, model)
// function(chiSquared)
// maximize = false
//}
/**
* Optimize differentiable expression using specific [Optimizer]
*/
public suspend fun <T : Any, F : FunctionOptimization<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
optimizer: Optimizer<F>,
startingPoint: Map<Symbol,T>,
vararg features: OptimizationFeature
): OptimizationProblem {
// require(startingPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" }
// val problem = factory(symbols.toList(), configuration)
// problem.function(this)
// return problem.optimize()
val problem = FunctionOptimization<T>()
}

View File

@ -0,0 +1,92 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
@file:OptIn(UnstableKMathAPI::class)
package space.kscience.kmath.optimization
import space.kscience.kmath.data.XYColumnarData
import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.FeatureSet
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Field
import space.kscience.kmath.operations.invoke
public fun interface PointToCurveDistance<T> {
public fun distance(problem: XYOptimization<T>, index: Int): DifferentiableExpression<T, Expression<T>>
public companion object {
public fun <T> byY(
algebra: Field<T>,
): PointToCurveDistance<T> = PointToCurveDistance { problem, index ->
algebra {
val x = problem.data.x[index]
val y = problem.data.y[index]
val model = problem.model(args + (Symbol.x to x))
model - y
}
}
// val default = PointToCurveDistance<Double>{args, data, index ->
//
// }
}
}
public class XYOptimization<T>(
override val features: FeatureSet<OptimizationFeature>,
public val data: XYColumnarData<T, T, T>,
public val model: DifferentiableExpression<T, Expression<T>>,
) : OptimizationProblem
//
//@UnstableKMathAPI
//public interface XYFit<T> : OptimizationProblem {
//
// public val algebra: Field<T>
//
// /**
// * Set X-Y data for this fit optionally including x and y errors
// */
// public fun data(
// dataSet: ColumnarData<T>,
// xSymbol: Symbol,
// ySymbol: Symbol,
// xErrSymbol: Symbol? = null,
// yErrSymbol: Symbol? = null,
// )
//
// public fun model(model: (T) -> DifferentiableExpression<T, *>)
//
// /**
// * Set the differentiable model for this fit
// */
// public fun <I : Any, A> model(
// autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
// modelFunction: A.(I) -> I,
// ): Unit where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> = model { arg ->
// autoDiff.process { modelFunction(const(arg)) }
// }
//}
//
///**
// * Define a chi-squared-based objective function
// */
//public fun <T : Any, I : Any, A> FunctionOptimization<T>.chiSquared(
// autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
// x: Buffer<T>,
// y: Buffer<T>,
// yErr: Buffer<T>,
// model: A.(I) -> I,
//) where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
// val chiSquared = FunctionOptimization.chiSquared(autoDiff, x, y, yErr, model)
// function(chiSquared)
// maximize = false
//}

View File

@ -15,7 +15,7 @@ import space.kscience.kmath.operations.Field
import space.kscience.kmath.optimization.OptimizationFeature
import space.kscience.kmath.optimization.OptimizationProblemFactory
import space.kscience.kmath.optimization.OptimizationResult
import space.kscience.kmath.optimization.XYFit
import space.kscience.kmath.optimization.XYOptimization
import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.DoubleL2Norm
import kotlin.math.pow
@ -28,7 +28,7 @@ public class QowFit(
override val symbols: List<Symbol>,
private val space: LinearSpace<Double, DoubleField>,
private val solver: LinearSolver<Double>,
) : XYFit<Double>, SymbolIndexer {
) : XYOptimization<Double>, SymbolIndexer {
private var logger: FitLogger? = null