MPP Univariate Autodiff from Roman Elizarov
This commit is contained in:
parent
d0a724771d
commit
765097cbbe
@ -12,12 +12,18 @@ import kotlin.math.sqrt
|
|||||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
||||||
* with respect to this variable.
|
* with respect to this variable.
|
||||||
*/
|
*/
|
||||||
data class ValueWithDeriv(var x: Double, var d: Double = 0.0) {
|
data class ValueWithDeriv(var x: Double) {
|
||||||
constructor(x: Number) : this(x.toDouble())
|
constructor(x: Number) : this(x.toDouble())
|
||||||
|
|
||||||
|
//TODO move set accessor inside AutoDiffField
|
||||||
|
var d: Double = 0.0
|
||||||
|
internal set
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Runs differentiation and establishes [Field<ValueWithDeriv>] context inside the block of code.
|
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||||
|
*
|
||||||
|
* The partial derivatives are placed in argument `d` variable
|
||||||
*
|
*
|
||||||
* Example:
|
* Example:
|
||||||
* ```
|
* ```
|
||||||
@ -28,7 +34,7 @@ data class ValueWithDeriv(var x: Double, var d: Double = 0.0) {
|
|||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
fun deriv(body: AutoDiffField.() -> ValueWithDeriv): ValueWithDeriv =
|
fun deriv(body: AutoDiffField.() -> ValueWithDeriv): ValueWithDeriv =
|
||||||
ValueWithDerivField().run {
|
AutoDiffFieldImpl().run {
|
||||||
val result = body()
|
val result = body()
|
||||||
result.d = 1.0 // computing derivative w.r.t result
|
result.d = 1.0 // computing derivative w.r.t result
|
||||||
runBackwardPass()
|
runBackwardPass()
|
||||||
@ -65,26 +71,12 @@ abstract class AutoDiffField : Field<ValueWithDeriv> {
|
|||||||
operator fun ValueWithDeriv.minus(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x - that.toDouble())) { z ->
|
operator fun ValueWithDeriv.minus(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x - that.toDouble())) { z ->
|
||||||
this.d += z.d
|
this.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Number.times(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() * that.x)) { z ->
|
|
||||||
that.d += z.d * this.toDouble()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun ValueWithDeriv.times(b: Number): ValueWithDeriv = b.times(this)
|
|
||||||
|
|
||||||
override operator fun Number.div(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() / that.x)) { z ->
|
|
||||||
that.d -= z.d * this.toDouble() / (that.x * that.x)
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun ValueWithDeriv.div(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x / that.toDouble())) { z ->
|
|
||||||
this.d += z.d / that.toDouble()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Automatic Differentiation context class.
|
* Automatic Differentiation context class.
|
||||||
*/
|
*/
|
||||||
private class ValueWithDerivField : AutoDiffField() {
|
private class AutoDiffFieldImpl : AutoDiffField() {
|
||||||
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
private var stack = arrayOfNulls<Any?>(8)
|
private var stack = arrayOfNulls<Any?>(8)
|
||||||
@ -135,8 +127,8 @@ private class ValueWithDerivField : AutoDiffField() {
|
|||||||
a.d += z.d * k.toDouble()
|
a.d += z.d * k.toDouble()
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0, 0.0)
|
override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0)
|
||||||
override val one: ValueWithDeriv get() = ValueWithDeriv(1.0, 0.0)
|
override val one: ValueWithDeriv get() = ValueWithDeriv(1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
@ -119,7 +119,7 @@ class AutoDiffTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testExample() {
|
fun testExample() {
|
||||||
val x = ValueWithDeriv(2)
|
val x = ValueWithDeriv(2)
|
||||||
val y = deriv { sqr(x) + 5 * x + 3*one }
|
val y = deriv { sqr(x) + 5 * x + 3 }
|
||||||
assertEquals(17.0, y.x) // the value of result (y)
|
assertEquals(17.0, y.x) // the value of result (y)
|
||||||
assertEquals(9.0, x.d) // dy/dx
|
assertEquals(9.0, x.d) // dy/dx
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user