From f2b7a08ad8018d508a1edc3c5af9894fb89bafd5 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 25 May 2021 16:53:53 +0300 Subject: [PATCH] Remove second generic from DifferentiableExpression --- CHANGELOG.md | 1 + .../DerivativeStructureExpression.kt | 4 +-- .../commons/optimization/CMOptimization.kt | 9 +---- .../kmath/commons/optimization/cmFit.kt | 8 ++--- .../expressions/DifferentiableExpression.kt | 34 ++++++++++++++----- .../kmath/expressions/SimpleAutoDiff.kt | 2 +- .../kmath/kotlingrad/KotlingradExpression.kt | 2 +- .../optimization/FunctionOptimization.kt | 6 ++-- .../kscience/kmath/optimization/XYFit.kt | 2 +- 9 files changed, 39 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c6b14b95..524d2a1de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ - MSTExpression - Expression algebra builders - Complex and Quaternion no longer are elements. +- Second generic from DifferentiableExpression ### Fixed - Ring inherits RingOperations, not GroupOperations diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 89e216601..361027968 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -106,7 +106,7 @@ public class DerivativeStructureField( public companion object : AutoDiffProcessor> { - public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression> = + public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression = DerivativeStructureExpression(function) } } @@ -116,7 +116,7 @@ public class DerivativeStructureField( */ public class DerivativeStructureExpression( public val function: DerivativeStructureField.() -> DerivativeStructure, -) : DifferentiableExpression> { +) : DifferentiableExpression { public override operator fun invoke(arguments: Map): Double = DerivativeStructureField(0, arguments).function().value diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimization.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimization.kt index bca00de46..400ee0310 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimization.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/CMOptimization.kt @@ -17,14 +17,7 @@ import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer import space.kscience.kmath.expressions.* import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.optimization.* -import kotlin.collections.HashMap -import kotlin.collections.List -import kotlin.collections.Map import kotlin.collections.set -import kotlin.collections.setOf -import kotlin.collections.toList -import kotlin.collections.toMap -import kotlin.collections.toTypedArray import kotlin.reflect.KClass public operator fun PointValuePair.component1(): DoubleArray = point @@ -71,7 +64,7 @@ public class CMOptimization( addOptimizationData(objectiveFunction) } - public override fun diffFunction(expression: DifferentiableExpression>) { + public override fun diffFunction(expression: DifferentiableExpression) { function(expression) val gradientFunction = ObjectiveFunctionGradient { val args = it.toMap() diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/cmFit.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/cmFit.kt index a5a913623..645c41291 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/cmFit.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/optimization/cmFit.kt @@ -25,7 +25,7 @@ public fun FunctionOptimization.Companion.chiSquared( y: Buffer, yErr: Buffer, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, -): DifferentiableExpression> = chiSquared(DerivativeStructureField, x, y, yErr, model) +): DifferentiableExpression = chiSquared(DerivativeStructureField, x, y, yErr, model) /** * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation @@ -35,7 +35,7 @@ public fun FunctionOptimization.Companion.chiSquared( y: Iterable, yErr: Iterable, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, -): DifferentiableExpression> = chiSquared( +): DifferentiableExpression = chiSquared( DerivativeStructureField, x.toList().asBuffer(), y.toList().asBuffer(), @@ -54,12 +54,12 @@ public fun Expression.optimize( /** * Optimize differentiable expression */ -public fun DifferentiableExpression>.optimize( +public fun DifferentiableExpression.optimize( vararg symbols: Symbol, configuration: CMOptimization.() -> Unit, ): OptimizationResult = optimizeWith(CMOptimization, symbols = symbols, configuration) -public fun DifferentiableExpression>.minimize( +public fun DifferentiableExpression.minimize( vararg startPoint: Pair, configuration: CMOptimization.() -> Unit = {}, ): OptimizationResult { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DifferentiableExpression.kt index 33d72afad..1dcada6d3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DifferentiableExpression.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DifferentiableExpression.kt @@ -11,35 +11,51 @@ package space.kscience.kmath.expressions * @param T the type this expression takes as argument and returns. * @param R the type of expression this expression can be differentiated to. */ -public interface DifferentiableExpression> : Expression { +public interface DifferentiableExpression : Expression { /** * Differentiates this expression by ordered collection of [symbols]. * * @param symbols the symbols. * @return the derivative or `null`. */ - public fun derivativeOrNull(symbols: List): R? + public fun derivativeOrNull(symbols: List): Expression? } -public fun > DifferentiableExpression.derivative(symbols: List): R = +public fun DifferentiableExpression.derivative(symbols: List): Expression = derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") -public fun > DifferentiableExpression.derivative(vararg symbols: Symbol): R = +public fun DifferentiableExpression.derivative(vararg symbols: Symbol): Expression = derivative(symbols.toList()) -public fun > DifferentiableExpression.derivative(name: String): R = +public fun DifferentiableExpression.derivative(name: String): Expression = + derivative(StringSymbol(name)) + +/** + * A special type of [DifferentiableExpression] which returns typed expressions as derivatives + */ +public interface SpecialDifferentiableExpression>: DifferentiableExpression { + override fun derivativeOrNull(symbols: List): R? +} + +public fun > SpecialDifferentiableExpression.derivative(symbols: List): R = + derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") + +public fun > SpecialDifferentiableExpression.derivative(vararg symbols: Symbol): R = + derivative(symbols.toList()) + +public fun > SpecialDifferentiableExpression.derivative(name: String): R = derivative(StringSymbol(name)) /** * A [DifferentiableExpression] that defines only first derivatives */ -public abstract class FirstDerivativeExpression> : DifferentiableExpression { +public abstract class FirstDerivativeExpression : DifferentiableExpression { /** * Returns first derivative of this expression by given [symbol]. */ - public abstract fun derivativeOrNull(symbol: Symbol): R? + public abstract fun derivativeOrNull(symbol: Symbol): Expression? - public final override fun derivativeOrNull(symbols: List): R? { + public final override fun derivativeOrNull(symbols: List): Expression? { val dSymbol = symbols.firstOrNull() ?: return null return derivativeOrNull(dSymbol) } @@ -49,5 +65,5 @@ public abstract class FirstDerivativeExpression> : Differen * A factory that converts an expression in autodiff variables to a [DifferentiableExpression] */ public fun interface AutoDiffProcessor, out R : Expression> { - public fun process(function: A.() -> I): DifferentiableExpression + public fun process(function: A.() -> I): DifferentiableExpression } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index 254d60b3d..478b85620 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -232,7 +232,7 @@ public fun > F.simpleAutoDiff( public class SimpleAutoDiffExpression>( public val field: F, public val function: SimpleAutoDiffField.() -> AutoDiffValue, -) : FirstDerivativeExpression>() { +) : FirstDerivativeExpression() { public override operator fun invoke(arguments: Map): T { //val bindings = arguments.entries.map { it.key.bind(it.value) } return SimpleAutoDiffField(field, arguments).function().value diff --git a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt index 72ecee4f1..4294462c0 100644 --- a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt +++ b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt @@ -24,7 +24,7 @@ import space.kscience.kmath.operations.NumericAlgebra public class KotlingradExpression>( public val algebra: A, public val mst: MST, -) : DifferentiableExpression> { +) : SpecialDifferentiableExpression> { public override fun invoke(arguments: Map): T = mst.interpret(algebra, arguments) public override fun derivativeOrNull(symbols: List): KotlingradExpression = diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt index 4cf5aea84..f54ba5723 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt @@ -27,7 +27,7 @@ public interface FunctionOptimization : Optimization { /** * Set a differentiable expression as objective function as function and gradient provider */ - public fun diffFunction(expression: DifferentiableExpression>) + public fun diffFunction(expression: DifferentiableExpression) public companion object { /** @@ -39,7 +39,7 @@ public interface FunctionOptimization : Optimization { y: Buffer, yErr: Buffer, model: A.(I) -> I, - ): DifferentiableExpression> where A : ExtendedField, A : ExpressionAlgebra { + ): DifferentiableExpression where A : ExtendedField, A : ExpressionAlgebra { require(x.size == y.size) { "X and y buffers should be of the same size" } require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } @@ -78,7 +78,7 @@ public fun FunctionOptimization.chiSquared( /** * Optimize differentiable expression using specific [OptimizationProblemFactory] */ -public fun > DifferentiableExpression>.optimizeWith( +public fun > DifferentiableExpression.optimizeWith( factory: OptimizationProblemFactory, vararg symbols: Symbol, configuration: F.() -> Unit, diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt index 633e9ae0e..70d7fdf79 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt @@ -27,7 +27,7 @@ public interface XYFit : Optimization { yErrSymbol: Symbol? = null, ) - public fun model(model: (T) -> DifferentiableExpression) + public fun model(model: (T) -> DifferentiableExpression) /** * Set the differentiable model for this fit