Safe modification of autodiff

This commit is contained in:
Alexander Nozik 2019-05-05 09:39:51 +03:00
parent 765097cbbe
commit 6f9b704aa7
3 changed files with 100 additions and 78 deletions

View File

@ -1,4 +1,4 @@
val kmathVersion by extra("0.1.2-dev-3") val kmathVersion by extra("0.1.2-dev-4")
allprojects { allprojects {
repositories { repositories {

View File

@ -12,12 +12,12 @@ import kotlin.math.sqrt
* Differentiable variable with value and derivative of differentiation ([deriv]) result * Differentiable variable with value and derivative of differentiation ([deriv]) result
* with respect to this variable. * with respect to this variable.
*/ */
data class ValueWithDeriv(var x: Double) { open class Variable(val x: Double) {
constructor(x: Number) : this(x.toDouble()) constructor(x: Number) : this(x.toDouble())
}
//TODO move set accessor inside AutoDiffField class DerivationResult(x: Double, val deriv: Map<Variable, Double>): Variable(x) {
var d: Double = 0.0 fun deriv(variable: Variable) = deriv[variable] ?: 0.0
internal set
} }
/** /**
@ -27,48 +27,53 @@ data class ValueWithDeriv(var x: Double) {
* *
* Example: * Example:
* ``` * ```
* val x = ValueWithDeriv(2) // define variable(s) and their values * val x = Variable(2) // define variable(s) and their values
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context * val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
* assertEquals(17.0, y.x) // the value of result (y) * assertEquals(17.0, y.x) // the value of result (y)
* assertEquals(9.0, x.d) // dy/dx * assertEquals(9.0, x.d) // dy/dx
* ``` * ```
*/ */
fun deriv(body: AutoDiffField.() -> ValueWithDeriv): ValueWithDeriv = fun deriv(body: AutoDiffField.() -> Variable): DerivationResult =
AutoDiffFieldImpl().run { AutoDiffContext().run {
val result = body() val result = body()
result.d = 1.0 // computing derivative w.r.t result result.d = 1.0 // computing derivative w.r.t result
runBackwardPass() runBackwardPass()
result DerivationResult(result.x, derivatives)
} }
abstract class AutoDiffField : Field<ValueWithDeriv> { abstract class AutoDiffField : Field<Variable> {
/** /**
* Performs update of derivative after the rest of the formula in the back-pass. * Performs update of derivative after the rest of the formula in the back-pass.
* *
* For example, implementation of `sin` function is: * For example, implementation of `sin` function is:
* *
* ``` * ```
* fun AD.sin(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(sin(x.x)) { z -> // call derive with function result * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
* } * }
* ``` * ```
*/ */
abstract fun <R> derive(value: R, block: (R) -> Unit): R abstract fun <R> derive(value: R, block: (R) -> Unit): R
/**
* A variable accessing inner state of derivatives. Use only in extensions
*/
abstract var Variable.d: Double
// Overloads for Double constants // Overloads for Double constants
operator fun Number.plus(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() + that.x)) { z -> operator fun Number.plus(that: Variable): Variable = derive(Variable(this.toDouble() + that.x)) { z ->
that.d += z.d that.d += z.d
} }
operator fun ValueWithDeriv.plus(b: Number): ValueWithDeriv = b.plus(this) operator fun Variable.plus(b: Number): Variable = b.plus(this)
operator fun Number.minus(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() - that.x)) { z -> operator fun Number.minus(that: Variable): Variable = derive(Variable(this.toDouble() - that.x)) { z ->
that.d -= z.d that.d -= z.d
} }
operator fun ValueWithDeriv.minus(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x - that.toDouble())) { z -> operator fun Variable.minus(that: Number): Variable = derive(Variable(this.x - that.toDouble())) { z ->
this.d += z.d this.d += z.d
} }
} }
@ -76,12 +81,19 @@ abstract class AutoDiffField : Field<ValueWithDeriv> {
/** /**
* Automatic Differentiation context class. * Automatic Differentiation context class.
*/ */
private class AutoDiffFieldImpl : AutoDiffField() { private class AutoDiffContext : AutoDiffField() {
// this stack contains pairs of blocks and values to apply them to // this stack contains pairs of blocks and values to apply them to
private var stack = arrayOfNulls<Any?>(8) private var stack = arrayOfNulls<Any?>(8)
private var sp = 0 private var sp = 0
internal val derivatives = HashMap<Variable, Double>()
override var Variable.d: Double
get() = derivatives[this] ?: 0.0
set(value) {
derivatives[this] = value
}
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: (R) -> Unit): R { override fun <R> derive(value: R, block: (R) -> Unit): R {
@ -104,71 +116,71 @@ private class AutoDiffFieldImpl : AutoDiffField() {
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
override fun add(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv = override fun add(a: Variable, b: Variable): Variable =
derive(ValueWithDeriv(a.x + b.x)) { z -> derive(Variable(a.x + b.x)) { z ->
a.d += z.d a.d += z.d
b.d += z.d b.d += z.d
} }
override fun multiply(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv = override fun multiply(a: Variable, b: Variable): Variable =
derive(ValueWithDeriv(a.x * b.x)) { z -> derive(Variable(a.x * b.x)) { z ->
a.d += z.d * b.x a.d += z.d * b.x
b.d += z.d * a.x b.d += z.d * a.x
} }
override fun divide(a: ValueWithDeriv, b: ValueWithDeriv): ValueWithDeriv = override fun divide(a: Variable, b: Variable): Variable =
derive(ValueWithDeriv(a.x / b.x)) { z -> derive(Variable(a.x / b.x)) { z ->
a.d += z.d / b.x a.d += z.d / b.x
b.d -= z.d * a.x / (b.x * b.x) b.d -= z.d * a.x / (b.x * b.x)
} }
override fun multiply(a: ValueWithDeriv, k: Number): ValueWithDeriv = override fun multiply(a: Variable, k: Number): Variable =
derive(ValueWithDeriv(k.toDouble() * a.x)) { z -> derive(Variable(k.toDouble() * a.x)) { z ->
a.d += z.d * k.toDouble() a.d += z.d * k.toDouble()
} }
override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0) override val zero: Variable get() = Variable(0.0)
override val one: ValueWithDeriv get() = ValueWithDeriv(1.0) override val one: Variable get() = Variable(1.0)
} }
// Extensions for differentiation of various basic mathematical functions // Extensions for differentiation of various basic mathematical functions
// x ^ 2 // x ^ 2
fun AutoDiffField.sqr(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(x.x * x.x)) { z -> fun AutoDiffField.sqr(x: Variable): Variable = derive(Variable(x.x * x.x)) { z ->
x.d += z.d * 2 * x.x x.d += z.d * 2 * x.x
} }
// x ^ 1/2 // x ^ 1/2
fun AutoDiffField.sqrt(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(sqrt(x.x))) { z -> fun AutoDiffField.sqrt(x: Variable): Variable = derive(Variable(sqrt(x.x))) { z ->
x.d += z.d * 0.5 / z.x x.d += z.d * 0.5 / z.x
} }
// x ^ y (const) // x ^ y (const)
fun AutoDiffField.pow(x: ValueWithDeriv, y: Double): ValueWithDeriv = derive(ValueWithDeriv(x.x.pow(y))) { z -> fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(Variable(x.x.pow(y))) { z ->
x.d += z.d * y * x.x.pow(y - 1) x.d += z.d * y * x.x.pow(y - 1)
} }
fun AutoDiffField.pow(x: ValueWithDeriv, y: Int): ValueWithDeriv = pow(x, y.toDouble()) fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble())
// exp(x) // exp(x)
fun AutoDiffField.exp(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.exp(x.x))) { z -> fun AutoDiffField.exp(x: Variable): Variable = derive(Variable(kotlin.math.exp(x.x))) { z ->
x.d += z.d * z.x x.d += z.d * z.x
} }
// ln(x) // ln(x)
fun AutoDiffField.ln(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.ln(x.x))) { z -> fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.x))) { z ->
x.d += z.d / x.x x.d += z.d / x.x
} }
// x ^ y (any) // x ^ y (any)
fun AutoDiffField.pow(x: ValueWithDeriv, y: ValueWithDeriv): ValueWithDeriv = exp(y * ln(x)) fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x))
// sin(x) // sin(x)
fun AutoDiffField.sin(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.sin(x.x))) { z -> fun AutoDiffField.sin(x: Variable): Variable = derive(Variable(kotlin.math.sin(x.x))) { z ->
x.d += z.d * kotlin.math.cos(x.x) x.d += z.d * kotlin.math.cos(x.x)
} }
// cos(x) // cos(x)
fun AutoDiffField.cos(x: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(kotlin.math.cos(x.x))) { z -> fun AutoDiffField.cos(x: Variable): Variable = derive(Variable(kotlin.math.cos(x.x))) { z ->
x.d -= z.d * kotlin.math.sin(x.x) x.d -= z.d * kotlin.math.sin(x.x)
} }

View File

@ -7,145 +7,155 @@ import kotlin.test.assertEquals
class AutoDiffTest { class AutoDiffTest {
@Test @Test
fun testPlusX2() { fun testPlusX2() {
val x = ValueWithDeriv(3) // diff w.r.t this x at 3 val x = Variable(3) // diff w.r.t this x at 3
val y = deriv { x + x } val y = deriv { x + x }
assertEquals(6.0, y.x) // y = x + x = 6 assertEquals(6.0, y.x) // y = x + x = 6
assertEquals(2.0, x.d) // dy/dx = 2 assertEquals(2.0, y.deriv(x)) // dy/dx = 2
} }
@Test @Test
fun testPlus() { fun testPlus() {
// two variables // two variables
val x = ValueWithDeriv(2) val x = Variable(2)
val y = ValueWithDeriv(3) val y = Variable(3)
val z = deriv { x + y } val z = deriv { x + y }
assertEquals(5.0, z.x) // z = x + y = 5 assertEquals(5.0, z.x) // z = x + y = 5
assertEquals(1.0, x.d) // dz/dx = 1 assertEquals(1.0, z.deriv(x)) // dz/dx = 1
assertEquals(1.0, y.d) // dz/dy = 1 assertEquals(1.0, z.deriv(y)) // dz/dy = 1
} }
@Test @Test
fun testMinus() { fun testMinus() {
// two variables // two variables
val x = ValueWithDeriv(7) val x = Variable(7)
val y = ValueWithDeriv(3) val y = Variable(3)
val z = deriv { x - y } val z = deriv { x - y }
assertEquals(4.0, z.x) // z = x - y = 4 assertEquals(4.0, z.x) // z = x - y = 4
assertEquals(1.0, x.d) // dz/dx = 1 assertEquals(1.0, z.deriv(x)) // dz/dx = 1
assertEquals(-1.0, y.d) // dz/dy = -1 assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
} }
@Test @Test
fun testMulX2() { fun testMulX2() {
val x = ValueWithDeriv(3) // diff w.r.t this x at 3 val x = Variable(3) // diff w.r.t this x at 3
val y = deriv { x * x } val y = deriv { x * x }
assertEquals(9.0, y.x) // y = x * x = 9 assertEquals(9.0, y.x) // y = x * x = 9
assertEquals(6.0, x.d) // dy/dx = 2 * x = 7 assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
} }
@Test @Test
fun testSqr() { fun testSqr() {
val x = ValueWithDeriv(3) val x = Variable(3)
val y = deriv { sqr(x) } val y = deriv { sqr(x) }
assertEquals(9.0, y.x) // y = x ^ 2 = 9 assertEquals(9.0, y.x) // y = x ^ 2 = 9
assertEquals(6.0, x.d) // dy/dx = 2 * x = 7 assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
} }
@Test @Test
fun testSqrSqr() { fun testSqrSqr() {
val x = ValueWithDeriv(2) val x = Variable(2)
val y = deriv { sqr(sqr(x)) } val y = deriv { sqr(sqr(x)) }
assertEquals(16.0, y.x) // y = x ^ 4 = 16 assertEquals(16.0, y.x) // y = x ^ 4 = 16
assertEquals(32.0, x.d) // dy/dx = 4 * x^3 = 32 assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
} }
@Test @Test
fun testX3() { fun testX3() {
val x = ValueWithDeriv(2) // diff w.r.t this x at 2 val x = Variable(2) // diff w.r.t this x at 2
val y = deriv { x * x * x } val y = deriv { x * x * x }
assertEquals(8.0, y.x) // y = x * x * x = 8 assertEquals(8.0, y.x) // y = x * x * x = 8
assertEquals(12.0, x.d) // dy/dx = 3 * x * x = 12 assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
} }
@Test @Test
fun testDiv() { fun testDiv() {
val x = ValueWithDeriv(5) val x = Variable(5)
val y = ValueWithDeriv(2) val y = Variable(2)
val z = deriv { x / y } val z = deriv { x / y }
assertEquals(2.5, z.x) // z = x / y = 2.5 assertEquals(2.5, z.x) // z = x / y = 2.5
assertEquals(0.5, x.d) // dz/dx = 1 / y = 0.5 assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
assertEquals(-1.25, y.d) // dz/dy = -x / y^2 = -1.25 assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
} }
@Test @Test
fun testPow3() { fun testPow3() {
val x = ValueWithDeriv(2) // diff w.r.t this x at 2 val x = Variable(2) // diff w.r.t this x at 2
val y = deriv { pow(x, 3) } val y = deriv { pow(x, 3) }
assertEquals(8.0, y.x) // y = x ^ 3 = 8 assertEquals(8.0, y.x) // y = x ^ 3 = 8
assertEquals(12.0, x.d) // dy/dx = 3 * x ^ 2 = 12 assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
} }
@Test @Test
fun testPowFull() { fun testPowFull() {
val x = ValueWithDeriv(2) val x = Variable(2)
val y = ValueWithDeriv(3) val y = Variable(3)
val z = deriv { pow(x, y) } val z = deriv { pow(x, y) }
assertApprox(8.0, z.x) // z = x ^ y = 8 assertApprox(8.0, z.x) // z = x ^ y = 8
assertApprox(12.0, x.d) // dz/dx = y * x ^ (y - 1) = 12 assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
assertApprox(8.0 * kotlin.math.ln(2.0), y.d) // dz/dy = x ^ y * ln(x) assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
} }
@Test @Test
fun testFromPaper() { fun testFromPaper() {
val x = ValueWithDeriv(3) val x = Variable(3)
val y = deriv { 2 * x + x * x * x } val y = deriv { 2 * x + x * x * x }
assertEquals(33.0, y.x) // y = 2 * x + x * x * x = 33 assertEquals(33.0, y.x) // y = 2 * x + x * x * x = 33
assertEquals(29.0, x.d) // dy/dx = 2 + 3 * x * x = 29 assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
}
@Test
fun testInnerVariable() {
val x = Variable(1)
val y = deriv {
Variable(1) * x
}
assertEquals(1.0, y.x) // y = x ^ n = 1
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
} }
@Test @Test
fun testLongChain() { fun testLongChain() {
val n = 10_000 val n = 10_000
val x = ValueWithDeriv(1) val x = Variable(1)
val y = deriv { val y = deriv {
var pow = ValueWithDeriv(1) var res = Variable(1)
for (i in 1..n) pow *= x for (i in 1..n) res *= x
pow res
} }
assertEquals(1.0, y.x) // y = x ^ n = 1 assertEquals(1.0, y.x) // y = x ^ n = 1
assertEquals(n.toDouble(), x.d) // dy/dx = n * x ^ (n - 1) = n - 1 assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
} }
@Test @Test
fun testExample() { fun testExample() {
val x = ValueWithDeriv(2) val x = Variable(2)
val y = deriv { sqr(x) + 5 * x + 3 } val y = deriv { sqr(x) + 5 * x + 3 }
assertEquals(17.0, y.x) // the value of result (y) assertEquals(17.0, y.x) // the value of result (y)
assertEquals(9.0, x.d) // dy/dx assertEquals(9.0, y.deriv(x)) // dy/dx
} }
@Test @Test
fun testSqrt() { fun testSqrt() {
val x = ValueWithDeriv(16) val x = Variable(16)
val y = deriv { sqrt(x) } val y = deriv { sqrt(x) }
assertEquals(4.0, y.x) // y = x ^ 1/2 = 4 assertEquals(4.0, y.x) // y = x ^ 1/2 = 4
assertEquals(1.0 / 8, x.d) // dy/dx = 1/2 / x ^ 1/4 = 1/8 assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
} }
@Test @Test
fun testSin() { fun testSin() {
val x = ValueWithDeriv(PI / 6) val x = Variable(PI / 6)
val y = deriv { sin(x) } val y = deriv { sin(x) }
assertApprox(0.5, y.x) // y = sin(PI/6) = 0.5 assertApprox(0.5, y.x) // y = sin(PI/6) = 0.5
assertApprox(kotlin.math.sqrt(3.0) / 2, x.d) // dy/dx = cos(PI/6) = sqrt(3)/2 assertApprox(kotlin.math.sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(PI/6) = sqrt(3)/2
} }
@Test @Test
fun testCos() { fun testCos() {
val x = ValueWithDeriv(PI / 6) val x = Variable(PI / 6)
val y = deriv { cos(x) } val y = deriv { cos(x) }
assertApprox(kotlin.math.sqrt(3.0) / 2, y.x) // y = cos(PI/6) = sqrt(3)/2 assertApprox(kotlin.math.sqrt(3.0) / 2, y.x) // y = cos(PI/6) = sqrt(3)/2
assertApprox(-0.5, x.d) // dy/dx = -sin(PI/6) = -0.5 assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(PI/6) = -0.5
} }
private fun assertApprox(a: Double, b: Double) { private fun assertApprox(a: Double, b: Double) {