Remove second generic from DifferentiableExpression

This commit is contained in:
Alexander Nozik 2021-05-25 16:53:53 +03:00
parent 12805712d3
commit f2b7a08ad8
9 changed files with 39 additions and 29 deletions

View File

@ -45,6 +45,7 @@
- MSTExpression - MSTExpression
- Expression algebra builders - Expression algebra builders
- Complex and Quaternion no longer are elements. - Complex and Quaternion no longer are elements.
- Second generic from DifferentiableExpression
### Fixed ### Fixed
- Ring inherits RingOperations, not GroupOperations - Ring inherits RingOperations, not GroupOperations

View File

@ -106,7 +106,7 @@ public class DerivativeStructureField(
public companion object : public companion object :
AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> { AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> {
public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double, Expression<Double>> = public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> =
DerivativeStructureExpression(function) DerivativeStructureExpression(function)
} }
} }
@ -116,7 +116,7 @@ public class DerivativeStructureField(
*/ */
public class DerivativeStructureExpression( public class DerivativeStructureExpression(
public val function: DerivativeStructureField.() -> DerivativeStructure, public val function: DerivativeStructureField.() -> DerivativeStructure,
) : DifferentiableExpression<Double, Expression<Double>> { ) : DifferentiableExpression<Double> {
public override operator fun invoke(arguments: Map<Symbol, Double>): Double = public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
DerivativeStructureField(0, arguments).function().value DerivativeStructureField(0, arguments).function().value

View File

@ -17,14 +17,7 @@ import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.optimization.* import space.kscience.kmath.optimization.*
import kotlin.collections.HashMap
import kotlin.collections.List
import kotlin.collections.Map
import kotlin.collections.set import kotlin.collections.set
import kotlin.collections.setOf
import kotlin.collections.toList
import kotlin.collections.toMap
import kotlin.collections.toTypedArray
import kotlin.reflect.KClass import kotlin.reflect.KClass
public operator fun PointValuePair.component1(): DoubleArray = point public operator fun PointValuePair.component1(): DoubleArray = point
@ -71,7 +64,7 @@ public class CMOptimization(
addOptimizationData(objectiveFunction) addOptimizationData(objectiveFunction)
} }
public override fun diffFunction(expression: DifferentiableExpression<Double, Expression<Double>>) { public override fun diffFunction(expression: DifferentiableExpression<Double>) {
function(expression) function(expression)
val gradientFunction = ObjectiveFunctionGradient { val gradientFunction = ObjectiveFunctionGradient {
val args = it.toMap() val args = it.toMap()

View File

@ -25,7 +25,7 @@ public fun FunctionOptimization.Companion.chiSquared(
y: Buffer<Double>, y: Buffer<Double>,
yErr: Buffer<Double>, yErr: Buffer<Double>,
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(DerivativeStructureField, x, y, yErr, model) ): DifferentiableExpression<Double> = chiSquared(DerivativeStructureField, x, y, yErr, model)
/** /**
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
@ -35,7 +35,7 @@ public fun FunctionOptimization.Companion.chiSquared(
y: Iterable<Double>, y: Iterable<Double>,
yErr: Iterable<Double>, yErr: Iterable<Double>,
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
): DifferentiableExpression<Double, Expression<Double>> = chiSquared( ): DifferentiableExpression<Double> = chiSquared(
DerivativeStructureField, DerivativeStructureField,
x.toList().asBuffer(), x.toList().asBuffer(),
y.toList().asBuffer(), y.toList().asBuffer(),
@ -54,12 +54,12 @@ public fun Expression<Double>.optimize(
/** /**
* Optimize differentiable expression * Optimize differentiable expression
*/ */
public fun DifferentiableExpression<Double, Expression<Double>>.optimize( public fun DifferentiableExpression<Double>.optimize(
vararg symbols: Symbol, vararg symbols: Symbol,
configuration: CMOptimization.() -> Unit, configuration: CMOptimization.() -> Unit,
): OptimizationResult<Double> = optimizeWith(CMOptimization, symbols = symbols, configuration) ): OptimizationResult<Double> = optimizeWith(CMOptimization, symbols = symbols, configuration)
public fun DifferentiableExpression<Double, Expression<Double>>.minimize( public fun DifferentiableExpression<Double>.minimize(
vararg startPoint: Pair<Symbol, Double>, vararg startPoint: Pair<Symbol, Double>,
configuration: CMOptimization.() -> Unit = {}, configuration: CMOptimization.() -> Unit = {},
): OptimizationResult<Double> { ): OptimizationResult<Double> {

View File

@ -11,35 +11,51 @@ package space.kscience.kmath.expressions
* @param T the type this expression takes as argument and returns. * @param T the type this expression takes as argument and returns.
* @param R the type of expression this expression can be differentiated to. * @param R the type of expression this expression can be differentiated to.
*/ */
public interface DifferentiableExpression<T, out R : Expression<T>> : Expression<T> { public interface DifferentiableExpression<T> : Expression<T> {
/** /**
* Differentiates this expression by ordered collection of [symbols]. * Differentiates this expression by ordered collection of [symbols].
* *
* @param symbols the symbols. * @param symbols the symbols.
* @return the derivative or `null`. * @return the derivative or `null`.
*/ */
public fun derivativeOrNull(symbols: List<Symbol>): R? public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>?
} }
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R = public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> =
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R = public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> =
derivative(symbols.toList()) derivative(symbols.toList())
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(name: String): R = public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
derivative(StringSymbol(name))
/**
* A special type of [DifferentiableExpression] which returns typed expressions as derivatives
*/
public interface SpecialDifferentiableExpression<T, R: Expression<T>>: DifferentiableExpression<T> {
override fun derivativeOrNull(symbols: List<Symbol>): R?
}
public fun <T, R : Expression<T>> SpecialDifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R =
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
public fun <T, R : Expression<T>> SpecialDifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R =
derivative(symbols.toList())
public fun <T, R : Expression<T>> SpecialDifferentiableExpression<T, R>.derivative(name: String): R =
derivative(StringSymbol(name)) derivative(StringSymbol(name))
/** /**
* A [DifferentiableExpression] that defines only first derivatives * A [DifferentiableExpression] that defines only first derivatives
*/ */
public abstract class FirstDerivativeExpression<T, R : Expression<T>> : DifferentiableExpression<T, R> { public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
/** /**
* Returns first derivative of this expression by given [symbol]. * Returns first derivative of this expression by given [symbol].
*/ */
public abstract fun derivativeOrNull(symbol: Symbol): R? public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
public final override fun derivativeOrNull(symbols: List<Symbol>): R? { public final override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
val dSymbol = symbols.firstOrNull() ?: return null val dSymbol = symbols.firstOrNull() ?: return null
return derivativeOrNull(dSymbol) return derivativeOrNull(dSymbol)
} }
@ -49,5 +65,5 @@ public abstract class FirstDerivativeExpression<T, R : Expression<T>> : Differen
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression] * A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
*/ */
public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> { public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
public fun process(function: A.() -> I): DifferentiableExpression<T, R> public fun process(function: A.() -> I): DifferentiableExpression<T>
} }

View File

@ -232,7 +232,7 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>( public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F, public val field: F,
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>, public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
) : FirstDerivativeExpression<T, Expression<T>>() { ) : FirstDerivativeExpression<T>() {
public override operator fun invoke(arguments: Map<Symbol, T>): T { public override operator fun invoke(arguments: Map<Symbol, T>): T {
//val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
return SimpleAutoDiffField(field, arguments).function().value return SimpleAutoDiffField(field, arguments).function().value

View File

@ -24,7 +24,7 @@ import space.kscience.kmath.operations.NumericAlgebra
public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>( public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
public val algebra: A, public val algebra: A,
public val mst: MST, public val mst: MST,
) : DifferentiableExpression<T, KotlingradExpression<T, A>> { ) : SpecialDifferentiableExpression<T, KotlingradExpression<T, A>> {
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments) public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): KotlingradExpression<T, A> = public override fun derivativeOrNull(symbols: List<Symbol>): KotlingradExpression<T, A> =

View File

@ -27,7 +27,7 @@ public interface FunctionOptimization<T : Any> : Optimization<T> {
/** /**
* Set a differentiable expression as objective function as function and gradient provider * Set a differentiable expression as objective function as function and gradient provider
*/ */
public fun diffFunction(expression: DifferentiableExpression<T, Expression<T>>) public fun diffFunction(expression: DifferentiableExpression<T>)
public companion object { public companion object {
/** /**
@ -39,7 +39,7 @@ public interface FunctionOptimization<T : Any> : Optimization<T> {
y: Buffer<T>, y: Buffer<T>,
yErr: Buffer<T>, yErr: Buffer<T>,
model: A.(I) -> I, model: A.(I) -> I,
): DifferentiableExpression<T, Expression<T>> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> { ): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
require(x.size == y.size) { "X and y buffers should be of the same size" } require(x.size == y.size) { "X and y buffers should be of the same size" }
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
@ -78,7 +78,7 @@ public fun <T: Any, I : Any, A> FunctionOptimization<T>.chiSquared(
/** /**
* Optimize differentiable expression using specific [OptimizationProblemFactory] * Optimize differentiable expression using specific [OptimizationProblemFactory]
*/ */
public fun <T : Any, F : FunctionOptimization<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith( public fun <T : Any, F : FunctionOptimization<T>> DifferentiableExpression<T>.optimizeWith(
factory: OptimizationProblemFactory<T, F>, factory: OptimizationProblemFactory<T, F>,
vararg symbols: Symbol, vararg symbols: Symbol,
configuration: F.() -> Unit, configuration: F.() -> Unit,

View File

@ -27,7 +27,7 @@ public interface XYFit<T : Any> : Optimization<T> {
yErrSymbol: Symbol? = null, yErrSymbol: Symbol? = null,
) )
public fun model(model: (T) -> DifferentiableExpression<T, *>) public fun model(model: (T) -> DifferentiableExpression<T>)
/** /**
* Set the differentiable model for this fit * Set the differentiable model for this fit