From d0a724771d80e60ae3e8025a8e2db20cff885563 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sat, 4 May 2019 22:21:22 +0300 Subject: [PATCH] MPP Univariate Autodiff from Roman Elizarov --- .../scientifik/kmath/operations/AutoDiff.kt | 182 ++++++++++++++++++ .../kmath/operations/AutoDiffTest.kt | 155 +++++++++++++++ 2 files changed, 337 insertions(+) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt create mode 100644 kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt new file mode 100644 index 000000000..4b0070545 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt @@ -0,0 +1,182 @@ +package scientifik.kmath.operations + +import kotlin.math.pow +import kotlin.math.sqrt + +/* + * Implementation of backward-mode automatic differentiation. + * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d + */ + +/** + * Differentiable variable with value and derivative of differentiation ([deriv]) result + * with respect to this variable. + */ +data class ValueWithDeriv(var x: Double, var d: Double = 0.0) { + constructor(x: Number) : this(x.toDouble()) +} + +/** + * Runs differentiation and establishes [Field] context inside the block of code. + * + * Example: + * ``` + * val x = ValueWithDeriv(2) // define variable(s) and their values + * val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context + * assertEquals(17.0, y.x) // the value of result (y) + * assertEquals(9.0, x.d) // dy/dx + * ``` + */ +fun deriv(body: AutoDiffField.() -> ValueWithDeriv): ValueWithDeriv = + ValueWithDerivField().run { + val result = body() + result.d = 1.0 // computing derivative w.r.t result + runBackwardPass() + result + } + + +abstract class AutoDiffField : Field { + /** + * Performs update of derivative after the rest of the formula in the back-pass. + * + * For example, implementation of `sin` function is: + * + * ``` + * fun AD.sin(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(sin(x.x)) { z -> // call derive with function result + * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function + * } + * ``` + */ + abstract fun derive(value: R, block: (R) -> Unit): R + + // Overloads for Double constants + + operator fun Number.plus(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() + that.x)) { z -> + that.d += z.d + } + + operator fun ValueWithDeriv.plus(b: Number): ValueWithDeriv = b.plus(this) + + operator fun Number.minus(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() - that.x)) { z -> + that.d -= z.d + } + + operator fun ValueWithDeriv.minus(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x - that.toDouble())) { z -> + this.d += z.d + } + + override operator fun Number.times(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() * that.x)) { z -> + that.d += z.d * this.toDouble() + } + + override operator fun ValueWithDeriv.times(b: Number): ValueWithDeriv = b.times(this) + + override operator fun Number.div(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() / that.x)) { z -> + that.d -= z.d * this.toDouble() / (that.x * that.x) + } + + override operator fun ValueWithDeriv.div(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x / that.toDouble())) { z -> + this.d += z.d / that.toDouble() + } +} + +/** + * Automatic Differentiation context class. + */ +private class ValueWithDerivField : AutoDiffField() { + + // this stack contains pairs of blocks and values to apply them to + private var stack = arrayOfNulls(8) + private var sp = 0 + + + @Suppress("UNCHECKED_CAST") + override fun derive(value: R, block: (R) -> Unit): R { + // save block to stack for backward pass + if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) + stack[sp++] = block + stack[sp++] = value + return value + } + + @Suppress("UNCHECKED_CAST") + fun runBackwardPass() { + while (sp > 0) { + val value = stack[--sp] + val block = stack[--sp] as (Any?) -> Unit + block(value) + } + } + + // Basic math (+, -, *, /) + + + override fun add(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv = + derive(ValueWithDeriv(a.x + b.x)) { z -> + a.d += z.d + b.d += z.d + } + + override fun multiply(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv = + derive(ValueWithDeriv(a.x * b.x)) { z -> + a.d += z.d * b.x + b.d += z.d * a.x + } + + override fun divide(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv = + derive(ValueWithDeriv(a.x / b.x)) { z -> + a.d += z.d / b.x + b.d -= z.d * a.x / (b.x * b.x) + } + + override fun multiply(a: ValueWithDeriv, k: Number): ValueWithDeriv = + derive(ValueWithDeriv(k.toDouble() * a.x)) { z -> + a.d += z.d * k.toDouble() + } + + override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0, 0.0) + override val one: ValueWithDeriv get() = ValueWithDeriv(1.0, 0.0) +} + +// Extensions for differentiation of various basic mathematical functions + +// x ^ 2 +fun AutoDiffField.sqr(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(x.x * x.x)) { z -> + x.d += z.d * 2 * x.x +} + +// x ^ 1/2 +fun AutoDiffField.sqrt(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(sqrt(x.x))) { z -> + x.d += z.d * 0.5 / z.x +} + +// x ^ y (const) +fun AutoDiffField.pow(x: ValueWithDeriv, y: Double): ValueWithDeriv = derive(ValueWithDeriv(x.x.pow(y))) { z -> + x.d += z.d * y * x.x.pow(y - 1) +} + +fun AutoDiffField.pow(x: ValueWithDeriv, y: Int): ValueWithDeriv = pow(x, y.toDouble()) + +// exp(x) +fun AutoDiffField.exp(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.exp(x.x))) { z -> + x.d += z.d * z.x +} + +// ln(x) +fun AutoDiffField.ln(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.ln(x.x))) { z -> + x.d += z.d / x.x +} + +// x ^ y (any) +fun AutoDiffField.pow(x: ValueWithDeriv, y: ValueWithDeriv): ValueWithDeriv = exp(y * ln(x)) + +// sin(x) +fun AutoDiffField.sin(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.sin(x.x))) { z -> + x.d += z.d * kotlin.math.cos(x.x) +} + +// cos(x) +fun AutoDiffField.cos(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.cos(x.x))) { z -> + x.d -= z.d * kotlin.math.sin(x.x) +} \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt new file mode 100644 index 000000000..6f7299e9e --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt @@ -0,0 +1,155 @@ +package scientifik.kmath.operations + +import kotlin.math.PI +import kotlin.test.Test +import kotlin.test.assertEquals + +class AutoDiffTest { + @Test + fun testPlusX2() { + val x = ValueWithDeriv(3) // diff w.r.t this x at 3 + val y = deriv { x + x } + assertEquals(6.0, y.x) // y = x + x = 6 + assertEquals(2.0, x.d) // dy/dx = 2 + } + + @Test + fun testPlus() { + // two variables + val x = ValueWithDeriv(2) + val y = ValueWithDeriv(3) + val z = deriv { x + y } + assertEquals(5.0, z.x) // z = x + y = 5 + assertEquals(1.0, x.d) // dz/dx = 1 + assertEquals(1.0, y.d) // dz/dy = 1 + } + + @Test + fun testMinus() { + // two variables + val x = ValueWithDeriv(7) + val y = ValueWithDeriv(3) + val z = deriv { x - y } + assertEquals(4.0, z.x) // z = x - y = 4 + assertEquals(1.0, x.d) // dz/dx = 1 + assertEquals(-1.0, y.d) // dz/dy = -1 + } + + @Test + fun testMulX2() { + val x = ValueWithDeriv(3) // diff w.r.t this x at 3 + val y = deriv { x * x } + assertEquals(9.0, y.x) // y = x * x = 9 + assertEquals(6.0, x.d) // dy/dx = 2 * x = 7 + } + + @Test + fun testSqr() { + val x = ValueWithDeriv(3) + val y = deriv { sqr(x) } + assertEquals(9.0, y.x) // y = x ^ 2 = 9 + assertEquals(6.0, x.d) // dy/dx = 2 * x = 7 + } + + @Test + fun testSqrSqr() { + val x = ValueWithDeriv(2) + val y = deriv { sqr(sqr(x)) } + assertEquals(16.0, y.x) // y = x ^ 4 = 16 + assertEquals(32.0, x.d) // dy/dx = 4 * x^3 = 32 + } + + @Test + fun testX3() { + val x = ValueWithDeriv(2) // diff w.r.t this x at 2 + val y = deriv { x * x * x } + assertEquals(8.0, y.x) // y = x * x * x = 8 + assertEquals(12.0, x.d) // dy/dx = 3 * x * x = 12 + } + + @Test + fun testDiv() { + val x = ValueWithDeriv(5) + val y = ValueWithDeriv(2) + val z = deriv { x / y } + assertEquals(2.5, z.x) // z = x / y = 2.5 + assertEquals(0.5, x.d) // dz/dx = 1 / y = 0.5 + assertEquals(-1.25, y.d) // dz/dy = -x / y^2 = -1.25 + } + + @Test + fun testPow3() { + val x = ValueWithDeriv(2) // diff w.r.t this x at 2 + val y = deriv { pow(x, 3) } + assertEquals(8.0, y.x) // y = x ^ 3 = 8 + assertEquals(12.0, x.d) // dy/dx = 3 * x ^ 2 = 12 + } + + @Test + fun testPowFull() { + val x = ValueWithDeriv(2) + val y = ValueWithDeriv(3) + val z = deriv { pow(x, y) } + assertApprox(8.0, z.x) // z = x ^ y = 8 + assertApprox(12.0, x.d) // dz/dx = y * x ^ (y - 1) = 12 + assertApprox(8.0 * kotlin.math.ln(2.0), y.d) // dz/dy = x ^ y * ln(x) + } + + @Test + fun testFromPaper() { + val x = ValueWithDeriv(3) + val y = deriv { 2 * x + x * x * x } + assertEquals(33.0, y.x) // y = 2 * x + x * x * x = 33 + assertEquals(29.0, x.d) // dy/dx = 2 + 3 * x * x = 29 + } + + @Test + fun testLongChain() { + val n = 10_000 + val x = ValueWithDeriv(1) + val y = deriv { + var pow = ValueWithDeriv(1) + for (i in 1..n) pow *= x + pow + } + assertEquals(1.0, y.x) // y = x ^ n = 1 + assertEquals(n.toDouble(), x.d) // dy/dx = n * x ^ (n - 1) = n - 1 + } + + @Test + fun testExample() { + val x = ValueWithDeriv(2) + val y = deriv { sqr(x) + 5 * x + 3*one } + assertEquals(17.0, y.x) // the value of result (y) + assertEquals(9.0, x.d) // dy/dx + } + + @Test + fun testSqrt() { + val x = ValueWithDeriv(16) + val y = deriv { sqrt(x) } + assertEquals(4.0, y.x) // y = x ^ 1/2 = 4 + assertEquals(1.0 / 8, x.d) // dy/dx = 1/2 / x ^ 1/4 = 1/8 + } + + @Test + fun testSin() { + val x = ValueWithDeriv(PI / 6) + val y = deriv { sin(x) } + assertApprox(0.5, y.x) // y = sin(PI/6) = 0.5 + assertApprox(kotlin.math.sqrt(3.0) / 2, x.d) // dy/dx = cos(PI/6) = sqrt(3)/2 + } + + @Test + fun testCos() { + val x = ValueWithDeriv(PI / 6) + val y = deriv { cos(x) } + assertApprox(kotlin.math.sqrt(3.0) / 2, y.x) // y = cos(PI/6) = sqrt(3)/2 + assertApprox(-0.5, x.d) // dy/dx = -sin(PI/6) = -0.5 + } + + private fun assertApprox(a: Double, b: Double) { + if ((a - b) > 1e-10) assertEquals(a, b) + } + +} \ No newline at end of file