Feature/diff api #154
@ -1 +1,3 @@
|
|||||||
job("Build") { gradlew("openjdk:11", "build") }
|
job("Build") {
|
||||||
|
gradlew("openjdk:11", "build")
|
||||||
|
}
|
||||||
|
@ -106,7 +106,7 @@ public class DerivativeStructureExpression(
|
|||||||
/**
|
/**
|
||||||
* Get the derivative expression with given orders
|
* Get the derivative expression with given orders
|
||||||
*/
|
*/
|
||||||
public override fun derivative(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
||||||
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,100 @@
|
|||||||
|
package kscience.kmath.commons.optimization
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.*
|
||||||
|
import org.apache.commons.math3.optim.*
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||||
|
public operator fun PointValuePair.component2(): Double = value
|
||||||
|
|
||||||
|
public class CMOptimizationProblem(
|
||||||
|
override val symbols: List<Symbol>,
|
||||||
|
) : OptimizationProblem<Double>, SymbolIndexer {
|
||||||
|
protected val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||||
|
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
||||||
|
|
||||||
|
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
||||||
|
DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER)
|
||||||
|
|
||||||
|
private fun addOptimizationData(data: OptimizationData) {
|
||||||
|
optimizationData[data::class] = data
|
||||||
|
}
|
||||||
|
|
||||||
|
init {
|
||||||
|
addOptimizationData(MaxEval.unlimited())
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun initialGuess(map: Map<Symbol, Double>): Unit {
|
||||||
|
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun expression(expression: Expression<Double>): Unit {
|
||||||
|
val objectiveFunction = ObjectiveFunction {
|
||||||
|
val args = it.toMap()
|
||||||
|
expression(args)
|
||||||
|
}
|
||||||
|
addOptimizationData(objectiveFunction)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun derivatives(expression: DifferentiableExpression<Double>): Unit {
|
||||||
|
expression(expression)
|
||||||
|
val gradientFunction = ObjectiveFunctionGradient {
|
||||||
|
val args = it.toMap()
|
||||||
|
DoubleArray(symbols.size) { index ->
|
||||||
|
expression.derivative(symbols[index])(args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
addOptimizationData(gradientFunction)
|
||||||
|
if (optimizatorBuilder == null) {
|
||||||
|
optimizatorBuilder = {
|
||||||
|
NonLinearConjugateGradientOptimizer(
|
||||||
|
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
||||||
|
convergenceChecker
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun simplex(simplex: AbstractSimplex) {
|
||||||
|
addOptimizationData(simplex)
|
||||||
|
//Set optimization builder to simplex if it is not present
|
||||||
|
if (optimizatorBuilder == null) {
|
||||||
|
optimizatorBuilder = { SimplexOptimizer(convergenceChecker) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun simplexSteps(steps: Map<Symbol, Double>) {
|
||||||
|
simplex(NelderMeadSimplex(steps.toDoubleArray()))
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun goal(goalType: GoalType) {
|
||||||
|
addOptimizationData(goalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun optimizer(block: () -> MultivariateOptimizer) {
|
||||||
|
optimizatorBuilder = block
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun optimize(): OptimizationResult<Double> {
|
||||||
|
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
|
||||||
|
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
||||||
|
return OptimizationResult(point.toMap(), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
||||||
|
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
||||||
|
public const val DEFAULT_MAX_ITER: Int = 1000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun CMOptimizationProblem.initialGuess(vararg pairs: Pair<Symbol, Double>): Unit = initialGuess(pairs.toMap())
|
||||||
|
public fun CMOptimizationProblem.simplexSteps(vararg pairs: Pair<Symbol, Double>): Unit = simplexSteps(pairs.toMap())
|
@ -0,0 +1,17 @@
|
|||||||
|
package kscience.kmath.commons.optimization
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
public typealias ParameterSpacePoint<T> = Map<Symbol, T>
|
||||||
|
|
||||||
|
public class OptimizationResult<T>(
|
||||||
|
public val point: ParameterSpacePoint<T>,
|
||||||
|
public val value: T,
|
||||||
|
public val extra: Map<KClass<*>, Any> = emptyMap()
|
||||||
|
)
|
||||||
|
|
||||||
|
public interface OptimizationProblem<T : Any> {
|
||||||
|
public fun optimize(): OptimizationResult<T>
|
||||||
|
}
|
||||||
|
|
@ -1,103 +1,32 @@
|
|||||||
package kscience.kmath.commons.optimization
|
package kscience.kmath.commons.optimization
|
||||||
|
|
||||||
import kscience.kmath.expressions.*
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
import org.apache.commons.math3.optim.*
|
import kscience.kmath.expressions.Expression
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
import kscience.kmath.expressions.Symbol
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer
|
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
|
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
|
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
|
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
|
||||||
|
|
||||||
public typealias ParameterSpacePoint = Map<Symbol, Double>
|
|
||||||
|
|
||||||
public class OptimizationResult(public val point: ParameterSpacePoint, public val value: Double)
|
|
||||||
|
|
||||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
|
||||||
public operator fun PointValuePair.component2(): Double = value
|
|
||||||
|
|
||||||
public object Optimization {
|
|
||||||
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
|
||||||
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
|
||||||
public const val DEFAULT_MAX_ITER: Int = 1000
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private fun SymbolIndexer.objectiveFunction(expression: Expression<Double>) = ObjectiveFunction {
|
|
||||||
val args = it.toMap()
|
|
||||||
expression(args)
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun SymbolIndexer.objectiveFunctionGradient(
|
|
||||||
expression: DifferentiableExpression<Double>,
|
|
||||||
) = ObjectiveFunctionGradient {
|
|
||||||
val args = it.toMap()
|
|
||||||
DoubleArray(symbols.size) { index ->
|
|
||||||
expression.derivative(symbols[index])(args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun SymbolIndexer.initialGuess(point: ParameterSpacePoint) = InitialGuess(point.toArray())
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize expression without derivatives
|
* Optimize expression without derivatives
|
||||||
*/
|
*/
|
||||||
public fun Expression<Double>.optimize(
|
public fun Expression<Double>.optimize(
|
||||||
startingPoint: ParameterSpacePoint,
|
vararg symbols: Symbol,
|
||||||
goalType: GoalType = GoalType.MAXIMIZE,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
vararg additionalArguments: OptimizationData,
|
): OptimizationResult<Double> {
|
||||||
optimizerBuilder: () -> MultivariateOptimizer = {
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
SimplexOptimizer(
|
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
|
||||||
SimpleValueChecker(
|
problem.expression(this)
|
||||||
Optimization.DEFAULT_RELATIVE_TOLERANCE,
|
return problem.optimize()
|
||||||
Optimization.DEFAULT_ABSOLUTE_TOLERANCE,
|
|
||||||
Optimization.DEFAULT_MAX_ITER
|
|
||||||
)
|
|
||||||
)
|
|
||||||
},
|
|
||||||
): OptimizationResult = withSymbols(startingPoint.keys) {
|
|
||||||
val optimizer = optimizerBuilder()
|
|
||||||
val objectiveFunction = objectiveFunction(this@optimize)
|
|
||||||
val (point, value) = optimizer.optimize(
|
|
||||||
objectiveFunction,
|
|
||||||
initialGuess(startingPoint),
|
|
||||||
goalType,
|
|
||||||
MaxEval.unlimited(),
|
|
||||||
NelderMeadSimplex(symbols.size, 1.0),
|
|
||||||
*additionalArguments
|
|
||||||
)
|
|
||||||
OptimizationResult(point.toMap(), value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize differentiable expression
|
* Optimize differentiable expression
|
||||||
*/
|
*/
|
||||||
public fun DifferentiableExpression<Double>.optimize(
|
public fun DifferentiableExpression<Double>.optimize(
|
||||||
startingPoint: ParameterSpacePoint,
|
vararg symbols: Symbol,
|
||||||
goalType: GoalType = GoalType.MAXIMIZE,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
vararg additionalArguments: OptimizationData,
|
): OptimizationResult<Double> {
|
||||||
optimizerBuilder: () -> NonLinearConjugateGradientOptimizer = {
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
NonLinearConjugateGradientOptimizer(
|
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration).apply(configuration)
|
||||||
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
problem.derivatives(this)
|
||||||
SimpleValueChecker(
|
return problem.optimize()
|
||||||
Optimization.DEFAULT_RELATIVE_TOLERANCE,
|
|
||||||
Optimization.DEFAULT_ABSOLUTE_TOLERANCE,
|
|
||||||
Optimization.DEFAULT_MAX_ITER
|
|
||||||
)
|
|
||||||
)
|
|
||||||
},
|
|
||||||
): OptimizationResult = withSymbols(startingPoint.keys) {
|
|
||||||
val optimizer = optimizerBuilder()
|
|
||||||
val objectiveFunction = objectiveFunction(this@optimize)
|
|
||||||
val objectiveGradient = objectiveFunctionGradient(this@optimize)
|
|
||||||
val (point, value) = optimizer.optimize(
|
|
||||||
objectiveFunction,
|
|
||||||
objectiveGradient,
|
|
||||||
initialGuess(startingPoint),
|
|
||||||
goalType,
|
|
||||||
MaxEval.unlimited(),
|
|
||||||
*additionalArguments
|
|
||||||
)
|
|
||||||
OptimizationResult(point.toMap(), value)
|
|
||||||
}
|
}
|
@ -1,10 +1,7 @@
|
|||||||
package kscience.kmath.commons.optimization
|
package kscience.kmath.commons.optimization
|
||||||
|
|
||||||
import kscience.kmath.commons.expressions.DerivativeStructureExpression
|
import kscience.kmath.commons.expressions.DerivativeStructureExpression
|
||||||
import kscience.kmath.expressions.Expression
|
|
||||||
import kscience.kmath.expressions.Symbol
|
|
||||||
import kscience.kmath.expressions.symbol
|
import kscience.kmath.expressions.symbol
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
|
|
||||||
internal class OptimizeTest {
|
internal class OptimizeTest {
|
||||||
@ -17,19 +14,22 @@ internal class OptimizeTest {
|
|||||||
exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2)
|
exp(-x.pow(2) / 2) + exp(-y.pow(2) / 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
val startingPoint: Map<Symbol, Double> = mapOf(x to 1.0, y to 1.0)
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testOptimization() {
|
fun testOptimization() {
|
||||||
val result = normal.optimize(startingPoint)
|
val result = normal.optimize(x, y) {
|
||||||
|
initialGuess(x to 1.0, y to 1.0)
|
||||||
|
//no need to select optimizer. Gradient optimizer is used by default
|
||||||
|
}
|
||||||
println(result.point)
|
println(result.point)
|
||||||
println(result.value)
|
println(result.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSimplexOptimization() {
|
fun testSimplexOptimization() {
|
||||||
val result = (normal as Expression<Double>).optimize(startingPoint){
|
val result = normal.optimize(x, y) {
|
||||||
SimplexOptimizer(1e-4,1e-4)
|
initialGuess(x to 1.0, y to 1.0)
|
||||||
|
simplexSteps(x to 2.0, y to 0.5)
|
||||||
|
//this sets simplex optimizer
|
||||||
}
|
}
|
||||||
println(result.point)
|
println(result.point)
|
||||||
println(result.value)
|
println(result.value)
|
||||||
|
@ -4,9 +4,15 @@ package kscience.kmath.expressions
|
|||||||
|
|||||||
* And object that could be differentiated
|
* And object that could be differentiated
|
||||||
*/
|
*/
|
||||||
public interface Differentiable<T> {
|
public interface Differentiable<T> {
|
||||||
public fun derivative(orders: Map<Symbol, Int>): T
|
public fun derivativeOrNull(orders: Map<Symbol, Int>): T?
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
public fun <T> Differentiable<T>.derivative(orders: Map<Symbol, Int>): T =
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided")
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
/**
|
||||||
Overload with Overload with `vararg symbols: Symbol` for order 1 can be added, too.
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
No sense in that. People will use first, maximum second derivatives. There is an extension for the first. The second one could be added any moment. No sense in that. People will use first, maximum second derivatives. There is an extension for the first. The second one could be added any moment.
|
|||||||
|
* An expression that provid
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
*/
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||||
@ -14,8 +20,19 @@ public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol
|
|||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
Removed Removed
|
|||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
derivative(StringSymbol(name) to 1)
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
|
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
|
||||||
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
|
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
|
||||||
//}
|
//}
|
||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>? {
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
return derivativeOrNull(dSymbol)
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
}
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
}
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
SAM SAM
I see no reason for that. This interface is mostly implemented by companion objects. Also I plan to add additional methods later. I see no reason for that. This interface is mostly implemented by companion objects. Also I plan to add additional methods later.
|
@ -221,23 +221,16 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
|||||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
public val field: F,
|
public val field: F,
|
||||||
public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
) : DifferentiableExpression<T> {
|
) : FirstDerivativeExpression<T>() {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
return AutoDiffContext(field, arguments).function().value
|
return AutoDiffContext(field, arguments).function().value
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||||
* Get the derivative expression with given orders
|
|
||||||
*/
|
|
||||||
public override fun derivative(orders: Map<Symbol, Int>): Expression<T> {
|
|
||||||
val dSymbol = orders.entries.singleOrNull { it.value == 1 }
|
|
||||||
?: error("SimpleAutoDiff supports only first order derivatives")
|
|
||||||
return Expression { arguments ->
|
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
val derivationResult = AutoDiffContext(field, arguments).derivate(function)
|
val derivationResult = AutoDiffContext(field, arguments).derivate(function)
|
||||||
derivationResult.derivative(dSymbol.key)
|
derivationResult.derivative(symbol)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
package kscience.kmath.expressions
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.structures.BufferFactory
|
||||||
|
import kscience.kmath.structures.Structure2D
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An environment to easy transform indexed variables to symbols and back.
|
* An environment to easy transform indexed variables to symbols and back.
|
||||||
|
* TODO requires multi-receivers to be beutiful
|
||||||
*/
|
*/
|
||||||
public interface SymbolIndexer {
|
public interface SymbolIndexer {
|
||||||
public val symbols: List<Symbol>
|
public val symbols: List<Symbol>
|
||||||
@ -22,15 +27,26 @@ public interface SymbolIndexer {
|
|||||||
return get(this@SymbolIndexer.indexOf(symbol))
|
return get(this@SymbolIndexer.indexOf(symbol))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public operator fun <T> Point<T>.get(symbol: Symbol): T {
|
||||||
|
require(size == symbols.size) { "The input buffer size for indexer should be ${symbols.size} but $size found" }
|
||||||
|
return get(this@SymbolIndexer.indexOf(symbol))
|
||||||
|
}
|
||||||
|
|
||||||
public fun DoubleArray.toMap(): Map<Symbol, Double> {
|
public fun DoubleArray.toMap(): Map<Symbol, Double> {
|
||||||
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
|
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
|
||||||
return symbols.indices.associate { symbols[it] to get(it) }
|
return symbols.indices.associate { symbols[it] to get(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public operator fun <T> Structure2D<T>.get(rowSymbol: Symbol, columnSymbol: Symbol): T =
|
||||||
|
get(indexOf(rowSymbol), indexOf(columnSymbol))
|
||||||
|
|
||||||
|
|
||||||
public fun <T> Map<Symbol, T>.toList(): List<T> = symbols.map { getValue(it) }
|
public fun <T> Map<Symbol, T>.toList(): List<T> = symbols.map { getValue(it) }
|
||||||
|
|
||||||
public fun Map<Symbol, Double>.toArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) }
|
public fun <T> Map<Symbol, T>.toPoint(bufferFactory: BufferFactory<T>): Point<T> =
|
||||||
|
bufferFactory(symbols.size) { getValue(symbols[it]) }
|
||||||
|
|
||||||
|
public fun Map<Symbol, Double>.toDoubleArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) }
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline class SimpleSymbolIndexer(override val symbols: List<Symbol>) : SymbolIndexer
|
public inline class SimpleSymbolIndexer(override val symbols: List<Symbol>) : SymbolIndexer
|
||||||
|
Loading…
Reference in New Issue
Block a user
Removed
Removed