Remove second generic from DifferentiableExpression
This commit is contained in:
parent
12805712d3
commit
f2b7a08ad8
@ -45,6 +45,7 @@
|
|||||||
- MSTExpression
|
- MSTExpression
|
||||||
- Expression algebra builders
|
- Expression algebra builders
|
||||||
- Complex and Quaternion no longer are elements.
|
- Complex and Quaternion no longer are elements.
|
||||||
|
- Second generic from DifferentiableExpression
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
- Ring inherits RingOperations, not GroupOperations
|
- Ring inherits RingOperations, not GroupOperations
|
||||||
|
@ -106,7 +106,7 @@ public class DerivativeStructureField(
|
|||||||
|
|
||||||
public companion object :
|
public companion object :
|
||||||
AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> {
|
AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> {
|
||||||
public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double, Expression<Double>> =
|
public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> =
|
||||||
DerivativeStructureExpression(function)
|
DerivativeStructureExpression(function)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -116,7 +116,7 @@ public class DerivativeStructureField(
|
|||||||
*/
|
*/
|
||||||
public class DerivativeStructureExpression(
|
public class DerivativeStructureExpression(
|
||||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||||
) : DifferentiableExpression<Double, Expression<Double>> {
|
) : DifferentiableExpression<Double> {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
DerivativeStructureField(0, arguments).function().value
|
DerivativeStructureField(0, arguments).function().value
|
||||||
|
|
||||||
|
@ -17,14 +17,7 @@ import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
|||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.optimization.*
|
import space.kscience.kmath.optimization.*
|
||||||
import kotlin.collections.HashMap
|
|
||||||
import kotlin.collections.List
|
|
||||||
import kotlin.collections.Map
|
|
||||||
import kotlin.collections.set
|
import kotlin.collections.set
|
||||||
import kotlin.collections.setOf
|
|
||||||
import kotlin.collections.toList
|
|
||||||
import kotlin.collections.toMap
|
|
||||||
import kotlin.collections.toTypedArray
|
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||||
@ -71,7 +64,7 @@ public class CMOptimization(
|
|||||||
addOptimizationData(objectiveFunction)
|
addOptimizationData(objectiveFunction)
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun diffFunction(expression: DifferentiableExpression<Double, Expression<Double>>) {
|
public override fun diffFunction(expression: DifferentiableExpression<Double>) {
|
||||||
function(expression)
|
function(expression)
|
||||||
val gradientFunction = ObjectiveFunctionGradient {
|
val gradientFunction = ObjectiveFunctionGradient {
|
||||||
val args = it.toMap()
|
val args = it.toMap()
|
||||||
|
@ -25,7 +25,7 @@ public fun FunctionOptimization.Companion.chiSquared(
|
|||||||
y: Buffer<Double>,
|
y: Buffer<Double>,
|
||||||
yErr: Buffer<Double>,
|
yErr: Buffer<Double>,
|
||||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
): DifferentiableExpression<Double> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
@ -35,7 +35,7 @@ public fun FunctionOptimization.Companion.chiSquared(
|
|||||||
y: Iterable<Double>,
|
y: Iterable<Double>,
|
||||||
yErr: Iterable<Double>,
|
yErr: Iterable<Double>,
|
||||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(
|
): DifferentiableExpression<Double> = chiSquared(
|
||||||
DerivativeStructureField,
|
DerivativeStructureField,
|
||||||
x.toList().asBuffer(),
|
x.toList().asBuffer(),
|
||||||
y.toList().asBuffer(),
|
y.toList().asBuffer(),
|
||||||
@ -54,12 +54,12 @@ public fun Expression<Double>.optimize(
|
|||||||
/**
|
/**
|
||||||
* Optimize differentiable expression
|
* Optimize differentiable expression
|
||||||
*/
|
*/
|
||||||
public fun DifferentiableExpression<Double, Expression<Double>>.optimize(
|
public fun DifferentiableExpression<Double>.optimize(
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: CMOptimization.() -> Unit,
|
configuration: CMOptimization.() -> Unit,
|
||||||
): OptimizationResult<Double> = optimizeWith(CMOptimization, symbols = symbols, configuration)
|
): OptimizationResult<Double> = optimizeWith(CMOptimization, symbols = symbols, configuration)
|
||||||
|
|
||||||
public fun DifferentiableExpression<Double, Expression<Double>>.minimize(
|
public fun DifferentiableExpression<Double>.minimize(
|
||||||
vararg startPoint: Pair<Symbol, Double>,
|
vararg startPoint: Pair<Symbol, Double>,
|
||||||
configuration: CMOptimization.() -> Unit = {},
|
configuration: CMOptimization.() -> Unit = {},
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> {
|
||||||
|
@ -11,35 +11,51 @@ package space.kscience.kmath.expressions
|
|||||||
* @param T the type this expression takes as argument and returns.
|
* @param T the type this expression takes as argument and returns.
|
||||||
* @param R the type of expression this expression can be differentiated to.
|
* @param R the type of expression this expression can be differentiated to.
|
||||||
*/
|
*/
|
||||||
public interface DifferentiableExpression<T, out R : Expression<T>> : Expression<T> {
|
public interface DifferentiableExpression<T> : Expression<T> {
|
||||||
/**
|
/**
|
||||||
* Differentiates this expression by ordered collection of [symbols].
|
* Differentiates this expression by ordered collection of [symbols].
|
||||||
*
|
*
|
||||||
* @param symbols the symbols.
|
* @param symbols the symbols.
|
||||||
* @return the derivative or `null`.
|
* @return the derivative or `null`.
|
||||||
*/
|
*/
|
||||||
public fun derivativeOrNull(symbols: List<Symbol>): R?
|
public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>?
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R =
|
public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> =
|
||||||
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
||||||
|
|
||||||
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R =
|
public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> =
|
||||||
derivative(symbols.toList())
|
derivative(symbols.toList())
|
||||||
|
|
||||||
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(name: String): R =
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||||
|
derivative(StringSymbol(name))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A special type of [DifferentiableExpression] which returns typed expressions as derivatives
|
||||||
|
*/
|
||||||
|
public interface SpecialDifferentiableExpression<T, R: Expression<T>>: DifferentiableExpression<T> {
|
||||||
|
override fun derivativeOrNull(symbols: List<Symbol>): R?
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T, R : Expression<T>> SpecialDifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R =
|
||||||
|
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
||||||
|
|
||||||
|
public fun <T, R : Expression<T>> SpecialDifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R =
|
||||||
|
derivative(symbols.toList())
|
||||||
|
|
||||||
|
public fun <T, R : Expression<T>> SpecialDifferentiableExpression<T, R>.derivative(name: String): R =
|
||||||
derivative(StringSymbol(name))
|
derivative(StringSymbol(name))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A [DifferentiableExpression] that defines only first derivatives
|
* A [DifferentiableExpression] that defines only first derivatives
|
||||||
*/
|
*/
|
||||||
public abstract class FirstDerivativeExpression<T, R : Expression<T>> : DifferentiableExpression<T, R> {
|
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
||||||
/**
|
/**
|
||||||
* Returns first derivative of this expression by given [symbol].
|
* Returns first derivative of this expression by given [symbol].
|
||||||
*/
|
*/
|
||||||
public abstract fun derivativeOrNull(symbol: Symbol): R?
|
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||||
|
|
||||||
public final override fun derivativeOrNull(symbols: List<Symbol>): R? {
|
public final override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
|
||||||
val dSymbol = symbols.firstOrNull() ?: return null
|
val dSymbol = symbols.firstOrNull() ?: return null
|
||||||
return derivativeOrNull(dSymbol)
|
return derivativeOrNull(dSymbol)
|
||||||
}
|
}
|
||||||
@ -49,5 +65,5 @@ public abstract class FirstDerivativeExpression<T, R : Expression<T>> : Differen
|
|||||||
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
||||||
*/
|
*/
|
||||||
public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
|
public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
|
||||||
public fun process(function: A.() -> I): DifferentiableExpression<T, R>
|
public fun process(function: A.() -> I): DifferentiableExpression<T>
|
||||||
}
|
}
|
||||||
|
@ -232,7 +232,7 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
|||||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
public val field: F,
|
public val field: F,
|
||||||
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
) : FirstDerivativeExpression<T, Expression<T>>() {
|
) : FirstDerivativeExpression<T>() {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
return SimpleAutoDiffField(field, arguments).function().value
|
return SimpleAutoDiffField(field, arguments).function().value
|
||||||
|
@ -24,7 +24,7 @@ import space.kscience.kmath.operations.NumericAlgebra
|
|||||||
public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
|
public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
|
||||||
public val algebra: A,
|
public val algebra: A,
|
||||||
public val mst: MST,
|
public val mst: MST,
|
||||||
) : DifferentiableExpression<T, KotlingradExpression<T, A>> {
|
) : SpecialDifferentiableExpression<T, KotlingradExpression<T, A>> {
|
||||||
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
|
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
|
||||||
|
|
||||||
public override fun derivativeOrNull(symbols: List<Symbol>): KotlingradExpression<T, A> =
|
public override fun derivativeOrNull(symbols: List<Symbol>): KotlingradExpression<T, A> =
|
||||||
|
@ -27,7 +27,7 @@ public interface FunctionOptimization<T : Any> : Optimization<T> {
|
|||||||
/**
|
/**
|
||||||
* Set a differentiable expression as objective function as function and gradient provider
|
* Set a differentiable expression as objective function as function and gradient provider
|
||||||
*/
|
*/
|
||||||
public fun diffFunction(expression: DifferentiableExpression<T, Expression<T>>)
|
public fun diffFunction(expression: DifferentiableExpression<T>)
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
@ -39,7 +39,7 @@ public interface FunctionOptimization<T : Any> : Optimization<T> {
|
|||||||
y: Buffer<T>,
|
y: Buffer<T>,
|
||||||
yErr: Buffer<T>,
|
yErr: Buffer<T>,
|
||||||
model: A.(I) -> I,
|
model: A.(I) -> I,
|
||||||
): DifferentiableExpression<T, Expression<T>> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||||
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ public fun <T: Any, I : Any, A> FunctionOptimization<T>.chiSquared(
|
|||||||
/**
|
/**
|
||||||
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : FunctionOptimization<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
public fun <T : Any, F : FunctionOptimization<T>> DifferentiableExpression<T>.optimizeWith(
|
||||||
factory: OptimizationProblemFactory<T, F>,
|
factory: OptimizationProblemFactory<T, F>,
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: F.() -> Unit,
|
configuration: F.() -> Unit,
|
||||||
|
@ -27,7 +27,7 @@ public interface XYFit<T : Any> : Optimization<T> {
|
|||||||
yErrSymbol: Symbol? = null,
|
yErrSymbol: Symbol? = null,
|
||||||
)
|
)
|
||||||
|
|
||||||
public fun model(model: (T) -> DifferentiableExpression<T, *>)
|
public fun model(model: (T) -> DifferentiableExpression<T>)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the differentiable model for this fit
|
* Set the differentiable model for this fit
|
||||||
|
Loading…
Reference in New Issue
Block a user