Generalized Elizarov's algorithm for ExtendedField

This commit is contained in:
Alexander Nozik 2019-06-06 18:50:40 +03:00
parent fcc6269ee8
commit cee39e0666
5 changed files with 243 additions and 222 deletions

View File

@ -0,0 +1,233 @@
package scientifik.kmath.misc
import scientifik.kmath.linear.Point
import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.sum
import scientifik.kmath.structures.asBuffer
/*
* Implementation of backward-mode automatic differentiation.
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
*/
/**
* Differentiable variable with value and derivative of differentiation ([deriv]) result
* with respect to this variable.
*/
open class Variable<T : Any>(val value: T)
class DerivationResult<T : Any>(
value: T,
val deriv: Map<Variable<T>, T>,
val context: ExtendedField<T>
) : Variable<T>(value) {
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero
/**
* compute divergence
*/
fun div() = context.run { sum(deriv.values) }
/**
* Compute a gradient for variables in given order
*/
fun grad(vararg variables: Variable<T>): Point<T> = if (variables.isEmpty()) {
error("Variable order is not provided for gradient construction")
} else {
variables.map(::deriv).asBuffer()
}
}
/**
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
*
* The partial derivatives are placed in argument `d` variable
*
* Example:
* ```
* val x = Variable(2) // define variable(s) and their values
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
* assertEquals(17.0, y.x) // the value of result (y)
* assertEquals(9.0, x.d) // dy/dx
* ```
*/
fun <T : Any> ExtendedField<T>.deriv(body: AutoDiffField<T>.() -> Variable<T>): DerivationResult<T> =
AutoDiffContext<T>(this).run {
val result = body()
result.d = context.one// computing derivative w.r.t result
runBackwardPass()
DerivationResult(result.value, derivatives, this@deriv)
}
abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
abstract val context: ExtendedField<T>
/**
* Performs update of derivative after the rest of the formula in the back-pass.
*
* For example, implementation of `sin` function is:
*
* ```
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
* }
* ```
*/
abstract fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R
/**
* A variable accessing inner state of derivatives.
* Use this function in inner builders to avoid creating additional derivative bindings
*/
abstract var Variable<T>.d: T
abstract fun variable(value: T): Variable<T>
inline fun variable(block: ExtendedField<T>.() -> T) = variable(context.block())
// Overloads for Double constants
operator fun Number.plus(that: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + that.value }) { z ->
that.d += z.d
}
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
operator fun Number.minus(that: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - that.value }) { z ->
that.d -= z.d
}
operator fun Variable<T>.minus(that: Number): Variable<T> =
derive(variable { this@minus.value - one * that.toDouble() }) { z ->
this@minus.d += z.d
}
}
/**
* Automatic Differentiation context class.
*/
private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) : AutoDiffField<T>() {
// this stack contains pairs of blocks and values to apply them to
private var stack = arrayOfNulls<Any?>(8)
private var sp = 0
internal val derivatives = HashMap<Variable<T>, T>()
/**
* A variable coupled with its derivative. For internal use only
*/
inner class VariableWithDeriv(x: T, var d: T = context.zero) : Variable<T>(x)
override fun variable(value: T): Variable<T> =
VariableWithDeriv(value)
override var Variable<T>.d: T
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
set(value) {
if (this is VariableWithDeriv) {
d = value
} else {
derivatives[this] = value
}
}
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
@Suppress("UNCHECKED_CAST")
fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as ExtendedField<T>.(Any?) -> Unit
context.block(value)
}
}
// Basic math (+, -, *, /)
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> =
derive(variable { a.value + b.value }) { z ->
a.d += z.d
b.d += z.d
}
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> =
derive(variable { a.value * b.value }) { z ->
a.d += z.d * b.value
b.d += z.d * a.value
}
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> =
derive(variable { a.value / b.value }) { z ->
a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value)
}
override fun multiply(a: Variable<T>, k: Number): Variable<T> =
derive(variable { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble()
}
override val zero: Variable<T> get() = Variable(context.zero)
override val one: Variable<T> get() = Variable(context.one)
}
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
fun <T : Any> AutoDiffField<T>.sqr(x: Variable<T>): Variable<T> = derive(variable { x.value * x.value }) { z ->
x.d += z.d * 2 * x.value
}
// x ^ 1/2
fun <T : Any> AutoDiffField<T>.sqrt(x: Variable<T>): Variable<T> = derive(variable { sqrt(x.value) }) { z ->
x.d += z.d * 0.5 / z.value
}
// x ^ y (const)
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Double): Variable<T> =
derive(variable { power(x.value, y) }) { z ->
x.d += z.d * y * power(x.value, y - 1)
}
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
// exp(x)
fun <T : Any> AutoDiffField<T>.exp(x: Variable<T>): Variable<T> = derive(variable { exp(x.value) }) { z ->
x.d += z.d * z.value
}
// ln(x)
fun <T : Any> AutoDiffField<T>.ln(x: Variable<T>): Variable<T> = derive(
variable { ln(x.value) }
) { z ->
x.d += z.d / x.value
}
// x ^ y (any)
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Variable<T>): Variable<T> = exp(y * ln(x))
// sin(x)
fun <T : Any> AutoDiffField<T>.sin(x: Variable<T>): Variable<T> = derive(variable { sin(x.value) }) { z ->
x.d += z.d * cos(x.value)
}
// cos(x)
fun <T : Any> AutoDiffField<T>.cos(x: Variable<T>): Variable<T> = derive(variable { cos(x.value) }) { z ->
x.d -= z.d * sin(x.value)
}

View File

@ -1,218 +0,0 @@
package scientifik.kmath.operations
import scientifik.kmath.linear.Point
import scientifik.kmath.structures.asBuffer
import kotlin.math.pow
import kotlin.math.sqrt
/*
* Implementation of backward-mode automatic differentiation.
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
*/
/**
* Differentiable variable with value and derivative of differentiation ([deriv]) result
* with respect to this variable.
*/
open class Variable(val value: Double) {
constructor(x: Number) : this(x.toDouble())
}
class DerivationResult(value: Double, val deriv: Map<Variable, Double>) : Variable(value) {
fun deriv(variable: Variable) = deriv[variable] ?: 0.0
/**
* compute divergence
*/
fun div() = deriv.values.sum()
/**
* Compute a gradient for variables in given order
*/
fun grad(vararg variables: Variable): Point<Double> = if (variables.isEmpty()) {
error("Variable order is not provided for gradient construction")
} else {
variables.map(::deriv).toDoubleArray().asBuffer()
}
}
/**
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
*
* The partial derivatives are placed in argument `d` variable
*
* Example:
* ```
* val x = Variable(2) // define variable(s) and their values
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
* assertEquals(17.0, y.x) // the value of result (y)
* assertEquals(9.0, x.d) // dy/dx
* ```
*/
fun deriv(body: AutoDiffField.() -> Variable): DerivationResult =
AutoDiffContext().run {
val result = body()
result.d = 1.0 // computing derivative w.r.t result
runBackwardPass()
DerivationResult(result.value, derivatives)
}
abstract class AutoDiffField : Field<Variable> {
/**
* Performs update of derivative after the rest of the formula in the back-pass.
*
* For example, implementation of `sin` function is:
*
* ```
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
* }
* ```
*/
abstract fun <R> derive(value: R, block: (R) -> Unit): R
/**
* A variable accessing inner state of derivatives.
* Use this function in inner builders to avoid creating additional derivative bindings
*/
abstract var Variable.d: Double
abstract fun variable(value: Double): Variable
// Overloads for Double constants
operator fun Number.plus(that: Variable): Variable = derive(variable(this.toDouble() + that.value)) { z ->
that.d += z.d
}
operator fun Variable.plus(b: Number): Variable = b.plus(this)
operator fun Number.minus(that: Variable): Variable = derive(variable(this.toDouble() - that.value)) { z ->
that.d -= z.d
}
operator fun Variable.minus(that: Number): Variable = derive(variable(this.value - that.toDouble())) { z ->
this.d += z.d
}
}
/**
* Automatic Differentiation context class.
*/
private class AutoDiffContext : AutoDiffField() {
// this stack contains pairs of blocks and values to apply them to
private var stack = arrayOfNulls<Any?>(8)
private var sp = 0
internal val derivatives = HashMap<Variable, Double>()
/**
* A variable coupled with its derivative. For internal use only
*/
class VariableWithDeriv(x: Double, var d: Double = 0.0) : Variable(x)
override fun variable(value: Double): Variable = VariableWithDeriv(value)
override var Variable.d: Double
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: 0.0
set(value) {
if (this is VariableWithDeriv) {
d = value
} else {
derivatives[this] = value
}
}
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: (R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
@Suppress("UNCHECKED_CAST")
fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as (Any?) -> Unit
block(value)
}
}
// Basic math (+, -, *, /)
override fun add(a: Variable, b: Variable): Variable =
derive(variable(a.value + b.value)) { z ->
a.d += z.d
b.d += z.d
}
override fun multiply(a: Variable, b: Variable): Variable =
derive(variable(a.value * b.value)) { z ->
a.d += z.d * b.value
b.d += z.d * a.value
}
override fun divide(a: Variable, b: Variable): Variable =
derive(Variable(a.value / b.value)) { z ->
a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value)
}
override fun multiply(a: Variable, k: Number): Variable =
derive(variable(k.toDouble() * a.value)) { z ->
a.d += z.d * k.toDouble()
}
override val zero: Variable get() = Variable(0.0)
override val one: Variable get() = Variable(1.0)
}
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
fun AutoDiffField.sqr(x: Variable): Variable = derive(variable(x.value * x.value)) { z ->
x.d += z.d * 2 * x.value
}
// x ^ 1/2
fun AutoDiffField.sqrt(x: Variable): Variable = derive(variable(sqrt(x.value))) { z ->
x.d += z.d * 0.5 / z.value
}
// x ^ y (const)
fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(variable(x.value.pow(y))) { z ->
x.d += z.d * y * x.value.pow(y - 1)
}
fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble())
// exp(x)
fun AutoDiffField.exp(x: Variable): Variable = derive(variable(kotlin.math.exp(x.value))) { z ->
x.d += z.d * z.value
}
// ln(x)
fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.value))) { z ->
x.d += z.d / x.value
}
// x ^ y (any)
fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x))
// sin(x)
fun AutoDiffField.sin(x: Variable): Variable = derive(variable(kotlin.math.sin(x.value))) { z ->
x.d += z.d * kotlin.math.cos(x.value)
}
// cos(x)
fun AutoDiffField.cos(x: Variable): Variable = derive(variable(kotlin.math.cos(x.value))) { z ->
x.d -= z.d * kotlin.math.sin(x.value)
}

View File

@ -33,7 +33,7 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
* A field for double without boxing. Does not produce appropriate field element
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
object RealField : ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double = 0.0
override inline fun add(a: Double, b: Double) = a + b
override inline fun multiply(a: Double, b: Double) = a * b
@ -64,7 +64,7 @@ object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double,
}
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object FloatField : Field<Float>, ExtendedFieldOperations<Float>, Norm<Float, Float> {
object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override val zero: Float = 0f
override inline fun add(a: Float, b: Float) = a + b
override inline fun multiply(a: Float, b: Float) = a * b

View File

@ -123,7 +123,7 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
override fun iterator(): Iterator<T> = list.iterator()
}
fun <T> List<T>.asBuffer() = ListBuffer(this)
fun <T> List<T>.asBuffer() = ListBuffer<T>(this)
@Suppress("FunctionName")
inline fun <T> ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer()

View File

@ -1,5 +1,6 @@
package scientifik.kmath.operations
package scientifik.kmath.misc
import scientifik.kmath.operations.RealField
import scientifik.kmath.structures.asBuffer
import kotlin.math.PI
import kotlin.test.Test
@ -7,6 +8,11 @@ import kotlin.test.assertEquals
import kotlin.test.assertTrue
class AutoDiffTest {
fun Variable(int: Int) = Variable(int.toDouble())
fun deriv(body: AutoDiffField<Double>.() -> Variable<Double>) = RealField.deriv(body)
@Test
fun testPlusX2() {
val x = Variable(3) // diff w.r.t this x at 3