Feature/diff api #154
@ -1,9 +1,6 @@
|
|||||||
package kscience.kmath.commons.expressions
|
package kscience.kmath.commons.expressions
|
||||||
|
|
||||||
import kscience.kmath.expressions.DifferentiableExpression
|
import kscience.kmath.expressions.*
|
||||||
import kscience.kmath.expressions.Expression
|
|
||||||
import kscience.kmath.expressions.ExpressionAlgebra
|
|
||||||
import kscience.kmath.expressions.Symbol
|
|
||||||
import kscience.kmath.operations.ExtendedField
|
import kscience.kmath.operations.ExtendedField
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
|
|
||||||
@ -92,6 +89,12 @@ public class DerivativeStructureField(
|
|||||||
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||||
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||||
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||||
|
|
||||||
|
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
||||||
|
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
|
||||||
|
return DerivativeStructureExpression(function)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -71,12 +71,8 @@ public object CMFit {
|
|||||||
public fun Expression<Double>.optimize(
|
public fun Expression<Double>.optimize(
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
|
||||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
|
||||||
problem.expression(this)
|
|
||||||
return problem.optimize()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize differentiable expression
|
* Optimize differentiable expression
|
||||||
@ -84,12 +80,7 @@ public fun Expression<Double>.optimize(
|
|||||||
public fun DifferentiableExpression<Double>.optimize(
|
public fun DifferentiableExpression<Double>.optimize(
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
|
||||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
|
||||||
problem.diffExpression(this)
|
|
||||||
return problem.optimize()
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DifferentiableExpression<Double>.minimize(
|
public fun DifferentiableExpression<Double>.minimize(
|
||||||
vararg startPoint: Pair<Symbol, Double>,
|
vararg startPoint: Pair<Symbol, Double>,
|
||||||
|
@ -33,7 +33,7 @@ public class CMOptimizationProblem(
|
|||||||
|
|
||||||
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
||||||
|
|
||||||
public fun initialGuess(map: Map<Symbol, Double>): Unit {
|
public override fun initialGuess(map: Map<Symbol, Double>): Unit {
|
||||||
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,10 +94,12 @@ public class CMOptimizationProblem(
|
|||||||
return OptimizationResult(point.toMap(), value, setOf(this))
|
return OptimizationResult(point.toMap(), value, setOf(this))
|
||||||
}
|
}
|
||||||
|
|
||||||
public companion object {
|
public companion object : OptimizationProblemFactory<Double, CMOptimizationProblem> {
|
||||||
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
||||||
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
||||||
public const val DEFAULT_MAX_ITER: Int = 1000
|
public const val DEFAULT_MAX_ITER: Int = 1000
|
||||||
|
|
||||||
|
override fun build(symbols: List<Symbol>): CMOptimizationProblem = CMOptimizationProblem(symbols)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,37 +0,0 @@
|
|||||||
package kscience.kmath.commons.optimization
|
|
||||||
|
|
||||||
import kscience.kmath.expressions.DifferentiableExpression
|
|
||||||
import kscience.kmath.expressions.Expression
|
|
||||||
import kscience.kmath.expressions.Symbol
|
|
||||||
|
|
||||||
public interface OptimizationFeature
|
|
||||||
|
|
||||||
//TODO move to prob/stat
|
|
||||||
|
|
||||||
public class OptimizationResult<T>(
|
|
||||||
public val point: Map<Symbol, T>,
|
|
||||||
public val value: T,
|
|
||||||
public val features: Set<OptimizationFeature> = emptySet(),
|
|
||||||
){
|
|
||||||
override fun toString(): String {
|
|
||||||
return "OptimizationResult(point=$point, value=$value)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A configuration builder for optimization problem
|
|
||||||
*/
|
|
||||||
public interface OptimizationProblem<T : Any> {
|
|
||||||
/**
|
|
||||||
* Set an objective function expression
|
|
||||||
*/
|
|
||||||
public fun expression(expression: Expression<Double>): Unit
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public fun diffExpression(expression: DifferentiableExpression<Double>): Unit
|
|
||||||
public fun update(result: OptimizationResult<T>)
|
|
||||||
public fun optimize(): OptimizationResult<T>
|
|
||||||
}
|
|
||||||
|
|
@ -23,10 +23,9 @@ public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expressio
|
|||||||
|
|||||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||||
derivative(StringSymbol(name) to 1)
|
derivative(StringSymbol(name) to 1)
|
||||||
|
|
||||||
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
|
/**
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
|
* A [DifferentiableExpression] that defines only first derivatives
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
//}
|
*/
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
||||||
|
|
||||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||||
@ -36,3 +35,10 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
|
|||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
Removed Removed
|
|||||||
return derivativeOrNull(dSymbol)
|
return derivativeOrNull(dSymbol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
SAM SAM
I see no reason for that. This interface is mostly implemented by companion objects. Also I plan to add additional methods later. I see no reason for that. This interface is mostly implemented by companion objects. Also I plan to add additional methods later.
|
|||||||
|
/**
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
*/
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public fun process(function: A.() -> I): DifferentiableExpression<T>
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
}
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
@ -47,7 +47,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
* Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
|
||||||
*
|
*
|
||||||
* The partial derivatives are placed in argument `d` variable
|
* The partial derivatives are placed in argument `d` variable
|
||||||
*
|
*
|
||||||
@ -59,36 +59,90 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
|||||||
* assertEquals(9.0, x.d) // dy/dx
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
* ```
|
* ```
|
||||||
*
|
*
|
||||||
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||||
* @return the result of differentiation.
|
* @return the result of differentiation.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
): DerivationResult<T> {
|
): DerivationResult<T> {
|
||||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
return AutoDiffContext(this, bindings).derivate(body)
|
return SimpleAutoDiffField(this, bindings).derivate(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
vararg bindings: Pair<Symbol, T>,
|
vararg bindings: Pair<Symbol, T>,
|
||||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
|
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents field in context of which functions can be derived.
|
* Represents field in context of which functions can be derived.
|
||||||
*/
|
*/
|
||||||
public abstract class AutoDiffField<T : Any, F : Field<T>>
|
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||||
: Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
public val context: F,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||||
|
|
||||||
public abstract val context: F
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
|
private var sp: Int = 0
|
||||||
|
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||||
|
* with respect to this variable.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @property value The value of this variable.
|
||||||
|
*/
|
||||||
|
private class AutoDiffVariableWithDerivative<T : Any>(
|
||||||
|
override val identity: String,
|
||||||
|
value: T,
|
||||||
|
var d: T,
|
||||||
|
) : AutoDiffValue<T>(value), Symbol {
|
||||||
|
override fun toString(): String = identity
|
||||||
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||||
|
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
||||||
|
|
||||||
|
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
||||||
|
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
||||||
|
|
||||||
|
private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
|
||||||
|
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
private fun runBackwardPass() {
|
||||||
|
while (sp > 0) {
|
||||||
|
val value = stack[--sp]
|
||||||
|
val block = stack[--sp] as F.(Any?) -> Unit
|
||||||
|
context.block(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
||||||
|
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||||
|
|
||||||
|
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A variable accessing inner state of derivatives.
|
* A variable accessing inner state of derivatives.
|
||||||
* Use this value in inner builders to avoid creating additional derivative bindings.
|
* Use this value in inner builders to avoid creating additional derivative bindings.
|
||||||
*/
|
*/
|
||||||
public abstract var AutoDiffValue<T>.d: T
|
public var AutoDiffValue<T>.d: T
|
||||||
|
get() = getDerivative(this)
|
||||||
|
set(value) = setDerivative(this, value)
|
||||||
|
|
||||||
|
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||||
@ -101,9 +155,22 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
|
|||||||
* }
|
* }
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||||
|
// save block to stack for backward pass
|
||||||
|
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||||
|
stack[sp++] = block
|
||||||
|
stack[sp++] = value
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
|
|
||||||
|
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||||
|
val result = function()
|
||||||
|
result.d = context.one // computing derivative w.r.t result
|
||||||
|
runBackwardPass()
|
||||||
|
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
|
||||||
|
}
|
||||||
|
|
||||||
// Overloads for Double constants
|
// Overloads for Double constants
|
||||||
|
|
||||||
@ -119,68 +186,7 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
|
|||||||
|
|
||||||
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
||||||
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Automatic Differentiation context class.
|
|
||||||
*/
|
|
||||||
private class AutoDiffContext<T : Any, F : Field<T>>(
|
|
||||||
override val context: F,
|
|
||||||
bindings: Map<Symbol, T>,
|
|
||||||
) : AutoDiffField<T, F>() {
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
|
||||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
|
||||||
private var sp: Int = 0
|
|
||||||
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
|
||||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
|
||||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
|
||||||
* with respect to this variable.
|
|
||||||
*
|
|
||||||
* @param T the non-nullable type of value.
|
|
||||||
* @property value The value of this variable.
|
|
||||||
*/
|
|
||||||
private class AutoDiffVariableWithDeriv<T : Any>(
|
|
||||||
override val identity: String,
|
|
||||||
value: T,
|
|
||||||
var d: T,
|
|
||||||
) : AutoDiffValue<T>(value), Symbol{
|
|
||||||
override fun toString(): String = identity
|
|
||||||
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
|
||||||
override fun hashCode(): Int = identity.hashCode()
|
|
||||||
}
|
|
||||||
|
|
||||||
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
|
|
||||||
it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv<T>? = bindings[symbol.identity]
|
|
||||||
|
|
||||||
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
|
||||||
|
|
||||||
override var AutoDiffValue<T>.d: T
|
|
||||||
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
|
||||||
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
|
||||||
// save block to stack for backward pass
|
|
||||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
|
||||||
stack[sp++] = block
|
|
||||||
stack[sp++] = value
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
fun runBackwardPass() {
|
|
||||||
while (sp > 0) {
|
|
||||||
val value = stack[--sp]
|
|
||||||
val block = stack[--sp] as F.(Any?) -> Unit
|
|
||||||
context.block(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic math (+, -, *, /)
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
@ -206,13 +212,6 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
|||||||
derive(const { k.toDouble() * a.value }) { z ->
|
derive(const { k.toDouble() * a.value }) { z ->
|
||||||
a.d += z.d * k.toDouble()
|
a.d += z.d * k.toDouble()
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun derivate(function: AutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
|
||||||
val result = function()
|
|
||||||
result.d = context.one // computing derivative w.r.t result
|
|
||||||
runBackwardPass()
|
|
||||||
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -220,99 +219,178 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
|||||||
*/
|
*/
|
||||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
public val field: F,
|
public val field: F,
|
||||||
public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
) : FirstDerivativeExpression<T>() {
|
) : FirstDerivativeExpression<T>() {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
return AutoDiffContext(field, arguments).function().value
|
return SimpleAutoDiffField(field, arguments).function().value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
val derivationResult = AutoDiffContext(field, arguments).derivate(function)
|
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
|
||||||
derivationResult.derivative(symbol)
|
derivationResult.derivative(symbol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||||
|
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||||
|
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
|
||||||
|
return SimpleAutoDiffExpression(field, function)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
// x ^ 2
|
// x ^ 2
|
||||||
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||||
|
|
||||||
// x ^ 1/2
|
// x ^ 1/2
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||||
|
|
||||||
// x ^ y (const)
|
// x ^ y (const)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
x: AutoDiffValue<T>,
|
x: AutoDiffValue<T>,
|
||||||
y: Double,
|
y: Double,
|
||||||
): AutoDiffValue<T> =
|
): AutoDiffValue<T> =
|
||||||
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
x: AutoDiffValue<T>,
|
x: AutoDiffValue<T>,
|
||||||
y: Int,
|
y: Int,
|
||||||
): AutoDiffValue<T> = pow(x, y.toDouble())
|
): AutoDiffValue<T> = pow(x, y.toDouble())
|
||||||
|
|
||||||
// exp(x)
|
// exp(x)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
|
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||||
|
|
||||||
// ln(x)
|
// ln(x)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
|
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||||
|
|
||||||
// x ^ y (any)
|
// x ^ y (any)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
x: AutoDiffValue<T>,
|
x: AutoDiffValue<T>,
|
||||||
y: AutoDiffValue<T>,
|
y: AutoDiffValue<T>,
|
||||||
): AutoDiffValue<T> =
|
): AutoDiffValue<T> =
|
||||||
exp(y * ln(x))
|
exp(y * ln(x))
|
||||||
|
|
||||||
// sin(x)
|
// sin(x)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||||
|
|
||||||
// cos(x)
|
// cos(x)
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { tan(x.value) }) { z ->
|
derive(const { tan(x.value) }) { z ->
|
||||||
val c = cos(x.value)
|
val c = cos(x.value)
|
||||||
x.d += z.d / (c * c)
|
x.d += z.d / (c * c)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { tan(x.value) }) { z ->
|
derive(const { tanh(x.value) }) { z ->
|
||||||
val c = cosh(x.value)
|
val c = cosh(x.value)
|
||||||
x.d += z.d / (c * c)
|
x.d += z.d / (c * c)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
||||||
|
|
||||||
|
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
||||||
|
context: F,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : ExtendedField<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
|
||||||
|
// x ^ 2
|
||||||
|
public fun sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sqr(x)
|
||||||
|
|
||||||
|
// x ^ 1/2
|
||||||
|
public override fun sqrt(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sqrt(arg)
|
||||||
|
|
||||||
|
// x ^ y (const)
|
||||||
|
public override fun power(arg: AutoDiffValue<T>, pow: Number): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).pow(arg, pow.toDouble())
|
||||||
|
|
||||||
|
// exp(x)
|
||||||
|
public override fun exp(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).exp(arg)
|
||||||
|
|
||||||
|
// ln(x)
|
||||||
|
public override fun ln(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).ln(arg)
|
||||||
|
|
||||||
|
// x ^ y (any)
|
||||||
|
public fun pow(
|
||||||
|
x: AutoDiffValue<T>,
|
||||||
|
y: AutoDiffValue<T>,
|
||||||
|
): AutoDiffValue<T> = exp(y * ln(x))
|
||||||
|
|
||||||
|
// sin(x)
|
||||||
|
public override fun sin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sin(arg)
|
||||||
|
|
||||||
|
// cos(x)
|
||||||
|
public override fun cos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).cos(arg)
|
||||||
|
|
||||||
|
public override fun tan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).tan(arg)
|
||||||
|
|
||||||
|
public override fun asin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).asin(arg)
|
||||||
|
|
||||||
|
public override fun acos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).acos(arg)
|
||||||
|
|
||||||
|
public override fun atan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).atan(arg)
|
||||||
|
|
||||||
|
public override fun sinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sinh(arg)
|
||||||
|
|
||||||
|
public override fun cosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).cosh(arg)
|
||||||
|
|
||||||
|
public override fun tanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).tanh(arg)
|
||||||
|
|
||||||
|
public override fun asinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).asinh(arg)
|
||||||
|
|
||||||
|
public override fun acosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).acosh(arg)
|
||||||
|
|
||||||
|
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).atanh(arg)
|
||||||
|
}
|
@ -8,11 +8,6 @@ import kotlin.contracts.InvocationKind
|
|||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
|
|
||||||
//public interface ExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>> {
|
|
||||||
// public fun expression(block: A.() -> E): Expression<T>
|
|
||||||
//}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a functional expression with this [Space].
|
* Creates a functional expression with this [Space].
|
||||||
*/
|
*/
|
||||||
|
@ -2,6 +2,7 @@ package kscience.kmath.expressions
|
|||||||
|
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
import kscience.kmath.structures.asBuffer
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.E
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
@ -13,18 +14,18 @@ class SimpleAutoDiffTest {
|
|||||||
|
|
||||||
fun dx(
|
fun dx(
|
||||||
xBinding: Pair<Symbol, Double>,
|
xBinding: Pair<Symbol, Double>,
|
||||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||||
|
|
||||||
fun dxy(
|
fun dxy(
|
||||||
xBinding: Pair<Symbol, Double>,
|
xBinding: Pair<Symbol, Double>,
|
||||||
yBinding: Pair<Symbol, Double>,
|
yBinding: Pair<Symbol, Double>,
|
||||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
|
||||||
body(bind(xBinding.first), bind(yBinding.first))
|
body(bind(xBinding.first), bind(yBinding.first))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun diff(block: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
fun diff(block: SimpleAutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
||||||
return SimpleAutoDiffExpression(RealField, block)
|
return SimpleAutoDiffExpression(RealField, block)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,9 +246,9 @@ class SimpleAutoDiffTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testTanh() {
|
fun testTanh() {
|
||||||
val y = dx(x to PI / 6) { x -> tanh(x) }
|
val y = dx(x to 1.0) { x -> tanh(x) }
|
||||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6)
|
||||||
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
Reference in New Issue
Block a user
Removed
Removed