forked from kscience/kmath
MPP Univariate Autodiff from Roman Elizarov
This commit is contained in:
parent
c77d9cbb08
commit
d0a724771d
@ -0,0 +1,182 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Implementation of backward-mode automatic differentiation.
|
||||||
|
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
||||||
|
* with respect to this variable.
|
||||||
|
*/
|
||||||
|
data class ValueWithDeriv(var x: Double, var d: Double = 0.0) {
|
||||||
|
constructor(x: Number) : this(x.toDouble())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs differentiation and establishes [Field<ValueWithDeriv>] context inside the block of code.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* ```
|
||||||
|
* val x = ValueWithDeriv(2) // define variable(s) and their values
|
||||||
|
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||||
|
* assertEquals(17.0, y.x) // the value of result (y)
|
||||||
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
fun deriv(body: AutoDiffField.() -> ValueWithDeriv): ValueWithDeriv =
|
||||||
|
ValueWithDerivField().run {
|
||||||
|
val result = body()
|
||||||
|
result.d = 1.0 // computing derivative w.r.t result
|
||||||
|
runBackwardPass()
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
abstract class AutoDiffField : Field<ValueWithDeriv> {
|
||||||
|
/**
|
||||||
|
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||||
|
*
|
||||||
|
* For example, implementation of `sin` function is:
|
||||||
|
*
|
||||||
|
* ```
|
||||||
|
* fun AD.sin(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(sin(x.x)) { z -> // call derive with function result
|
||||||
|
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
abstract fun <R> derive(value: R, block: (R) -> Unit): R
|
||||||
|
|
||||||
|
// Overloads for Double constants
|
||||||
|
|
||||||
|
operator fun Number.plus(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() + that.x)) { z ->
|
||||||
|
that.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
operator fun ValueWithDeriv.plus(b: Number): ValueWithDeriv = b.plus(this)
|
||||||
|
|
||||||
|
operator fun Number.minus(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() - that.x)) { z ->
|
||||||
|
that.d -= z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
operator fun ValueWithDeriv.minus(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x - that.toDouble())) { z ->
|
||||||
|
this.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Number.times(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() * that.x)) { z ->
|
||||||
|
that.d += z.d * this.toDouble()
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun ValueWithDeriv.times(b: Number): ValueWithDeriv = b.times(this)
|
||||||
|
|
||||||
|
override operator fun Number.div(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() / that.x)) { z ->
|
||||||
|
that.d -= z.d * this.toDouble() / (that.x * that.x)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun ValueWithDeriv.div(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x / that.toDouble())) { z ->
|
||||||
|
this.d += z.d / that.toDouble()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatic Differentiation context class.
|
||||||
|
*/
|
||||||
|
private class ValueWithDerivField : AutoDiffField() {
|
||||||
|
|
||||||
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
|
private var stack = arrayOfNulls<Any?>(8)
|
||||||
|
private var sp = 0
|
||||||
|
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
override fun <R> derive(value: R, block: (R) -> Unit): R {
|
||||||
|
// save block to stack for backward pass
|
||||||
|
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||||
|
stack[sp++] = block
|
||||||
|
stack[sp++] = value
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
fun runBackwardPass() {
|
||||||
|
while (sp > 0) {
|
||||||
|
val value = stack[--sp]
|
||||||
|
val block = stack[--sp] as (Any?) -> Unit
|
||||||
|
block(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
|
|
||||||
|
override fun add(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv =
|
||||||
|
derive(ValueWithDeriv(a.x + b.x)) { z ->
|
||||||
|
a.d += z.d
|
||||||
|
b.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv =
|
||||||
|
derive(ValueWithDeriv(a.x * b.x)) { z ->
|
||||||
|
a.d += z.d * b.x
|
||||||
|
b.d += z.d * a.x
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun divide(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv =
|
||||||
|
derive(ValueWithDeriv(a.x / b.x)) { z ->
|
||||||
|
a.d += z.d / b.x
|
||||||
|
b.d -= z.d * a.x / (b.x * b.x)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: ValueWithDeriv, k: Number): ValueWithDeriv =
|
||||||
|
derive(ValueWithDeriv(k.toDouble() * a.x)) { z ->
|
||||||
|
a.d += z.d * k.toDouble()
|
||||||
|
}
|
||||||
|
|
||||||
|
override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0, 0.0)
|
||||||
|
override val one: ValueWithDeriv get() = ValueWithDeriv(1.0, 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
|
// x ^ 2
|
||||||
|
fun AutoDiffField.sqr(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(x.x * x.x)) { z ->
|
||||||
|
x.d += z.d * 2 * x.x
|
||||||
|
}
|
||||||
|
|
||||||
|
// x ^ 1/2
|
||||||
|
fun AutoDiffField.sqrt(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(sqrt(x.x))) { z ->
|
||||||
|
x.d += z.d * 0.5 / z.x
|
||||||
|
}
|
||||||
|
|
||||||
|
// x ^ y (const)
|
||||||
|
fun AutoDiffField.pow(x: ValueWithDeriv, y: Double): ValueWithDeriv = derive(ValueWithDeriv(x.x.pow(y))) { z ->
|
||||||
|
x.d += z.d * y * x.x.pow(y - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun AutoDiffField.pow(x: ValueWithDeriv, y: Int): ValueWithDeriv = pow(x, y.toDouble())
|
||||||
|
|
||||||
|
// exp(x)
|
||||||
|
fun AutoDiffField.exp(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.exp(x.x))) { z ->
|
||||||
|
x.d += z.d * z.x
|
||||||
|
}
|
||||||
|
|
||||||
|
// ln(x)
|
||||||
|
fun AutoDiffField.ln(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.ln(x.x))) { z ->
|
||||||
|
x.d += z.d / x.x
|
||||||
|
}
|
||||||
|
|
||||||
|
// x ^ y (any)
|
||||||
|
fun AutoDiffField.pow(x: ValueWithDeriv, y: ValueWithDeriv): ValueWithDeriv = exp(y * ln(x))
|
||||||
|
|
||||||
|
// sin(x)
|
||||||
|
fun AutoDiffField.sin(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.sin(x.x))) { z ->
|
||||||
|
x.d += z.d * kotlin.math.cos(x.x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cos(x)
|
||||||
|
fun AutoDiffField.cos(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.cos(x.x))) { z ->
|
||||||
|
x.d -= z.d * kotlin.math.sin(x.x)
|
||||||
|
}
|
@ -0,0 +1,155 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import kotlin.math.PI
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class AutoDiffTest {
|
||||||
|
@Test
|
||||||
|
fun testPlusX2() {
|
||||||
|
val x = ValueWithDeriv(3) // diff w.r.t this x at 3
|
||||||
|
val y = deriv { x + x }
|
||||||
|
assertEquals(6.0, y.x) // y = x + x = 6
|
||||||
|
assertEquals(2.0, x.d) // dy/dx = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlus() {
|
||||||
|
// two variables
|
||||||
|
val x = ValueWithDeriv(2)
|
||||||
|
val y = ValueWithDeriv(3)
|
||||||
|
val z = deriv { x + y }
|
||||||
|
assertEquals(5.0, z.x) // z = x + y = 5
|
||||||
|
assertEquals(1.0, x.d) // dz/dx = 1
|
||||||
|
assertEquals(1.0, y.d) // dz/dy = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMinus() {
|
||||||
|
// two variables
|
||||||
|
val x = ValueWithDeriv(7)
|
||||||
|
val y = ValueWithDeriv(3)
|
||||||
|
val z = deriv { x - y }
|
||||||
|
assertEquals(4.0, z.x) // z = x - y = 4
|
||||||
|
assertEquals(1.0, x.d) // dz/dx = 1
|
||||||
|
assertEquals(-1.0, y.d) // dz/dy = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMulX2() {
|
||||||
|
val x = ValueWithDeriv(3) // diff w.r.t this x at 3
|
||||||
|
val y = deriv { x * x }
|
||||||
|
assertEquals(9.0, y.x) // y = x * x = 9
|
||||||
|
assertEquals(6.0, x.d) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqr() {
|
||||||
|
val x = ValueWithDeriv(3)
|
||||||
|
val y = deriv { sqr(x) }
|
||||||
|
assertEquals(9.0, y.x) // y = x ^ 2 = 9
|
||||||
|
assertEquals(6.0, x.d) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrSqr() {
|
||||||
|
val x = ValueWithDeriv(2)
|
||||||
|
val y = deriv { sqr(sqr(x)) }
|
||||||
|
assertEquals(16.0, y.x) // y = x ^ 4 = 16
|
||||||
|
assertEquals(32.0, x.d) // dy/dx = 4 * x^3 = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testX3() {
|
||||||
|
val x = ValueWithDeriv(2) // diff w.r.t this x at 2
|
||||||
|
val y = deriv { x * x * x }
|
||||||
|
assertEquals(8.0, y.x) // y = x * x * x = 8
|
||||||
|
assertEquals(12.0, x.d) // dy/dx = 3 * x * x = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDiv() {
|
||||||
|
val x = ValueWithDeriv(5)
|
||||||
|
val y = ValueWithDeriv(2)
|
||||||
|
val z = deriv { x / y }
|
||||||
|
assertEquals(2.5, z.x) // z = x / y = 2.5
|
||||||
|
assertEquals(0.5, x.d) // dz/dx = 1 / y = 0.5
|
||||||
|
assertEquals(-1.25, y.d) // dz/dy = -x / y^2 = -1.25
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPow3() {
|
||||||
|
val x = ValueWithDeriv(2) // diff w.r.t this x at 2
|
||||||
|
val y = deriv { pow(x, 3) }
|
||||||
|
assertEquals(8.0, y.x) // y = x ^ 3 = 8
|
||||||
|
assertEquals(12.0, x.d) // dy/dx = 3 * x ^ 2 = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPowFull() {
|
||||||
|
val x = ValueWithDeriv(2)
|
||||||
|
val y = ValueWithDeriv(3)
|
||||||
|
val z = deriv { pow(x, y) }
|
||||||
|
assertApprox(8.0, z.x) // z = x ^ y = 8
|
||||||
|
assertApprox(12.0, x.d) // dz/dx = y * x ^ (y - 1) = 12
|
||||||
|
assertApprox(8.0 * kotlin.math.ln(2.0), y.d) // dz/dy = x ^ y * ln(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testFromPaper() {
|
||||||
|
val x = ValueWithDeriv(3)
|
||||||
|
val y = deriv { 2 * x + x * x * x }
|
||||||
|
assertEquals(33.0, y.x) // y = 2 * x + x * x * x = 33
|
||||||
|
assertEquals(29.0, x.d) // dy/dx = 2 + 3 * x * x = 29
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLongChain() {
|
||||||
|
val n = 10_000
|
||||||
|
val x = ValueWithDeriv(1)
|
||||||
|
val y = deriv {
|
||||||
|
var pow = ValueWithDeriv(1)
|
||||||
|
for (i in 1..n) pow *= x
|
||||||
|
pow
|
||||||
|
}
|
||||||
|
assertEquals(1.0, y.x) // y = x ^ n = 1
|
||||||
|
assertEquals(n.toDouble(), x.d) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testExample() {
|
||||||
|
val x = ValueWithDeriv(2)
|
||||||
|
val y = deriv { sqr(x) + 5 * x + 3*one }
|
||||||
|
assertEquals(17.0, y.x) // the value of result (y)
|
||||||
|
assertEquals(9.0, x.d) // dy/dx
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrt() {
|
||||||
|
val x = ValueWithDeriv(16)
|
||||||
|
val y = deriv { sqrt(x) }
|
||||||
|
assertEquals(4.0, y.x) // y = x ^ 1/2 = 4
|
||||||
|
assertEquals(1.0 / 8, x.d) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSin() {
|
||||||
|
val x = ValueWithDeriv(PI / 6)
|
||||||
|
val y = deriv { sin(x) }
|
||||||
|
assertApprox(0.5, y.x) // y = sin(PI/6) = 0.5
|
||||||
|
assertApprox(kotlin.math.sqrt(3.0) / 2, x.d) // dy/dx = cos(PI/6) = sqrt(3)/2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCos() {
|
||||||
|
val x = ValueWithDeriv(PI / 6)
|
||||||
|
val y = deriv { cos(x) }
|
||||||
|
assertApprox(kotlin.math.sqrt(3.0) / 2, y.x) // y = cos(PI/6) = sqrt(3)/2
|
||||||
|
assertApprox(-0.5, x.d) // dy/dx = -sin(PI/6) = -0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun assertApprox(a: Double, b: Double) {
|
||||||
|
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user