diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt index 4b0070545..ab007b976 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt @@ -12,12 +12,18 @@ import kotlin.math.sqrt * Differentiable variable with value and derivative of differentiation ([deriv]) result * with respect to this variable. */ -data class ValueWithDeriv(var x: Double, var d: Double = 0.0) { +data class ValueWithDeriv(var x: Double) { constructor(x: Number) : this(x.toDouble()) + + //TODO move set accessor inside AutoDiffField + var d: Double = 0.0 + internal set } /** - * Runs differentiation and establishes [Field] context inside the block of code. + * Runs differentiation and establishes [AutoDiffField] context inside the block of code. + * + * The partial derivatives are placed in argument `d` variable * * Example: * ``` @@ -28,7 +34,7 @@ data class ValueWithDeriv(var x: Double, var d: Double = 0.0) { * ``` */ fun deriv(body: AutoDiffField.() -> ValueWithDeriv): ValueWithDeriv = - ValueWithDerivField().run { + AutoDiffFieldImpl().run { val result = body() result.d = 1.0 // computing derivative w.r.t result runBackwardPass() @@ -65,26 +71,12 @@ abstract class AutoDiffField : Field { operator fun ValueWithDeriv.minus(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x - that.toDouble())) { z -> this.d += z.d } - - override operator fun Number.times(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() * that.x)) { z -> - that.d += z.d * this.toDouble() - } - - override operator fun ValueWithDeriv.times(b: Number): ValueWithDeriv = b.times(this) - - override operator fun Number.div(that: ValueWithDeriv): ValueWithDeriv = derive(ValueWithDeriv(this.toDouble() / that.x)) { z -> - that.d -= z.d * this.toDouble() / (that.x * that.x) - } - - override operator fun ValueWithDeriv.div(that: Number): ValueWithDeriv = derive(ValueWithDeriv(this.x / that.toDouble())) { z -> - this.d += z.d / that.toDouble() - } } /** * Automatic Differentiation context class. */ -private class ValueWithDerivField : AutoDiffField() { +private class AutoDiffFieldImpl : AutoDiffField() { // this stack contains pairs of blocks and values to apply them to private var stack = arrayOfNulls(8) @@ -135,8 +127,8 @@ private class ValueWithDerivField : AutoDiffField() { a.d += z.d * k.toDouble() } - override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0, 0.0) - override val one: ValueWithDeriv get() = ValueWithDeriv(1.0, 0.0) + override val zero: ValueWithDeriv get() = ValueWithDeriv(0.0) + override val one: ValueWithDeriv get() = ValueWithDeriv(1.0) } // Extensions for differentiation of various basic mathematical functions diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt index 6f7299e9e..736c301fd 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt @@ -119,7 +119,7 @@ class AutoDiffTest { @Test fun testExample() { val x = ValueWithDeriv(2) - val y = deriv { sqr(x) + 5 * x + 3*one } + val y = deriv { sqr(x) + 5 * x + 3 } assertEquals(17.0, y.x) // the value of result (y) assertEquals(9.0, x.d) // dy/dx }