forked from kscience/kmath
Another refactor of SimpleAutoDiff
This commit is contained in:
parent
4450c0fcc7
commit
9a147d033e
@ -1,9 +1,6 @@
|
|||||||
package kscience.kmath.commons.expressions
|
package kscience.kmath.commons.expressions
|
||||||
|
|
||||||
import kscience.kmath.expressions.DifferentiableExpression
|
import kscience.kmath.expressions.*
|
||||||
import kscience.kmath.expressions.Expression
|
|
||||||
import kscience.kmath.expressions.ExpressionAlgebra
|
|
||||||
import kscience.kmath.expressions.Symbol
|
|
||||||
import kscience.kmath.operations.ExtendedField
|
import kscience.kmath.operations.ExtendedField
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
|
|
||||||
@ -92,6 +89,12 @@ public class DerivativeStructureField(
|
|||||||
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||||
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||||
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||||
|
|
||||||
|
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
||||||
|
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
|
||||||
|
return DerivativeStructureExpression(function)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -71,12 +71,8 @@ public object CMFit {
|
|||||||
public fun Expression<Double>.optimize(
|
public fun Expression<Double>.optimize(
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
|
||||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
|
||||||
problem.expression(this)
|
|
||||||
return problem.optimize()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize differentiable expression
|
* Optimize differentiable expression
|
||||||
@ -84,12 +80,7 @@ public fun Expression<Double>.optimize(
|
|||||||
public fun DifferentiableExpression<Double>.optimize(
|
public fun DifferentiableExpression<Double>.optimize(
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
|
||||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
|
||||||
problem.diffExpression(this)
|
|
||||||
return problem.optimize()
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DifferentiableExpression<Double>.minimize(
|
public fun DifferentiableExpression<Double>.minimize(
|
||||||
vararg startPoint: Pair<Symbol, Double>,
|
vararg startPoint: Pair<Symbol, Double>,
|
||||||
|
@ -33,7 +33,7 @@ public class CMOptimizationProblem(
|
|||||||
|
|
||||||
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
||||||
|
|
||||||
public fun initialGuess(map: Map<Symbol, Double>): Unit {
|
public override fun initialGuess(map: Map<Symbol, Double>): Unit {
|
||||||
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,10 +94,12 @@ public class CMOptimizationProblem(
|
|||||||
return OptimizationResult(point.toMap(), value, setOf(this))
|
return OptimizationResult(point.toMap(), value, setOf(this))
|
||||||
}
|
}
|
||||||
|
|
||||||
public companion object {
|
public companion object : OptimizationProblemFactory<Double, CMOptimizationProblem> {
|
||||||
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
||||||
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
||||||
public const val DEFAULT_MAX_ITER: Int = 1000
|
public const val DEFAULT_MAX_ITER: Int = 1000
|
||||||
|
|
||||||
|
override fun build(symbols: List<Symbol>): CMOptimizationProblem = CMOptimizationProblem(symbols)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,37 +0,0 @@
|
|||||||
package kscience.kmath.commons.optimization
|
|
||||||
|
|
||||||
import kscience.kmath.expressions.DifferentiableExpression
|
|
||||||
import kscience.kmath.expressions.Expression
|
|
||||||
import kscience.kmath.expressions.Symbol
|
|
||||||
|
|
||||||
public interface OptimizationFeature
|
|
||||||
|
|
||||||
//TODO move to prob/stat
|
|
||||||
|
|
||||||
public class OptimizationResult<T>(
|
|
||||||
public val point: Map<Symbol, T>,
|
|
||||||
public val value: T,
|
|
||||||
public val features: Set<OptimizationFeature> = emptySet(),
|
|
||||||
){
|
|
||||||
override fun toString(): String {
|
|
||||||
return "OptimizationResult(point=$point, value=$value)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A configuration builder for optimization problem
|
|
||||||
*/
|
|
||||||
public interface OptimizationProblem<T : Any> {
|
|
||||||
/**
|
|
||||||
* Set an objective function expression
|
|
||||||
*/
|
|
||||||
public fun expression(expression: Expression<Double>): Unit
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public fun diffExpression(expression: DifferentiableExpression<Double>): Unit
|
|
||||||
public fun update(result: OptimizationResult<T>)
|
|
||||||
public fun optimize(): OptimizationResult<T>
|
|
||||||
}
|
|
||||||
|
|
@ -23,10 +23,9 @@ public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expressio
|
|||||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||||
derivative(StringSymbol(name) to 1)
|
derivative(StringSymbol(name) to 1)
|
||||||
|
|
||||||
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
|
/**
|
||||||
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
|
* A [DifferentiableExpression] that defines only first derivatives
|
||||||
//}
|
*/
|
||||||
|
|
||||||
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
||||||
|
|
||||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||||
@ -35,4 +34,11 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
|
|||||||
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
|
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
|
||||||
return derivativeOrNull(dSymbol)
|
return derivativeOrNull(dSymbol)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
||||||
|
*/
|
||||||
|
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
|
||||||
|
public fun process(function: A.() -> I): DifferentiableExpression<T>
|
||||||
}
|
}
|
@ -47,7 +47,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
* Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
|
||||||
*
|
*
|
||||||
* The partial derivatives are placed in argument `d` variable
|
* The partial derivatives are placed in argument `d` variable
|
||||||
*
|
*
|
||||||
@ -59,36 +59,90 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
|||||||
* assertEquals(9.0, x.d) // dy/dx
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
* ```
|
* ```
|
||||||
*
|
*
|
||||||
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||||
* @return the result of differentiation.
|
* @return the result of differentiation.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
): DerivationResult<T> {
|
): DerivationResult<T> {
|
||||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
return AutoDiffContext(this, bindings).derivate(body)
|
return SimpleAutoDiffField(this, bindings).derivate(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
vararg bindings: Pair<Symbol, T>,
|
vararg bindings: Pair<Symbol, T>,
|
||||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
|
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents field in context of which functions can be derived.
|
* Represents field in context of which functions can be derived.
|
||||||
*/
|
*/
|
||||||
public abstract class AutoDiffField<T : Any, F : Field<T>>
|
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||||
: Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
public val context: F,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||||
|
|
||||||
public abstract val context: F
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
|
private var sp: Int = 0
|
||||||
|
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||||
|
* with respect to this variable.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @property value The value of this variable.
|
||||||
|
*/
|
||||||
|
private class AutoDiffVariableWithDerivative<T : Any>(
|
||||||
|
override val identity: String,
|
||||||
|
value: T,
|
||||||
|
var d: T,
|
||||||
|
) : AutoDiffValue<T>(value), Symbol {
|
||||||
|
override fun toString(): String = identity
|
||||||
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||||
|
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
||||||
|
|
||||||
|
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
||||||
|
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
||||||
|
|
||||||
|
private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
|
||||||
|
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
private fun runBackwardPass() {
|
||||||
|
while (sp > 0) {
|
||||||
|
val value = stack[--sp]
|
||||||
|
val block = stack[--sp] as F.(Any?) -> Unit
|
||||||
|
context.block(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
||||||
|
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||||
|
|
||||||
|
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A variable accessing inner state of derivatives.
|
* A variable accessing inner state of derivatives.
|
||||||
* Use this value in inner builders to avoid creating additional derivative bindings.
|
* Use this value in inner builders to avoid creating additional derivative bindings.
|
||||||
*/
|
*/
|
||||||
public abstract var AutoDiffValue<T>.d: T
|
public var AutoDiffValue<T>.d: T
|
||||||
|
get() = getDerivative(this)
|
||||||
|
set(value) = setDerivative(this, value)
|
||||||
|
|
||||||
|
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||||
@ -101,9 +155,22 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
|
|||||||
* }
|
* }
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||||
|
// save block to stack for backward pass
|
||||||
|
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||||
|
stack[sp++] = block
|
||||||
|
stack[sp++] = value
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
|
|
||||||
|
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||||
|
val result = function()
|
||||||
|
result.d = context.one // computing derivative w.r.t result
|
||||||
|
runBackwardPass()
|
||||||
|
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
|
||||||
|
}
|
||||||
|
|
||||||
// Overloads for Double constants
|
// Overloads for Double constants
|
||||||
|
|
||||||
@ -119,68 +186,7 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
|
|||||||
|
|
||||||
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
||||||
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Automatic Differentiation context class.
|
|
||||||
*/
|
|
||||||
private class AutoDiffContext<T : Any, F : Field<T>>(
|
|
||||||
override val context: F,
|
|
||||||
bindings: Map<Symbol, T>,
|
|
||||||
) : AutoDiffField<T, F>() {
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
|
||||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
|
||||||
private var sp: Int = 0
|
|
||||||
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
|
||||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
|
||||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
|
||||||
* with respect to this variable.
|
|
||||||
*
|
|
||||||
* @param T the non-nullable type of value.
|
|
||||||
* @property value The value of this variable.
|
|
||||||
*/
|
|
||||||
private class AutoDiffVariableWithDeriv<T : Any>(
|
|
||||||
override val identity: String,
|
|
||||||
value: T,
|
|
||||||
var d: T,
|
|
||||||
) : AutoDiffValue<T>(value), Symbol{
|
|
||||||
override fun toString(): String = identity
|
|
||||||
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
|
||||||
override fun hashCode(): Int = identity.hashCode()
|
|
||||||
}
|
|
||||||
|
|
||||||
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
|
|
||||||
it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv<T>? = bindings[symbol.identity]
|
|
||||||
|
|
||||||
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
|
||||||
|
|
||||||
override var AutoDiffValue<T>.d: T
|
|
||||||
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
|
||||||
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
|
||||||
// save block to stack for backward pass
|
|
||||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
|
||||||
stack[sp++] = block
|
|
||||||
stack[sp++] = value
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
fun runBackwardPass() {
|
|
||||||
while (sp > 0) {
|
|
||||||
val value = stack[--sp]
|
|
||||||
val block = stack[--sp] as F.(Any?) -> Unit
|
|
||||||
context.block(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic math (+, -, *, /)
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
@ -206,13 +212,6 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
|||||||
derive(const { k.toDouble() * a.value }) { z ->
|
derive(const { k.toDouble() * a.value }) { z ->
|
||||||
a.d += z.d * k.toDouble()
|
a.d += z.d * k.toDouble()
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun derivate(function: AutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
|
||||||
val result = function()
|
|
||||||
result.d = context.one // computing derivative w.r.t result
|
|
||||||
runBackwardPass()
|
|
||||||
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -220,99 +219,178 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
|||||||
*/
|
*/
|
||||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
public val field: F,
|
public val field: F,
|
||||||
public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
) : FirstDerivativeExpression<T>() {
|
) : FirstDerivativeExpression<T>() {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
return AutoDiffContext(field, arguments).function().value
|
return SimpleAutoDiffField(field, arguments).function().value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
val derivationResult = AutoDiffContext(field, arguments).derivate(function)
|
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
|
||||||
derivationResult.derivative(symbol)
|
derivationResult.derivative(symbol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||||
|
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||||
|
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
|
||||||
|
return SimpleAutoDiffExpression(field, function)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
// x ^ 2
|
// x ^ 2
|
||||||
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||||
|
|
||||||
// x ^ 1/2
|
// x ^ 1/2
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||||
|
|
||||||
// x ^ y (const)
|
// x ^ y (const)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
x: AutoDiffValue<T>,
|
x: AutoDiffValue<T>,
|
||||||
y: Double,
|
y: Double,
|
||||||
): AutoDiffValue<T> =
|
): AutoDiffValue<T> =
|
||||||
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
x: AutoDiffValue<T>,
|
x: AutoDiffValue<T>,
|
||||||
y: Int,
|
y: Int,
|
||||||
): AutoDiffValue<T> = pow(x, y.toDouble())
|
): AutoDiffValue<T> = pow(x, y.toDouble())
|
||||||
|
|
||||||
// exp(x)
|
// exp(x)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
|
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||||
|
|
||||||
// ln(x)
|
// ln(x)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
|
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||||
|
|
||||||
// x ^ y (any)
|
// x ^ y (any)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
x: AutoDiffValue<T>,
|
x: AutoDiffValue<T>,
|
||||||
y: AutoDiffValue<T>,
|
y: AutoDiffValue<T>,
|
||||||
): AutoDiffValue<T> =
|
): AutoDiffValue<T> =
|
||||||
exp(y * ln(x))
|
exp(y * ln(x))
|
||||||
|
|
||||||
// sin(x)
|
// sin(x)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||||
|
|
||||||
// cos(x)
|
// cos(x)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { tan(x.value) }) { z ->
|
derive(const { tan(x.value) }) { z ->
|
||||||
val c = cos(x.value)
|
val c = cos(x.value)
|
||||||
x.d += z.d / (c * c)
|
x.d += z.d / (c * c)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { tan(x.value) }) { z ->
|
derive(const { tanh(x.value) }) { z ->
|
||||||
val c = cosh(x.value)
|
val c = cosh(x.value)
|
||||||
x.d += z.d / (c * c)
|
x.d += z.d / (c * c)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
||||||
|
|
||||||
|
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
||||||
|
context: F,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : ExtendedField<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
|
||||||
|
// x ^ 2
|
||||||
|
public fun sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sqr(x)
|
||||||
|
|
||||||
|
// x ^ 1/2
|
||||||
|
public override fun sqrt(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sqrt(arg)
|
||||||
|
|
||||||
|
// x ^ y (const)
|
||||||
|
public override fun power(arg: AutoDiffValue<T>, pow: Number): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).pow(arg, pow.toDouble())
|
||||||
|
|
||||||
|
// exp(x)
|
||||||
|
public override fun exp(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).exp(arg)
|
||||||
|
|
||||||
|
// ln(x)
|
||||||
|
public override fun ln(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).ln(arg)
|
||||||
|
|
||||||
|
// x ^ y (any)
|
||||||
|
public fun pow(
|
||||||
|
x: AutoDiffValue<T>,
|
||||||
|
y: AutoDiffValue<T>,
|
||||||
|
): AutoDiffValue<T> = exp(y * ln(x))
|
||||||
|
|
||||||
|
// sin(x)
|
||||||
|
public override fun sin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sin(arg)
|
||||||
|
|
||||||
|
// cos(x)
|
||||||
|
public override fun cos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).cos(arg)
|
||||||
|
|
||||||
|
public override fun tan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).tan(arg)
|
||||||
|
|
||||||
|
public override fun asin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).asin(arg)
|
||||||
|
|
||||||
|
public override fun acos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).acos(arg)
|
||||||
|
|
||||||
|
public override fun atan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).atan(arg)
|
||||||
|
|
||||||
|
public override fun sinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sinh(arg)
|
||||||
|
|
||||||
|
public override fun cosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).cosh(arg)
|
||||||
|
|
||||||
|
public override fun tanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).tanh(arg)
|
||||||
|
|
||||||
|
public override fun asinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).asinh(arg)
|
||||||
|
|
||||||
|
public override fun acosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).acosh(arg)
|
||||||
|
|
||||||
|
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).atanh(arg)
|
||||||
|
}
|
@ -8,11 +8,6 @@ import kotlin.contracts.InvocationKind
|
|||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
|
|
||||||
//public interface ExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>> {
|
|
||||||
// public fun expression(block: A.() -> E): Expression<T>
|
|
||||||
//}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a functional expression with this [Space].
|
* Creates a functional expression with this [Space].
|
||||||
*/
|
*/
|
||||||
|
@ -2,6 +2,7 @@ package kscience.kmath.expressions
|
|||||||
|
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
import kscience.kmath.structures.asBuffer
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.E
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
@ -13,18 +14,18 @@ class SimpleAutoDiffTest {
|
|||||||
|
|
||||||
fun dx(
|
fun dx(
|
||||||
xBinding: Pair<Symbol, Double>,
|
xBinding: Pair<Symbol, Double>,
|
||||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||||
|
|
||||||
fun dxy(
|
fun dxy(
|
||||||
xBinding: Pair<Symbol, Double>,
|
xBinding: Pair<Symbol, Double>,
|
||||||
yBinding: Pair<Symbol, Double>,
|
yBinding: Pair<Symbol, Double>,
|
||||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
|
||||||
body(bind(xBinding.first), bind(yBinding.first))
|
body(bind(xBinding.first), bind(yBinding.first))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun diff(block: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
fun diff(block: SimpleAutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
||||||
return SimpleAutoDiffExpression(RealField, block)
|
return SimpleAutoDiffExpression(RealField, block)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ class SimpleAutoDiffTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPlusX2Expr() {
|
fun testPlusX2Expr() {
|
||||||
val expr = diff{
|
val expr = diff {
|
||||||
val x = bind(x)
|
val x = bind(x)
|
||||||
x + x
|
x + x
|
||||||
}
|
}
|
||||||
@ -245,9 +246,9 @@ class SimpleAutoDiffTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testTanh() {
|
fun testTanh() {
|
||||||
val y = dx(x to PI / 6) { x -> tanh(x) }
|
val y = dx(x to 1.0) { x -> tanh(x) }
|
||||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6)
|
||||||
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
Reference in New Issue
Block a user