Merge dev into master #59

Merged
altavir merged 37 commits from dev into master 2019-05-31 12:35:58 +03:00
2 changed files with 13 additions and 21 deletions
Showing only changes of commit 765097cbbe - Show all commits

View File

@ -12,12 +12,18 @@ import kotlin.math.sqrt
* Differentiable variable with value and derivative of differentiation ([deriv]) result * Differentiable variable with value and derivative of differentiation ([deriv]) result
* with respect to this variable. * with respect to this variable.
*/ */
data class ValueWithDeriv(var x: Double, var d: Double = 0.0) { data class ValueWithDeriv(var x: Double) {
constructor(x: Number) : this(x.toDouble()) constructor(x: Number) : this(x.toDouble())
//TODO move set accessor inside AutoDiffField
var d: Double = 0.0
internal set
} }
/** /**
* Runs differentiation and establishes [Field<ValueWithDeriv>] context inside the block of code. * Runs differentiation and establishes [AutoDiffField] context inside the block of code.
*
* The partial derivatives are placed in argument `d` variable
* *
* Example: * Example:
* ``` * ```
@ -28,7 +34,7 @@ data class ValueWithDeriv(var x: Double, var d: Double = 0.0) {
* ``` * ```
*/ */
fun deriv(body: AutoDiffField.() -> ValueWithDeriv): ValueWithDeriv = fun deriv(body: AutoDiffField.() -> ValueWithDeriv): ValueWithDeriv =
ValueWithDerivField().run { AutoDiffFieldImpl().run {
val result = body() val result = body()
result.d = 1.0 // computing derivative w.r.t result result.d = 1.0 // computing derivative w.r.t result
runBackwardPass() runBackwardPass()
@ -65,26 +71,12 @@ abstract class AutoDiffField : Field<ValueWithDeriv> {
operator fun ValueWithDeriv.minus(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x - that.toDouble())) { z -> operator fun ValueWithDeriv.minus(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x - that.toDouble())) { z ->
this.d += z.d this.d += z.d
} }
override operator fun Number.times(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() * that.x)) { z ->
that.d += z.d * this.toDouble()
}
override operator fun ValueWithDeriv.times(b: Number): ValueWithDeriv = b.times(this)
override operator fun Number.div(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() / that.x)) { z ->
that.d -= z.d * this.toDouble() / (that.x * that.x)
}
override operator fun ValueWithDeriv.div(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x / that.toDouble())) { z ->
this.d += z.d / that.toDouble()
}
} }
/** /**
* Automatic Differentiation context class. * Automatic Differentiation context class.
*/ */
private class ValueWithDerivField : AutoDiffField() { private class AutoDiffFieldImpl : AutoDiffField() {
// this stack contains pairs of blocks and values to apply them to // this stack contains pairs of blocks and values to apply them to
private var stack = arrayOfNulls<Any?>(8) private var stack = arrayOfNulls<Any?>(8)
@ -135,8 +127,8 @@ private class ValueWithDerivField : AutoDiffField() {
a.d += z.d * k.toDouble() a.d += z.d * k.toDouble()
} }
override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0, 0.0) override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0)
override val one: ValueWithDeriv get() = ValueWithDeriv(1.0, 0.0) override val one: ValueWithDeriv get() = ValueWithDeriv(1.0)
} }
// Extensions for differentiation of various basic mathematical functions // Extensions for differentiation of various basic mathematical functions

View File

@ -119,7 +119,7 @@ class AutoDiffTest {
@Test @Test
fun testExample() { fun testExample() {
val x = ValueWithDeriv(2) val x = ValueWithDeriv(2)
val y = deriv { sqr(x) + 5 * x + 3*one } val y = deriv { sqr(x) + 5 * x + 3 }
assertEquals(17.0, y.x) // the value of result (y) assertEquals(17.0, y.x) // the value of result (y)
assertEquals(9.0, x.d) // dy/dx assertEquals(9.0, x.d) // dy/dx
} }