forked from kscience/kmath
Another refactor of SimpleAutoDiff
This commit is contained in:
parent
4450c0fcc7
commit
9a147d033e
kmath-commons/src/main/kotlin/kscience/kmath/commons
expressions
optimization
kmath-core/src
commonMain/kotlin/kscience/kmath/expressions
commonTest/kotlin/kscience/kmath/expressions
11
kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt
11
kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt
@ -1,9 +1,6 @@
|
||||
package kscience.kmath.commons.expressions
|
||||
|
||||
import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.ExpressionAlgebra
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kscience.kmath.expressions.*
|
||||
import kscience.kmath.operations.ExtendedField
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
|
||||
@ -92,6 +89,12 @@ public class DerivativeStructureField(
|
||||
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||
|
||||
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
||||
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
|
||||
return DerivativeStructureExpression(function)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -71,12 +71,8 @@ public object CMFit {
|
||||
public fun Expression<Double>.optimize(
|
||||
vararg symbols: Symbol,
|
||||
configuration: CMOptimizationProblem.() -> Unit,
|
||||
): OptimizationResult<Double> {
|
||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
||||
problem.expression(this)
|
||||
return problem.optimize()
|
||||
}
|
||||
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||
|
||||
|
||||
/**
|
||||
* Optimize differentiable expression
|
||||
@ -84,12 +80,7 @@ public fun Expression<Double>.optimize(
|
||||
public fun DifferentiableExpression<Double>.optimize(
|
||||
vararg symbols: Symbol,
|
||||
configuration: CMOptimizationProblem.() -> Unit,
|
||||
): OptimizationResult<Double> {
|
||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
|
||||
problem.diffExpression(this)
|
||||
return problem.optimize()
|
||||
}
|
||||
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||
|
||||
public fun DifferentiableExpression<Double>.minimize(
|
||||
vararg startPoint: Pair<Symbol, Double>,
|
||||
|
@ -33,7 +33,7 @@ public class CMOptimizationProblem(
|
||||
|
||||
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
||||
|
||||
public fun initialGuess(map: Map<Symbol, Double>): Unit {
|
||||
public override fun initialGuess(map: Map<Symbol, Double>): Unit {
|
||||
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||
}
|
||||
|
||||
@ -94,10 +94,12 @@ public class CMOptimizationProblem(
|
||||
return OptimizationResult(point.toMap(), value, setOf(this))
|
||||
}
|
||||
|
||||
public companion object {
|
||||
public companion object : OptimizationProblemFactory<Double, CMOptimizationProblem> {
|
||||
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
||||
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
||||
public const val DEFAULT_MAX_ITER: Int = 1000
|
||||
|
||||
override fun build(symbols: List<Symbol>): CMOptimizationProblem = CMOptimizationProblem(symbols)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,37 +0,0 @@
|
||||
package kscience.kmath.commons.optimization
|
||||
|
||||
import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
|
||||
public interface OptimizationFeature
|
||||
|
||||
//TODO move to prob/stat
|
||||
|
||||
public class OptimizationResult<T>(
|
||||
public val point: Map<Symbol, T>,
|
||||
public val value: T,
|
||||
public val features: Set<OptimizationFeature> = emptySet(),
|
||||
){
|
||||
override fun toString(): String {
|
||||
return "OptimizationResult(point=$point, value=$value)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A configuration builder for optimization problem
|
||||
*/
|
||||
public interface OptimizationProblem<T : Any> {
|
||||
/**
|
||||
* Set an objective function expression
|
||||
*/
|
||||
public fun expression(expression: Expression<Double>): Unit
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public fun diffExpression(expression: DifferentiableExpression<Double>): Unit
|
||||
public fun update(result: OptimizationResult<T>)
|
||||
public fun optimize(): OptimizationResult<T>
|
||||
}
|
||||
|
@ -23,10 +23,9 @@ public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expressio
|
||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||
derivative(StringSymbol(name) to 1)
|
||||
|
||||
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
|
||||
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
|
||||
//}
|
||||
|
||||
/**
|
||||
* A [DifferentiableExpression] that defines only first derivatives
|
||||
*/
|
||||
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
||||
|
||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||
@ -35,4 +34,11 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
|
||||
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
|
||||
return derivativeOrNull(dSymbol)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
||||
*/
|
||||
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
|
||||
public fun process(function: A.() -> I): DifferentiableExpression<T>
|
||||
}
|
@ -47,7 +47,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||
* Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
|
||||
*
|
||||
* The partial derivatives are placed in argument `d` variable
|
||||
*
|
||||
@ -59,36 +59,90 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
||||
* assertEquals(9.0, x.d) // dy/dx
|
||||
* ```
|
||||
*
|
||||
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||
* @return the result of differentiation.
|
||||
*/
|
||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||
bindings: Map<Symbol, T>,
|
||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
): DerivationResult<T> {
|
||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||
|
||||
return AutoDiffContext(this, bindings).derivate(body)
|
||||
return SimpleAutoDiffField(this, bindings).derivate(body)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||
vararg bindings: Pair<Symbol, T>,
|
||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
|
||||
|
||||
/**
|
||||
* Represents field in context of which functions can be derived.
|
||||
*/
|
||||
public abstract class AutoDiffField<T : Any, F : Field<T>>
|
||||
: Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
public val context: F,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||
|
||||
public abstract val context: F
|
||||
// this stack contains pairs of blocks and values to apply them to
|
||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||
private var sp: Int = 0
|
||||
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||
|
||||
/**
|
||||
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||
* with respect to this variable.
|
||||
*
|
||||
* @param T the non-nullable type of value.
|
||||
* @property value The value of this variable.
|
||||
*/
|
||||
private class AutoDiffVariableWithDerivative<T : Any>(
|
||||
override val identity: String,
|
||||
value: T,
|
||||
var d: T,
|
||||
) : AutoDiffValue<T>(value), Symbol {
|
||||
override fun toString(): String = identity
|
||||
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||
override fun hashCode(): Int = identity.hashCode()
|
||||
}
|
||||
|
||||
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
||||
}
|
||||
|
||||
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
||||
|
||||
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
||||
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
||||
|
||||
private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
|
||||
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
||||
}
|
||||
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private fun runBackwardPass() {
|
||||
while (sp > 0) {
|
||||
val value = stack[--sp]
|
||||
val block = stack[--sp] as F.(Any?) -> Unit
|
||||
context.block(value)
|
||||
}
|
||||
}
|
||||
|
||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||
|
||||
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
||||
|
||||
/**
|
||||
* A variable accessing inner state of derivatives.
|
||||
* Use this value in inner builders to avoid creating additional derivative bindings.
|
||||
*/
|
||||
public abstract var AutoDiffValue<T>.d: T
|
||||
public var AutoDiffValue<T>.d: T
|
||||
get() = getDerivative(this)
|
||||
set(value) = setDerivative(this, value)
|
||||
|
||||
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
|
||||
|
||||
/**
|
||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||
@ -101,9 +155,22 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||
// save block to stack for backward pass
|
||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||
stack[sp++] = block
|
||||
stack[sp++] = value
|
||||
return value
|
||||
}
|
||||
|
||||
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
|
||||
|
||||
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||
val result = function()
|
||||
result.d = context.one // computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
|
||||
}
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
@ -119,68 +186,7 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
|
||||
|
||||
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
||||
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||
}
|
||||
|
||||
/**
|
||||
* Automatic Differentiation context class.
|
||||
*/
|
||||
private class AutoDiffContext<T : Any, F : Field<T>>(
|
||||
override val context: F,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : AutoDiffField<T, F>() {
|
||||
// this stack contains pairs of blocks and values to apply them to
|
||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||
private var sp: Int = 0
|
||||
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||
|
||||
/**
|
||||
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||
* with respect to this variable.
|
||||
*
|
||||
* @param T the non-nullable type of value.
|
||||
* @property value The value of this variable.
|
||||
*/
|
||||
private class AutoDiffVariableWithDeriv<T : Any>(
|
||||
override val identity: String,
|
||||
value: T,
|
||||
var d: T,
|
||||
) : AutoDiffValue<T>(value), Symbol{
|
||||
override fun toString(): String = identity
|
||||
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||
override fun hashCode(): Int = identity.hashCode()
|
||||
}
|
||||
|
||||
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
|
||||
it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero)
|
||||
}
|
||||
|
||||
override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv<T>? = bindings[symbol.identity]
|
||||
|
||||
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
||||
|
||||
override var AutoDiffValue<T>.d: T
|
||||
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
||||
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||
// save block to stack for backward pass
|
||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||
stack[sp++] = block
|
||||
stack[sp++] = value
|
||||
return value
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun runBackwardPass() {
|
||||
while (sp > 0) {
|
||||
val value = stack[--sp]
|
||||
val block = stack[--sp] as F.(Any?) -> Unit
|
||||
context.block(value)
|
||||
}
|
||||
}
|
||||
|
||||
// Basic math (+, -, *, /)
|
||||
|
||||
@ -206,13 +212,6 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
||||
derive(const { k.toDouble() * a.value }) { z ->
|
||||
a.d += z.d * k.toDouble()
|
||||
}
|
||||
|
||||
inline fun derivate(function: AutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||
val result = function()
|
||||
result.d = context.one // computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -220,99 +219,178 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
||||
*/
|
||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||
public val field: F,
|
||||
public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
) : FirstDerivativeExpression<T>() {
|
||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||
return AutoDiffContext(field, arguments).function().value
|
||||
return SimpleAutoDiffField(field, arguments).function().value
|
||||
}
|
||||
|
||||
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||
val derivationResult = AutoDiffContext(field, arguments).derivate(function)
|
||||
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
|
||||
derivationResult.derivative(symbol)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
||||
*/
|
||||
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
|
||||
return SimpleAutoDiffExpression(field, function)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Extensions for differentiation of various basic mathematical functions
|
||||
|
||||
// x ^ 2
|
||||
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||
|
||||
// x ^ 1/2
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||
|
||||
// x ^ y (const)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||
x: AutoDiffValue<T>,
|
||||
y: Double,
|
||||
): AutoDiffValue<T> =
|
||||
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||
x: AutoDiffValue<T>,
|
||||
y: Int,
|
||||
): AutoDiffValue<T> = pow(x, y.toDouble())
|
||||
|
||||
// exp(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||
|
||||
// ln(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||
|
||||
// x ^ y (any)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||
x: AutoDiffValue<T>,
|
||||
y: AutoDiffValue<T>,
|
||||
): AutoDiffValue<T> =
|
||||
exp(y * ln(x))
|
||||
|
||||
// sin(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||
|
||||
// cos(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { tan(x.value) }) { z ->
|
||||
val c = cos(x.value)
|
||||
x.d += z.d / (c * c)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { tan(x.value) }) { z ->
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { tanh(x.value) }) { z ->
|
||||
val c = cosh(x.value)
|
||||
x.d += z.d / (c * c)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
||||
|
||||
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
||||
context: F,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : ExtendedField<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
|
||||
// x ^ 2
|
||||
public fun sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).sqr(x)
|
||||
|
||||
// x ^ 1/2
|
||||
public override fun sqrt(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).sqrt(arg)
|
||||
|
||||
// x ^ y (const)
|
||||
public override fun power(arg: AutoDiffValue<T>, pow: Number): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).pow(arg, pow.toDouble())
|
||||
|
||||
// exp(x)
|
||||
public override fun exp(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).exp(arg)
|
||||
|
||||
// ln(x)
|
||||
public override fun ln(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).ln(arg)
|
||||
|
||||
// x ^ y (any)
|
||||
public fun pow(
|
||||
x: AutoDiffValue<T>,
|
||||
y: AutoDiffValue<T>,
|
||||
): AutoDiffValue<T> = exp(y * ln(x))
|
||||
|
||||
// sin(x)
|
||||
public override fun sin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).sin(arg)
|
||||
|
||||
// cos(x)
|
||||
public override fun cos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).cos(arg)
|
||||
|
||||
public override fun tan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).tan(arg)
|
||||
|
||||
public override fun asin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).asin(arg)
|
||||
|
||||
public override fun acos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).acos(arg)
|
||||
|
||||
public override fun atan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).atan(arg)
|
||||
|
||||
public override fun sinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).sinh(arg)
|
||||
|
||||
public override fun cosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).cosh(arg)
|
||||
|
||||
public override fun tanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).tanh(arg)
|
||||
|
||||
public override fun asinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).asinh(arg)
|
||||
|
||||
public override fun acosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).acosh(arg)
|
||||
|
||||
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
(this as SimpleAutoDiffField<T, F>).atanh(arg)
|
||||
}
|
@ -8,11 +8,6 @@ import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
|
||||
//public interface ExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>> {
|
||||
// public fun expression(block: A.() -> E): Expression<T>
|
||||
//}
|
||||
|
||||
|
||||
/**
|
||||
* Creates a functional expression with this [Space].
|
||||
*/
|
||||
|
@ -2,6 +2,7 @@ package kscience.kmath.expressions
|
||||
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kotlin.math.E
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
@ -13,18 +14,18 @@ class SimpleAutoDiffTest {
|
||||
|
||||
fun dx(
|
||||
xBinding: Pair<Symbol, Double>,
|
||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||
|
||||
fun dxy(
|
||||
xBinding: Pair<Symbol, Double>,
|
||||
yBinding: Pair<Symbol, Double>,
|
||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
|
||||
body(bind(xBinding.first), bind(yBinding.first))
|
||||
}
|
||||
|
||||
fun diff(block: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
||||
fun diff(block: SimpleAutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
||||
return SimpleAutoDiffExpression(RealField, block)
|
||||
}
|
||||
|
||||
@ -45,7 +46,7 @@ class SimpleAutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testPlusX2Expr() {
|
||||
val expr = diff{
|
||||
val expr = diff {
|
||||
val x = bind(x)
|
||||
x + x
|
||||
}
|
||||
@ -245,9 +246,9 @@ class SimpleAutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testTanh() {
|
||||
val y = dx(x to PI / 6) { x -> tanh(x) }
|
||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
||||
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
||||
val y = dx(x to 1.0) { x -> tanh(x) }
|
||||
assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6)
|
||||
assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
||||
}
|
||||
|
||||
@Test
|
||||
|
Loading…
x
Reference in New Issue
Block a user