From 6f9b704aa703799666aa9342266ba2eb9a7a386d Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 5 May 2019 09:39:51 +0300 Subject: [PATCH] Safe modification of autodiff --- build.gradle.kts | 2 +- .../scientifik/kmath/operations/AutoDiff.kt | 80 +++++++++------- .../kmath/operations/AutoDiffTest.kt | 96 ++++++++++--------- 3 files changed, 100 insertions(+), 78 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 5f60bc2f0..4ec4695f5 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,4 +1,4 @@ -val kmathVersion by extra("0.1.2-dev-3") +val kmathVersion by extra("0.1.2-dev-4") allprojects { repositories { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt index ab007b976..454d386b0 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt @@ -12,12 +12,12 @@ import kotlin.math.sqrt * Differentiable variable with value and derivative of differentiation ([deriv]) result * with respect to this variable. */ -data class ValueWithDeriv(var x: Double) { +open class Variable(val x: Double) { constructor(x: Number) : this(x.toDouble()) +} - //TODO move set accessor inside AutoDiffField - var d: Double = 0.0 - internal set +class DerivationResult(x: Double, val deriv: Map): Variable(x) { + fun deriv(variable: Variable) = deriv[variable] ?: 0.0 } /** @@ -27,48 +27,53 @@ data class ValueWithDeriv(var x: Double) { * * Example: * ``` - * val x = ValueWithDeriv(2) // define variable(s) and their values + * val x = Variable(2) // define variable(s) and their values * val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context * assertEquals(17.0, y.x) // the value of result (y) * assertEquals(9.0, x.d) // dy/dx * ``` */ -fun deriv(body: AutoDiffField.() -> ValueWithDeriv): ValueWithDeriv = - AutoDiffFieldImpl().run { +fun deriv(body: AutoDiffField.() -> Variable): DerivationResult = + AutoDiffContext().run { val result = body() result.d = 1.0 // computing derivative w.r.t result runBackwardPass() - result + DerivationResult(result.x, derivatives) } -abstract class AutoDiffField : Field { +abstract class AutoDiffField : Field { /** * Performs update of derivative after the rest of the formula in the back-pass. * * For example, implementation of `sin` function is: * * ``` - * fun AD.sin(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(sin(x.x)) { z -> // call derive with function result + * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function * } * ``` */ abstract fun derive(value: R, block: (R) -> Unit): R + /** + * A variable accessing inner state of derivatives. Use only in extensions + */ + abstract var Variable.d: Double + // Overloads for Double constants - operator fun Number.plus(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() + that.x)) { z -> + operator fun Number.plus(that: Variable): Variable = derive(Variable(this.toDouble() + that.x)) { z -> that.d += z.d } - operator fun ValueWithDeriv.plus(b: Number): ValueWithDeriv = b.plus(this) + operator fun Variable.plus(b: Number): Variable = b.plus(this) - operator fun Number.minus(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() - that.x)) { z -> + operator fun Number.minus(that: Variable): Variable = derive(Variable(this.toDouble() - that.x)) { z -> that.d -= z.d } - operator fun ValueWithDeriv.minus(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x - that.toDouble())) { z -> + operator fun Variable.minus(that: Number): Variable = derive(Variable(this.x - that.toDouble())) { z -> this.d += z.d } } @@ -76,12 +81,19 @@ abstract class AutoDiffField : Field { /** * Automatic Differentiation context class. */ -private class AutoDiffFieldImpl : AutoDiffField() { +private class AutoDiffContext : AutoDiffField() { // this stack contains pairs of blocks and values to apply them to private var stack = arrayOfNulls(8) private var sp = 0 + internal val derivatives = HashMap() + + override var Variable.d: Double + get() = derivatives[this] ?: 0.0 + set(value) { + derivatives[this] = value + } @Suppress("UNCHECKED_CAST") override fun derive(value: R, block: (R) -> Unit): R { @@ -104,71 +116,71 @@ private class AutoDiffFieldImpl : AutoDiffField() { // Basic math (+, -, *, /) - override fun add(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv = - derive(ValueWithDeriv(a.x + b.x)) { z -> + override fun add(a: Variable, b: Variable): Variable = + derive(Variable(a.x + b.x)) { z -> a.d += z.d b.d += z.d } - override fun multiply(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv = - derive(ValueWithDeriv(a.x * b.x)) { z -> + override fun multiply(a: Variable, b: Variable): Variable = + derive(Variable(a.x * b.x)) { z -> a.d += z.d * b.x b.d += z.d * a.x } - override fun divide(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv = - derive(ValueWithDeriv(a.x / b.x)) { z -> + override fun divide(a: Variable, b: Variable): Variable = + derive(Variable(a.x / b.x)) { z -> a.d += z.d / b.x b.d -= z.d * a.x / (b.x * b.x) } - override fun multiply(a: ValueWithDeriv, k: Number): ValueWithDeriv = - derive(ValueWithDeriv(k.toDouble() * a.x)) { z -> + override fun multiply(a: Variable, k: Number): Variable = + derive(Variable(k.toDouble() * a.x)) { z -> a.d += z.d * k.toDouble() } - override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0) - override val one: ValueWithDeriv get() = ValueWithDeriv(1.0) + override val zero: Variable get() = Variable(0.0) + override val one: Variable get() = Variable(1.0) } // Extensions for differentiation of various basic mathematical functions // x ^ 2 -fun AutoDiffField.sqr(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(x.x * x.x)) { z -> +fun AutoDiffField.sqr(x: Variable): Variable = derive(Variable(x.x * x.x)) { z -> x.d += z.d * 2 * x.x } // x ^ 1/2 -fun AutoDiffField.sqrt(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(sqrt(x.x))) { z -> +fun AutoDiffField.sqrt(x: Variable): Variable = derive(Variable(sqrt(x.x))) { z -> x.d += z.d * 0.5 / z.x } // x ^ y (const) -fun AutoDiffField.pow(x: ValueWithDeriv, y: Double): ValueWithDeriv = derive(ValueWithDeriv(x.x.pow(y))) { z -> +fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(Variable(x.x.pow(y))) { z -> x.d += z.d * y * x.x.pow(y - 1) } -fun AutoDiffField.pow(x: ValueWithDeriv, y: Int): ValueWithDeriv = pow(x, y.toDouble()) +fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble()) // exp(x) -fun AutoDiffField.exp(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.exp(x.x))) { z -> +fun AutoDiffField.exp(x: Variable): Variable = derive(Variable(kotlin.math.exp(x.x))) { z -> x.d += z.d * z.x } // ln(x) -fun AutoDiffField.ln(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.ln(x.x))) { z -> +fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.x))) { z -> x.d += z.d / x.x } // x ^ y (any) -fun AutoDiffField.pow(x: ValueWithDeriv, y: ValueWithDeriv): ValueWithDeriv = exp(y * ln(x)) +fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x)) // sin(x) -fun AutoDiffField.sin(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.sin(x.x))) { z -> +fun AutoDiffField.sin(x: Variable): Variable = derive(Variable(kotlin.math.sin(x.x))) { z -> x.d += z.d * kotlin.math.cos(x.x) } // cos(x) -fun AutoDiffField.cos(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.cos(x.x))) { z -> +fun AutoDiffField.cos(x: Variable): Variable = derive(Variable(kotlin.math.cos(x.x))) { z -> x.d -= z.d * kotlin.math.sin(x.x) } \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt index 736c301fd..9adc2f977 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt @@ -7,145 +7,155 @@ import kotlin.test.assertEquals class AutoDiffTest { @Test fun testPlusX2() { - val x = ValueWithDeriv(3) // diff w.r.t this x at 3 + val x = Variable(3) // diff w.r.t this x at 3 val y = deriv { x + x } assertEquals(6.0, y.x) // y = x + x = 6 - assertEquals(2.0, x.d) // dy/dx = 2 + assertEquals(2.0, y.deriv(x)) // dy/dx = 2 } @Test fun testPlus() { // two variables - val x = ValueWithDeriv(2) - val y = ValueWithDeriv(3) + val x = Variable(2) + val y = Variable(3) val z = deriv { x + y } assertEquals(5.0, z.x) // z = x + y = 5 - assertEquals(1.0, x.d) // dz/dx = 1 - assertEquals(1.0, y.d) // dz/dy = 1 + assertEquals(1.0, z.deriv(x)) // dz/dx = 1 + assertEquals(1.0, z.deriv(y)) // dz/dy = 1 } @Test fun testMinus() { // two variables - val x = ValueWithDeriv(7) - val y = ValueWithDeriv(3) + val x = Variable(7) + val y = Variable(3) val z = deriv { x - y } assertEquals(4.0, z.x) // z = x - y = 4 - assertEquals(1.0, x.d) // dz/dx = 1 - assertEquals(-1.0, y.d) // dz/dy = -1 + assertEquals(1.0, z.deriv(x)) // dz/dx = 1 + assertEquals(-1.0, z.deriv(y)) // dz/dy = -1 } @Test fun testMulX2() { - val x = ValueWithDeriv(3) // diff w.r.t this x at 3 + val x = Variable(3) // diff w.r.t this x at 3 val y = deriv { x * x } assertEquals(9.0, y.x) // y = x * x = 9 - assertEquals(6.0, x.d) // dy/dx = 2 * x = 7 + assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7 } @Test fun testSqr() { - val x = ValueWithDeriv(3) + val x = Variable(3) val y = deriv { sqr(x) } assertEquals(9.0, y.x) // y = x ^ 2 = 9 - assertEquals(6.0, x.d) // dy/dx = 2 * x = 7 + assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7 } @Test fun testSqrSqr() { - val x = ValueWithDeriv(2) + val x = Variable(2) val y = deriv { sqr(sqr(x)) } assertEquals(16.0, y.x) // y = x ^ 4 = 16 - assertEquals(32.0, x.d) // dy/dx = 4 * x^3 = 32 + assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32 } @Test fun testX3() { - val x = ValueWithDeriv(2) // diff w.r.t this x at 2 + val x = Variable(2) // diff w.r.t this x at 2 val y = deriv { x * x * x } assertEquals(8.0, y.x) // y = x * x * x = 8 - assertEquals(12.0, x.d) // dy/dx = 3 * x * x = 12 + assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12 } @Test fun testDiv() { - val x = ValueWithDeriv(5) - val y = ValueWithDeriv(2) + val x = Variable(5) + val y = Variable(2) val z = deriv { x / y } assertEquals(2.5, z.x) // z = x / y = 2.5 - assertEquals(0.5, x.d) // dz/dx = 1 / y = 0.5 - assertEquals(-1.25, y.d) // dz/dy = -x / y^2 = -1.25 + assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5 + assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25 } @Test fun testPow3() { - val x = ValueWithDeriv(2) // diff w.r.t this x at 2 + val x = Variable(2) // diff w.r.t this x at 2 val y = deriv { pow(x, 3) } assertEquals(8.0, y.x) // y = x ^ 3 = 8 - assertEquals(12.0, x.d) // dy/dx = 3 * x ^ 2 = 12 + assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12 } @Test fun testPowFull() { - val x = ValueWithDeriv(2) - val y = ValueWithDeriv(3) + val x = Variable(2) + val y = Variable(3) val z = deriv { pow(x, y) } assertApprox(8.0, z.x) // z = x ^ y = 8 - assertApprox(12.0, x.d) // dz/dx = y * x ^ (y - 1) = 12 - assertApprox(8.0 * kotlin.math.ln(2.0), y.d) // dz/dy = x ^ y * ln(x) + assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12 + assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x) } @Test fun testFromPaper() { - val x = ValueWithDeriv(3) + val x = Variable(3) val y = deriv { 2 * x + x * x * x } assertEquals(33.0, y.x) // y = 2 * x + x * x * x = 33 - assertEquals(29.0, x.d) // dy/dx = 2 + 3 * x * x = 29 + assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29 + } + + @Test + fun testInnerVariable() { + val x = Variable(1) + val y = deriv { + Variable(1) * x + } + assertEquals(1.0, y.x) // y = x ^ n = 1 + assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1 } @Test fun testLongChain() { val n = 10_000 - val x = ValueWithDeriv(1) + val x = Variable(1) val y = deriv { - var pow = ValueWithDeriv(1) - for (i in 1..n) pow *= x - pow + var res = Variable(1) + for (i in 1..n) res *= x + res } assertEquals(1.0, y.x) // y = x ^ n = 1 - assertEquals(n.toDouble(), x.d) // dy/dx = n * x ^ (n - 1) = n - 1 + assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1 } @Test fun testExample() { - val x = ValueWithDeriv(2) + val x = Variable(2) val y = deriv { sqr(x) + 5 * x + 3 } assertEquals(17.0, y.x) // the value of result (y) - assertEquals(9.0, x.d) // dy/dx + assertEquals(9.0, y.deriv(x)) // dy/dx } @Test fun testSqrt() { - val x = ValueWithDeriv(16) + val x = Variable(16) val y = deriv { sqrt(x) } assertEquals(4.0, y.x) // y = x ^ 1/2 = 4 - assertEquals(1.0 / 8, x.d) // dy/dx = 1/2 / x ^ 1/4 = 1/8 + assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8 } @Test fun testSin() { - val x = ValueWithDeriv(PI / 6) + val x = Variable(PI / 6) val y = deriv { sin(x) } assertApprox(0.5, y.x) // y = sin(PI/6) = 0.5 - assertApprox(kotlin.math.sqrt(3.0) / 2, x.d) // dy/dx = cos(PI/6) = sqrt(3)/2 + assertApprox(kotlin.math.sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(PI/6) = sqrt(3)/2 } @Test fun testCos() { - val x = ValueWithDeriv(PI / 6) + val x = Variable(PI / 6) val y = deriv { cos(x) } assertApprox(kotlin.math.sqrt(3.0) / 2, y.x) // y = cos(PI/6) = sqrt(3)/2 - assertApprox(-0.5, x.d) // dy/dx = -sin(PI/6) = -0.5 + assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(PI/6) = -0.5 } private fun assertApprox(a: Double, b: Double) {