Another refactor of SimpleAutoDiff

This commit is contained in:
Alexander Nozik 2020-10-27 17:57:17 +03:00
parent 4450c0fcc7
commit 9a147d033e
8 changed files with 214 additions and 175 deletions

View File

@ -1,9 +1,6 @@
package kscience.kmath.commons.expressions
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.ExpressionAlgebra
import kscience.kmath.expressions.Symbol
import kscience.kmath.expressions.*
import kscience.kmath.operations.ExtendedField
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
@ -92,6 +89,12 @@ public class DerivativeStructureField(
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
return DerivativeStructureExpression(function)
}
}
}
/**

View File

@ -71,12 +71,8 @@ public object CMFit {
public fun Expression<Double>.optimize(
vararg symbols: Symbol,
configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> {
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
problem.expression(this)
return problem.optimize()
}
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
/**
* Optimize differentiable expression
@ -84,12 +80,7 @@ public fun Expression<Double>.optimize(
public fun DifferentiableExpression<Double>.optimize(
vararg symbols: Symbol,
configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> {
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
problem.diffExpression(this)
return problem.optimize()
}
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
public fun DifferentiableExpression<Double>.minimize(
vararg startPoint: Pair<Symbol, Double>,

View File

@ -33,7 +33,7 @@ public class CMOptimizationProblem(
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
public fun initialGuess(map: Map<Symbol, Double>): Unit {
public override fun initialGuess(map: Map<Symbol, Double>): Unit {
addOptimizationData(InitialGuess(map.toDoubleArray()))
}
@ -94,10 +94,12 @@ public class CMOptimizationProblem(
return OptimizationResult(point.toMap(), value, setOf(this))
}
public companion object {
public companion object : OptimizationProblemFactory<Double, CMOptimizationProblem> {
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
public const val DEFAULT_MAX_ITER: Int = 1000
override fun build(symbols: List<Symbol>): CMOptimizationProblem = CMOptimizationProblem(symbols)
}
}

View File

@ -1,37 +0,0 @@
package kscience.kmath.commons.optimization
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.Symbol
public interface OptimizationFeature
//TODO move to prob/stat
public class OptimizationResult<T>(
public val point: Map<Symbol, T>,
public val value: T,
public val features: Set<OptimizationFeature> = emptySet(),
){
override fun toString(): String {
return "OptimizationResult(point=$point, value=$value)"
}
}
/**
* A configuration builder for optimization problem
*/
public interface OptimizationProblem<T : Any> {
/**
* Set an objective function expression
*/
public fun expression(expression: Expression<Double>): Unit
/**
*
*/
public fun diffExpression(expression: DifferentiableExpression<Double>): Unit
public fun update(result: OptimizationResult<T>)
public fun optimize(): OptimizationResult<T>
}

View File

@ -23,10 +23,9 @@ public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expressio
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
derivative(StringSymbol(name) to 1)
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
//}
/**
* A [DifferentiableExpression] that defines only first derivatives
*/
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
@ -35,4 +34,11 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
return derivativeOrNull(dSymbol)
}
}
/**
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
*/
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
public fun process(function: A.() -> I): DifferentiableExpression<T>
}

View File

@ -47,7 +47,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
}
/**
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
* Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
*
* The partial derivatives are placed in argument `d` variable
*
@ -59,36 +59,90 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
* assertEquals(9.0, x.d) // dy/dx
* ```
*
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
* @return the result of differentiation.
*/
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
bindings: Map<Symbol, T>,
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return AutoDiffContext(this, bindings).derivate(body)
return SimpleAutoDiffField(this, bindings).derivate(body)
}
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
vararg bindings: Pair<Symbol, T>,
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
/**
* Represents field in context of which functions can be derived.
*/
public abstract class AutoDiffField<T : Any, F : Field<T>>
: Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public val context: F,
bindings: Map<Symbol, T>,
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
public abstract val context: F
// this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
/**
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
* with respect to this variable.
*
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
private class AutoDiffVariableWithDerivative<T : Any>(
override val identity: String,
value: T,
var d: T,
) : AutoDiffValue<T>(value), Symbol {
override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
}
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
}
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
private fun getDerivative(variable: AutoDiffValue<T>): T =
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
}
@Suppress("UNCHECKED_CAST")
private fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one)
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
/**
* A variable accessing inner state of derivatives.
* Use this value in inner builders to avoid creating additional derivative bindings.
*/
public abstract var AutoDiffValue<T>.d: T
public var AutoDiffValue<T>.d: T
get() = getDerivative(this)
set(value) = setDerivative(this, value)
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
/**
* Performs update of derivative after the rest of the formula in the back-pass.
@ -101,9 +155,22 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
* }
* ```
*/
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
@Suppress("UNCHECKED_CAST")
public fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
}
// Overloads for Double constants
@ -119,68 +186,7 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
}
/**
* Automatic Differentiation context class.
*/
private class AutoDiffContext<T : Any, F : Field<T>>(
override val context: F,
bindings: Map<Symbol, T>,
) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one)
/**
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
* with respect to this variable.
*
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
private class AutoDiffVariableWithDeriv<T : Any>(
override val identity: String,
value: T,
var d: T,
) : AutoDiffValue<T>(value), Symbol{
override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
}
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero)
}
override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv<T>? = bindings[symbol.identity]
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
override var AutoDiffValue<T>.d: T
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
@Suppress("UNCHECKED_CAST")
fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
// Basic math (+, -, *, /)
@ -206,13 +212,6 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
derive(const { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble()
}
inline fun derivate(function: AutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
}
}
/**
@ -220,99 +219,178 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
*/
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F,
public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
) : FirstDerivativeExpression<T>() {
public override operator fun invoke(arguments: Map<Symbol, T>): T {
//val bindings = arguments.entries.map { it.key.bind(it.value) }
return AutoDiffContext(field, arguments).function().value
return SimpleAutoDiffField(field, arguments).function().value
}
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
//val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = AutoDiffContext(field, arguments).derivate(function)
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
derivationResult.derivative(symbol)
}
}
/**
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
*/
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
return SimpleAutoDiffExpression(field, function)
}
}
}
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>,
y: Double,
): AutoDiffValue<T> =
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>,
y: Int,
): AutoDiffValue<T> = pow(x, y.toDouble())
// exp(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>,
y: AutoDiffValue<T>,
): AutoDiffValue<T> =
exp(y * ln(x))
// sin(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tan(x.value) }) { z ->
val c = cos(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tan(x.value) }) { z ->
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tanh(x.value) }) { z ->
val c = cosh(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
context: F,
bindings: Map<Symbol, T>,
) : ExtendedField<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
// x ^ 2
public fun sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sqr(x)
// x ^ 1/2
public override fun sqrt(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sqrt(arg)
// x ^ y (const)
public override fun power(arg: AutoDiffValue<T>, pow: Number): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).pow(arg, pow.toDouble())
// exp(x)
public override fun exp(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).exp(arg)
// ln(x)
public override fun ln(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).ln(arg)
// x ^ y (any)
public fun pow(
x: AutoDiffValue<T>,
y: AutoDiffValue<T>,
): AutoDiffValue<T> = exp(y * ln(x))
// sin(x)
public override fun sin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sin(arg)
// cos(x)
public override fun cos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).cos(arg)
public override fun tan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).tan(arg)
public override fun asin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).asin(arg)
public override fun acos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).acos(arg)
public override fun atan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).atan(arg)
public override fun sinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sinh(arg)
public override fun cosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).cosh(arg)
public override fun tanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).tanh(arg)
public override fun asinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).asinh(arg)
public override fun acosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).acosh(arg)
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).atanh(arg)
}

View File

@ -8,11 +8,6 @@ import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
//public interface ExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>> {
// public fun expression(block: A.() -> E): Expression<T>
//}
/**
* Creates a functional expression with this [Space].
*/

View File

@ -2,6 +2,7 @@ package kscience.kmath.expressions
import kscience.kmath.operations.RealField
import kscience.kmath.structures.asBuffer
import kotlin.math.E
import kotlin.math.PI
import kotlin.math.pow
import kotlin.math.sqrt
@ -13,18 +14,18 @@ class SimpleAutoDiffTest {
fun dx(
xBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
fun dxy(
xBinding: Pair<Symbol, Double>,
yBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
body(bind(xBinding.first), bind(yBinding.first))
}
fun diff(block: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
fun diff(block: SimpleAutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
return SimpleAutoDiffExpression(RealField, block)
}
@ -45,7 +46,7 @@ class SimpleAutoDiffTest {
@Test
fun testPlusX2Expr() {
val expr = diff{
val expr = diff {
val x = bind(x)
x + x
}
@ -245,9 +246,9 @@ class SimpleAutoDiffTest {
@Test
fun testTanh() {
val y = dx(x to PI / 6) { x -> tanh(x) }
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
val y = dx(x to 1.0) { x -> tanh(x) }
assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6)
assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
}
@Test