Feature/diff api #154

Merged
altavir merged 20 commits from feature/diff-api into dev 2020-10-28 13:25:24 +03:00
4 changed files with 73 additions and 81 deletions
Showing only changes of commit ae07652d9e - Show all commits

View File

@ -25,13 +25,13 @@ public class DerivativeStructureField(
*/ */
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) : public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol { DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
override val identity: Any = symbol.identity override val identity: String = symbol.identity
} }
/** /**
* Identity-based symbol bindings map * Identity-based symbol bindings map
*/ */
private val variables: Map<Any?, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) -> private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
key.identity to DerivativeStructureSymbol(key, value) key.identity to DerivativeStructureSymbol(key, value)
} }

View File

@ -12,7 +12,7 @@ public interface Symbol {
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol. * Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
* By default uses object identity * By default uses object identity
*/ */
public val identity: Any get() = this public val identity: String
} }
/** /**

View File

@ -12,20 +12,7 @@ import kotlin.contracts.contract
*/ */
/** public open class AutoDiffValue<out T>(public val value: T)
CommanderTvis commented 2020-10-26 22:52:50 +03:00 (Migrated from github.com)
Review

I am sure that it may be replaced with sealed class to prevent one to extend AutoDiffValue with irrelevant object.

I am sure that it may be replaced with sealed class to prevent one to extend AutoDiffValue with irrelevant object.
altavir commented 2020-10-28 09:51:39 +03:00 (Migrated from github.com)
Review

The idea is that it could be extended anytime. Here AutoDiffValue is just a marker interface. It is possible to even replace it with an inline class, but we need performance measurements to make that change.

The idea is that it could be extended anytime. Here `AutoDiffValue` is just a marker interface. It is possible to even replace it with an inline class, but we need performance measurements to make that change.
* A [Symbol] with bound value
*/
public interface BoundSymbol<out T> : Symbol {
public val value: T
}
/**
* Bind a [Symbol] to a [value] and produce [BoundSymbol]
*/
public fun <T> Symbol.bind(value: T): BoundSymbol<T> = object : BoundSymbol<T> {
override val identity = this@bind.identity
override val value: T = value
}
/** /**
* Represents result of [withAutoDiff] call. * Represents result of [withAutoDiff] call.
@ -36,10 +23,10 @@ public fun <T> Symbol.bind(value: T): BoundSymbol<T> = object : BoundSymbol<T> {
* @property context The field over [T]. * @property context The field over [T].
*/ */
public class DerivationResult<T : Any>( public class DerivationResult<T : Any>(
override val value: T, public val value: T,
private val derivativeValues: Map<Any, T>, private val derivativeValues: Map<String, T>,
public val context: Field<T>, public val context: Field<T>,
) : BoundSymbol<T> { ) {
/** /**
* Returns derivative of [variable] or returns [Ring.zero] in [context]. * Returns derivative of [variable] or returns [Ring.zero] in [context].
*/ */
@ -76,8 +63,8 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
* @return the result of differentiation. * @return the result of differentiation.
*/ */
public fun <T : Any, F : Field<T>> F.withAutoDiff( public fun <T : Any, F : Field<T>> F.withAutoDiff(
bindings: Collection<BoundSymbol<T>>, bindings: Map<Symbol, T>,
body: AutoDiffField<T, F>.() -> BoundSymbol<T>, body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> { ): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
@ -86,14 +73,14 @@ public fun <T : Any, F : Field<T>> F.withAutoDiff(
public fun <T : Any, F : Field<T>> F.withAutoDiff( public fun <T : Any, F : Field<T>> F.withAutoDiff(
vararg bindings: Pair<Symbol, T>, vararg bindings: Pair<Symbol, T>,
body: AutoDiffField<T, F>.() -> BoundSymbol<T>, body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> = withAutoDiff(bindings.map { it.first.bind(it.second) }, body) ): DerivationResult<T> = withAutoDiff(bindings.toMap(), body)
/** /**
* Represents field in context of which functions can be derived. * Represents field in context of which functions can be derived.
*/ */
public abstract class AutoDiffField<T : Any, F : Field<T>> public abstract class AutoDiffField<T : Any, F : Field<T>>
: Field<BoundSymbol<T>>, ExpressionAlgebra<T, BoundSymbol<T>> { : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
public abstract val context: F public abstract val context: F
@ -101,7 +88,7 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
* A variable accessing inner state of derivatives. * A variable accessing inner state of derivatives.
* Use this value in inner builders to avoid creating additional derivative bindings. * Use this value in inner builders to avoid creating additional derivative bindings.
*/ */
public abstract var BoundSymbol<T>.d: T public abstract var AutoDiffValue<T>.d: T
/** /**
* Performs update of derivative after the rest of the formula in the back-pass. * Performs update of derivative after the rest of the formula in the back-pass.
@ -116,21 +103,21 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
*/ */
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
public inline fun const(block: F.() -> T): BoundSymbol<T> = const(context.block()) public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
// Overloads for Double constants // Overloads for Double constants
override operator fun Number.plus(b: BoundSymbol<T>): BoundSymbol<T> = override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@plus.toDouble() * one + b.value }) { z -> derive(const { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d b.d += z.d
} }
override operator fun BoundSymbol<T>.plus(b: Number): BoundSymbol<T> = b.plus(this) override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
override operator fun Number.minus(b: BoundSymbol<T>): BoundSymbol<T> = override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
override operator fun BoundSymbol<T>.minus(b: Number): BoundSymbol<T> = override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
} }
@ -139,14 +126,14 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
*/ */
private class AutoDiffContext<T : Any, F : Field<T>>( private class AutoDiffContext<T : Any, F : Field<T>>(
override val context: F, override val context: F,
bindings: Collection<BoundSymbol<T>>, bindings: Map<Symbol, T>,
) : AutoDiffField<T, F>() { ) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to // this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8) private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0 private var sp: Int = 0
private val derivatives: MutableMap<Any, T> = hashMapOf() private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
override val zero: BoundSymbol<T> get() = const(context.zero) override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: BoundSymbol<T> get() = const(context.one) override val one: AutoDiffValue<T> get() = const(context.one)
/** /**
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result * Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
@ -155,17 +142,23 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
* @param T the non-nullable type of value. * @param T the non-nullable type of value.
* @property value The value of this variable. * @property value The value of this variable.
*/ */
private class AutoDiffVariableWithDeriv<T : Any>(override val value: T, var d: T) : BoundSymbol<T> private class AutoDiffVariableWithDeriv<T : Any>(
override val identity: String,
value: T,
var d: T,
) : AutoDiffValue<T>(value), Symbol
private val bindings: Map<Any, BoundSymbol<T>> = bindings.associateBy { it.identity } private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero)
}
override fun bindOrNull(symbol: Symbol): BoundSymbol<T>? = bindings[symbol.identity] override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv<T>? = bindings[symbol.identity]
override fun const(value: T): BoundSymbol<T> = AutoDiffVariableWithDeriv(value, context.zero) override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
override var BoundSymbol<T>.d: T override var AutoDiffValue<T>.d: T
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: F.(R) -> Unit): R { override fun <R> derive(value: R, block: F.(R) -> Unit): R {
@ -187,34 +180,34 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
override fun add(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> = override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value + b.value }) { z -> derive(const { a.value + b.value }) { z ->
a.d += z.d a.d += z.d
b.d += z.d b.d += z.d
} }
override fun multiply(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> = override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value * b.value }) { z -> derive(const { a.value * b.value }) { z ->
a.d += z.d * b.value a.d += z.d * b.value
b.d += z.d * a.value b.d += z.d * a.value
} }
override fun divide(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> = override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value / b.value }) { z -> derive(const { a.value / b.value }) { z ->
a.d += z.d / b.value a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value) b.d -= z.d * a.value / (b.value * b.value)
} }
override fun multiply(a: BoundSymbol<T>, k: Number): BoundSymbol<T> = override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
derive(const { k.toDouble() * a.value }) { z -> derive(const { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble() a.d += z.d * k.toDouble()
} }
inline fun derivate(function: AutoDiffField<T, F>.() -> BoundSymbol<T>): DerivationResult<T> { inline fun derivate(function: AutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function() val result = function()
result.d = context.one // computing derivative w.r.t result result.d = context.one // computing derivative w.r.t result
runBackwardPass() runBackwardPass()
return DerivationResult(result.value, derivatives, context) return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
} }
} }
@ -223,11 +216,11 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
*/ */
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>( public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F, public val field: F,
public val function: AutoDiffField<T, F>.() -> BoundSymbol<T>, public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
) : DifferentiableExpression<T> { ) : DifferentiableExpression<T> {
public override operator fun invoke(arguments: Map<Symbol, T>): T { public override operator fun invoke(arguments: Map<Symbol, T>): T {
val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
return AutoDiffContext(field, bindings).function().value return AutoDiffContext(field, arguments).function().value
} }
/** /**
@ -237,8 +230,8 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
val dSymbol = orders.entries.singleOrNull { it.value == 1 } val dSymbol = orders.entries.singleOrNull { it.value == 1 }
?: error("SimpleAutoDiff supports only first order derivatives") ?: error("SimpleAutoDiff supports only first order derivatives")
return Expression { arguments -> return Expression { arguments ->
val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = AutoDiffContext(field, bindings).derivate(function) val derivationResult = AutoDiffContext(field, arguments).derivate(function)
derivationResult.derivative(dSymbol.key) derivationResult.derivative(dSymbol.key)
} }
} }
@ -248,82 +241,81 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
// Extensions for differentiation of various basic mathematical functions // Extensions for differentiation of various basic mathematical functions
// x ^ 2 // x ^ 2
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2 // x ^ 1/2
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const) // x ^ y (const)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow( public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
x: BoundSymbol<T>, x: AutoDiffValue<T>,
y: Double, y: Double,
): BoundSymbol<T> = ): AutoDiffValue<T> =
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow( public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
x: BoundSymbol<T>, x: AutoDiffValue<T>,
y: Int, y: Int,
): BoundSymbol<T> = ): AutoDiffValue<T> = pow(x, y.toDouble())
pow(x, y.toDouble())
// exp(x) // exp(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value } derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x) // ln(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value } derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any) // x ^ y (any)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow( public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
x: BoundSymbol<T>, x: AutoDiffValue<T>,
y: BoundSymbol<T>, y: AutoDiffValue<T>,
): BoundSymbol<T> = ): AutoDiffValue<T> =
exp(y * ln(x)) exp(y * ln(x))
// sin(x) // sin(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x) // cos(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tan(x.value) }) { z -> derive(const { tan(x.value) }) { z ->
val c = cos(x.value) val c = cos(x.value)
x.d += z.d / (c * c) x.d += z.d / (c * c)
} }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) } derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) } derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) } derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) } derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) } derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tan(x.value) }) { z -> derive(const { tan(x.value) }) { z ->
val c = cosh(x.value) val c = cosh(x.value)
x.d += z.d / (c * c) x.d += z.d / (c * c)
} }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) } derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) } derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: BoundSymbol<T>): BoundSymbol<T> = public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) } derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }

View File

@ -12,23 +12,23 @@ import kotlin.test.assertTrue
class SimpleAutoDiffTest { class SimpleAutoDiffTest {
fun d( fun d(
vararg bindings: Pair<Symbol, Double>, vararg bindings: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>, body: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body) ): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
fun dx( fun dx(
xBinding: Pair<Symbol, Double>, xBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>) -> BoundSymbol<Double>, body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) } ): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) }
fun dxy( fun dxy(
xBinding: Pair<Symbol, Double>, xBinding: Pair<Symbol, Double>,
yBinding: Pair<Symbol, Double>, yBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>, y: BoundSymbol<Double>) -> BoundSymbol<Double>, body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) { ): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) {
body(bind(xBinding.first), bind(yBinding.first)) body(bind(xBinding.first), bind(yBinding.first))
} }
fun diff(block: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>): SimpleAutoDiffExpression<Double, RealField> { fun diff(block: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
return SimpleAutoDiffExpression(RealField, block) return SimpleAutoDiffExpression(RealField, block)
} }