Feature/diff api #154

Merged
altavir merged 20 commits from feature/diff-api into dev 2020-10-28 13:25:24 +03:00
4 changed files with 73 additions and 81 deletions
Showing only changes of commit ae07652d9e - Show all commits

View File

@ -25,13 +25,13 @@ public class DerivativeStructureField(
*/
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
override val identity: Any = symbol.identity
override val identity: String = symbol.identity
}
/**
* Identity-based symbol bindings map
*/
private val variables: Map<Any?, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
key.identity to DerivativeStructureSymbol(key, value)
}

View File

@ -12,7 +12,7 @@ public interface Symbol {
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
* By default uses object identity
*/
public val identity: Any get() = this
public val identity: String
}
/**

View File

@ -12,20 +12,7 @@ import kotlin.contracts.contract
*/
/**
* A [Symbol] with bound value
*/
public interface BoundSymbol<out T> : Symbol {
public val value: T
}
/**
* Bind a [Symbol] to a [value] and produce [BoundSymbol]
*/
public fun <T> Symbol.bind(value: T): BoundSymbol<T> = object : BoundSymbol<T> {
override val identity = this@bind.identity
override val value: T = value
}
public open class AutoDiffValue<out T>(public val value: T)
CommanderTvis commented 2020-10-26 22:52:50 +03:00 (Migrated from github.com)
Review

I am sure that it may be replaced with sealed class to prevent one to extend AutoDiffValue with irrelevant object.

I am sure that it may be replaced with sealed class to prevent one to extend AutoDiffValue with irrelevant object.
altavir commented 2020-10-28 09:51:39 +03:00 (Migrated from github.com)
Review

The idea is that it could be extended anytime. Here AutoDiffValue is just a marker interface. It is possible to even replace it with an inline class, but we need performance measurements to make that change.

The idea is that it could be extended anytime. Here `AutoDiffValue` is just a marker interface. It is possible to even replace it with an inline class, but we need performance measurements to make that change.
/**
* Represents result of [withAutoDiff] call.
@ -36,10 +23,10 @@ public fun <T> Symbol.bind(value: T): BoundSymbol<T> = object : BoundSymbol<T> {
* @property context The field over [T].
*/
public class DerivationResult<T : Any>(
override val value: T,
private val derivativeValues: Map<Any, T>,
public val value: T,
private val derivativeValues: Map<String, T>,
public val context: Field<T>,
) : BoundSymbol<T> {
) {
/**
* Returns derivative of [variable] or returns [Ring.zero] in [context].
*/
@ -76,8 +63,8 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
* @return the result of differentiation.
*/
public fun <T : Any, F : Field<T>> F.withAutoDiff(
bindings: Collection<BoundSymbol<T>>,
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
bindings: Map<Symbol, T>,
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
@ -86,14 +73,14 @@ public fun <T : Any, F : Field<T>> F.withAutoDiff(
public fun <T : Any, F : Field<T>> F.withAutoDiff(
vararg bindings: Pair<Symbol, T>,
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
): DerivationResult<T> = withAutoDiff(bindings.map { it.first.bind(it.second) }, body)
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> = withAutoDiff(bindings.toMap(), body)
/**
* Represents field in context of which functions can be derived.
*/
public abstract class AutoDiffField<T : Any, F : Field<T>>
: Field<BoundSymbol<T>>, ExpressionAlgebra<T, BoundSymbol<T>> {
: Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
public abstract val context: F
@ -101,7 +88,7 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
* A variable accessing inner state of derivatives.
* Use this value in inner builders to avoid creating additional derivative bindings.
*/
public abstract var BoundSymbol<T>.d: T
public abstract var AutoDiffValue<T>.d: T
/**
* Performs update of derivative after the rest of the formula in the back-pass.
@ -116,21 +103,21 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
*/
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
public inline fun const(block: F.() -> T): BoundSymbol<T> = const(context.block())
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
// Overloads for Double constants
override operator fun Number.plus(b: BoundSymbol<T>): BoundSymbol<T> =
override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d
}
override operator fun BoundSymbol<T>.plus(b: Number): BoundSymbol<T> = b.plus(this)
override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
override operator fun Number.minus(b: BoundSymbol<T>): BoundSymbol<T> =
override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
override operator fun BoundSymbol<T>.minus(b: Number): BoundSymbol<T> =
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
}
@ -139,14 +126,14 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
*/
private class AutoDiffContext<T : Any, F : Field<T>>(
override val context: F,
bindings: Collection<BoundSymbol<T>>,
bindings: Map<Symbol, T>,
) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0
private val derivatives: MutableMap<Any, T> = hashMapOf()
override val zero: BoundSymbol<T> get() = const(context.zero)
override val one: BoundSymbol<T> get() = const(context.one)
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one)
/**
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
@ -155,17 +142,23 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
private class AutoDiffVariableWithDeriv<T : Any>(override val value: T, var d: T) : BoundSymbol<T>
private class AutoDiffVariableWithDeriv<T : Any>(
override val identity: String,
value: T,
var d: T,
) : AutoDiffValue<T>(value), Symbol
private val bindings: Map<Any, BoundSymbol<T>> = bindings.associateBy { it.identity }
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero)
}
override fun bindOrNull(symbol: Symbol): BoundSymbol<T>? = bindings[symbol.identity]
override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv<T>? = bindings[symbol.identity]
override fun const(value: T): BoundSymbol<T> = AutoDiffVariableWithDeriv(value, context.zero)
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
override var BoundSymbol<T>.d: T
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value
override var AutoDiffValue<T>.d: T
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
@ -187,34 +180,34 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
// Basic math (+, -, *, /)
override fun add(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value + b.value }) { z ->
a.d += z.d
b.d += z.d
}
override fun multiply(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value * b.value }) { z ->
a.d += z.d * b.value
b.d += z.d * a.value
}
override fun divide(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value / b.value }) { z ->
a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value)
}
override fun multiply(a: BoundSymbol<T>, k: Number): BoundSymbol<T> =
override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
derive(const { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble()
}
inline fun derivate(function: AutoDiffField<T, F>.() -> BoundSymbol<T>): DerivationResult<T> {
inline fun derivate(function: AutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
return DerivationResult(result.value, derivatives, context)
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
}
}
@ -223,11 +216,11 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
*/
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F,
public val function: AutoDiffField<T, F>.() -> BoundSymbol<T>,
public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
) : DifferentiableExpression<T> {
public override operator fun invoke(arguments: Map<Symbol, T>): T {
val bindings = arguments.entries.map { it.key.bind(it.value) }
return AutoDiffContext(field, bindings).function().value
//val bindings = arguments.entries.map { it.key.bind(it.value) }
return AutoDiffContext(field, arguments).function().value
}
/**
@ -237,8 +230,8 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
val dSymbol = orders.entries.singleOrNull { it.value == 1 }
?: error("SimpleAutoDiff supports only first order derivatives")
return Expression { arguments ->
val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = AutoDiffContext(field, bindings).derivate(function)
//val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = AutoDiffContext(field, arguments).derivate(function)
derivationResult.derivative(dSymbol.key)
}
}
@ -248,82 +241,81 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
x: BoundSymbol<T>,
x: AutoDiffValue<T>,
y: Double,
): BoundSymbol<T> =
): AutoDiffValue<T> =
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
x: BoundSymbol<T>,
x: AutoDiffValue<T>,
y: Int,
): BoundSymbol<T> =
pow(x, y.toDouble())
): AutoDiffValue<T> = pow(x, y.toDouble())
// exp(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
x: BoundSymbol<T>,
y: BoundSymbol<T>,
): BoundSymbol<T> =
x: AutoDiffValue<T>,
y: AutoDiffValue<T>,
): AutoDiffValue<T> =
exp(y * ln(x))
// sin(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tan(x.value) }) { z ->
val c = cos(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tan(x.value) }) { z ->
val c = cosh(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: BoundSymbol<T>): BoundSymbol<T> =
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }

View File

@ -12,23 +12,23 @@ import kotlin.test.assertTrue
class SimpleAutoDiffTest {
fun d(
vararg bindings: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>,
body: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
fun dx(
xBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>) -> BoundSymbol<Double>,
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) }
fun dxy(
xBinding: Pair<Symbol, Double>,
yBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>, y: BoundSymbol<Double>) -> BoundSymbol<Double>,
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) {
body(bind(xBinding.first), bind(yBinding.first))
}
fun diff(block: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>): SimpleAutoDiffExpression<Double, RealField> {
fun diff(block: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
return SimpleAutoDiffExpression(RealField, block)
}