From 2e4be2aa3a3678ee5fed7de029106bb793ed4c32 Mon Sep 17 00:00:00 2001 From: darksnake Date: Tue, 24 Jan 2023 11:36:53 +0300 Subject: [PATCH] A minor change to XYfit arguments --- .../kscience/kmath/optimization/XYFit.kt | 67 ++++++++++++------- 1 file changed, 44 insertions(+), 23 deletions(-) diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt index d2182e91f..bba7a7113 100644 --- a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt @@ -30,7 +30,7 @@ public interface PointToCurveDistance : OptimizationFeature { return object : DifferentiableExpression { override fun derivativeOrNull( - symbols: List + symbols: List, ): Expression? = problem.model.derivativeOrNull(symbols)?.let { derivExpression -> Expression { arguments -> derivExpression.invoke(arguments + (Symbol.x to x)) @@ -93,24 +93,15 @@ public fun XYFit.withFeature(vararg features: OptimizationFeature): XYFit { return XYFit(data, model, this.features.with(*features), pointToCurveDistance, pointWeight) } -/** - * Fit given dta with - */ -public suspend fun XYColumnarData.fitWith( +public suspend fun XYColumnarData.fitWith( optimizer: Optimizer, - processor: AutoDiffProcessor, + modelExpression: DifferentiableExpression, startingPoint: Map, vararg features: OptimizationFeature = emptyArray(), xSymbol: Symbol = Symbol.x, pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, pointWeight: PointWeight = PointWeight.byYSigma, - model: A.(I) -> I -): XYFit where A : ExtendedField, A : ExpressionAlgebra { - val modelExpression = processor.differentiate { - val x = bindSymbol(xSymbol) - model(x) - } - +): XYFit { var actualFeatures = FeatureSet.of(*features, OptimizationStartPoint(startingPoint)) if (actualFeatures.getFeature() == null) { @@ -127,20 +118,50 @@ public suspend fun XYColumnarData.fitWith( return optimizer.optimize(problem) } +/** + * Fit given data with a model provided as an expression + */ +public suspend fun XYColumnarData.fitWith( + optimizer: Optimizer, + processor: AutoDiffProcessor, + startingPoint: Map, + vararg features: OptimizationFeature = emptyArray(), + xSymbol: Symbol = Symbol.x, + pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, + pointWeight: PointWeight = PointWeight.byYSigma, + model: A.(I) -> I, +): XYFit where A : ExtendedField, A : ExpressionAlgebra { + val modelExpression: DifferentiableExpression = processor.differentiate { + val x = bindSymbol(xSymbol) + model(x) + } + + return fitWith( + optimizer = optimizer, + modelExpression = modelExpression, + startingPoint = startingPoint, + features = features, + xSymbol = xSymbol, + pointToCurveDistance = pointToCurveDistance, + pointWeight = pointWeight + ) +} + /** * Compute chi squared value for completed fit. Return null for incomplete fit */ -public val XYFit.chiSquaredOrNull: Double? get() { - val result = resultPointOrNull ?: return null +public val XYFit.chiSquaredOrNull: Double? + get() { + val result = resultPointOrNull ?: return null - return data.indices.sumOf { index-> + return data.indices.sumOf { index -> - val x = data.x[index] - val y = data.y[index] - val yErr = data[Symbol.yError]?.get(index) ?: 1.0 + val x = data.x[index] + val y = data.y[index] + val yErr = data[Symbol.yError]?.get(index) ?: 1.0 - val mu = model.invoke(result + (xSymbol to x) ) + val mu = model.invoke(result + (xSymbol to x)) - ((y - mu)/yErr).pow(2) - } -} \ No newline at end of file + ((y - mu) / yErr).pow(2) + } + } \ No newline at end of file