diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt index 454d386b0..cd50815e3 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt @@ -1,5 +1,7 @@ package scientifik.kmath.operations +import scientifik.kmath.linear.Point +import scientifik.kmath.structures.asBuffer import kotlin.math.pow import kotlin.math.sqrt @@ -12,12 +14,26 @@ import kotlin.math.sqrt * Differentiable variable with value and derivative of differentiation ([deriv]) result * with respect to this variable. */ -open class Variable(val x: Double) { +open class Variable(val value: Double) { constructor(x: Number) : this(x.toDouble()) } -class DerivationResult(x: Double, val deriv: Map): Variable(x) { +class DerivationResult(value: Double, val deriv: Map) : Variable(value) { fun deriv(variable: Variable) = deriv[variable] ?: 0.0 + + /** + * compute divergence + */ + fun div() = deriv.values.sum() + + /** + * Compute a gradient for variables in given order + */ + fun grad(vararg variables: Variable): Point = if (variables.isEmpty()) { + error("Variable order is not provided for gradient construction") + } else { + variables.map(::deriv).toDoubleArray().asBuffer() + } } /** @@ -38,7 +54,7 @@ fun deriv(body: AutoDiffField.() -> Variable): DerivationResult = val result = body() result.d = 1.0 // computing derivative w.r.t result runBackwardPass() - DerivationResult(result.x, derivatives) + DerivationResult(result.value, derivatives) } @@ -61,19 +77,21 @@ abstract class AutoDiffField : Field { */ abstract var Variable.d: Double + abstract fun variable(value: Double): Variable + // Overloads for Double constants - operator fun Number.plus(that: Variable): Variable = derive(Variable(this.toDouble() + that.x)) { z -> + operator fun Number.plus(that: Variable): Variable = derive(variable(this.toDouble() + that.value)) { z -> that.d += z.d } operator fun Variable.plus(b: Number): Variable = b.plus(this) - operator fun Number.minus(that: Variable): Variable = derive(Variable(this.toDouble() - that.x)) { z -> + operator fun Number.minus(that: Variable): Variable = derive(variable(this.toDouble() - that.value)) { z -> that.d -= z.d } - operator fun Variable.minus(that: Number): Variable = derive(Variable(this.x - that.toDouble())) { z -> + operator fun Variable.minus(that: Number): Variable = derive(variable(this.value - that.toDouble())) { z -> this.d += z.d } } @@ -89,10 +107,22 @@ private class AutoDiffContext : AutoDiffField() { internal val derivatives = HashMap() + + /** + * A variable coupled with its derivative. For internal use only + */ + class VariableWithDeriv(x: Double, var d: Double = 0.0): Variable(x) + + override fun variable(value: Double): Variable = VariableWithDeriv(value) + override var Variable.d: Double - get() = derivatives[this] ?: 0.0 + get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: 0.0 set(value) { - derivatives[this] = value + if(this is VariableWithDeriv){ + d = value + }else { + derivatives[this] = value + } } @Suppress("UNCHECKED_CAST") @@ -117,25 +147,25 @@ private class AutoDiffContext : AutoDiffField() { override fun add(a: Variable, b: Variable): Variable = - derive(Variable(a.x + b.x)) { z -> + derive(variable(a.value + b.value)) { z -> a.d += z.d b.d += z.d } override fun multiply(a: Variable, b: Variable): Variable = - derive(Variable(a.x * b.x)) { z -> - a.d += z.d * b.x - b.d += z.d * a.x + derive(variable(a.value * b.value)) { z -> + a.d += z.d * b.value + b.d += z.d * a.value } override fun divide(a: Variable, b: Variable): Variable = - derive(Variable(a.x / b.x)) { z -> - a.d += z.d / b.x - b.d -= z.d * a.x / (b.x * b.x) + derive(Variable(a.value / b.value)) { z -> + a.d += z.d / b.value + b.d -= z.d * a.value / (b.value * b.value) } override fun multiply(a: Variable, k: Number): Variable = - derive(Variable(k.toDouble() * a.x)) { z -> + derive(variable(k.toDouble() * a.value)) { z -> a.d += z.d * k.toDouble() } @@ -146,41 +176,41 @@ private class AutoDiffContext : AutoDiffField() { // Extensions for differentiation of various basic mathematical functions // x ^ 2 -fun AutoDiffField.sqr(x: Variable): Variable = derive(Variable(x.x * x.x)) { z -> - x.d += z.d * 2 * x.x +fun AutoDiffField.sqr(x: Variable): Variable = derive(variable(x.value * x.value)) { z -> + x.d += z.d * 2 * x.value } // x ^ 1/2 -fun AutoDiffField.sqrt(x: Variable): Variable = derive(Variable(sqrt(x.x))) { z -> - x.d += z.d * 0.5 / z.x +fun AutoDiffField.sqrt(x: Variable): Variable = derive(variable(sqrt(x.value))) { z -> + x.d += z.d * 0.5 / z.value } // x ^ y (const) -fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(Variable(x.x.pow(y))) { z -> - x.d += z.d * y * x.x.pow(y - 1) +fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(variable(x.value.pow(y))) { z -> + x.d += z.d * y * x.value.pow(y - 1) } fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble()) // exp(x) -fun AutoDiffField.exp(x: Variable): Variable = derive(Variable(kotlin.math.exp(x.x))) { z -> - x.d += z.d * z.x +fun AutoDiffField.exp(x: Variable): Variable = derive(variable(kotlin.math.exp(x.value))) { z -> + x.d += z.d * z.value } // ln(x) -fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.x))) { z -> - x.d += z.d / x.x +fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.value))) { z -> + x.d += z.d / x.value } // x ^ y (any) fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x)) // sin(x) -fun AutoDiffField.sin(x: Variable): Variable = derive(Variable(kotlin.math.sin(x.x))) { z -> - x.d += z.d * kotlin.math.cos(x.x) +fun AutoDiffField.sin(x: Variable): Variable = derive(variable(kotlin.math.sin(x.value))) { z -> + x.d += z.d * kotlin.math.cos(x.value) } // cos(x) -fun AutoDiffField.cos(x: Variable): Variable = derive(Variable(kotlin.math.cos(x.x))) { z -> - x.d -= z.d * kotlin.math.sin(x.x) +fun AutoDiffField.cos(x: Variable): Variable = derive(variable(kotlin.math.cos(x.value))) { z -> + x.d -= z.d * kotlin.math.sin(x.value) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index e521df86e..6aaa2da8f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -37,6 +37,11 @@ interface Buffer { companion object { + inline fun real(size: Int, initializer: (Int) -> Double): DoubleBuffer { + val array = DoubleArray(size) { initializer(it) } + return DoubleBuffer(array) + } + /** * Create a boxing buffer of given type */ diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt index 9adc2f977..e4fad888e 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt @@ -1,15 +1,17 @@ package scientifik.kmath.operations +import scientifik.kmath.structures.asBuffer import kotlin.math.PI import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertTrue class AutoDiffTest { @Test fun testPlusX2() { val x = Variable(3) // diff w.r.t this x at 3 val y = deriv { x + x } - assertEquals(6.0, y.x) // y = x + x = 6 + assertEquals(6.0, y.value) // y = x + x = 6 assertEquals(2.0, y.deriv(x)) // dy/dx = 2 } @@ -19,7 +21,7 @@ class AutoDiffTest { val x = Variable(2) val y = Variable(3) val z = deriv { x + y } - assertEquals(5.0, z.x) // z = x + y = 5 + assertEquals(5.0, z.value) // z = x + y = 5 assertEquals(1.0, z.deriv(x)) // dz/dx = 1 assertEquals(1.0, z.deriv(y)) // dz/dy = 1 } @@ -30,7 +32,7 @@ class AutoDiffTest { val x = Variable(7) val y = Variable(3) val z = deriv { x - y } - assertEquals(4.0, z.x) // z = x - y = 4 + assertEquals(4.0, z.value) // z = x - y = 4 assertEquals(1.0, z.deriv(x)) // dz/dx = 1 assertEquals(-1.0, z.deriv(y)) // dz/dy = -1 } @@ -39,7 +41,7 @@ class AutoDiffTest { fun testMulX2() { val x = Variable(3) // diff w.r.t this x at 3 val y = deriv { x * x } - assertEquals(9.0, y.x) // y = x * x = 9 + assertEquals(9.0, y.value) // y = x * x = 9 assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7 } @@ -47,7 +49,7 @@ class AutoDiffTest { fun testSqr() { val x = Variable(3) val y = deriv { sqr(x) } - assertEquals(9.0, y.x) // y = x ^ 2 = 9 + assertEquals(9.0, y.value) // y = x ^ 2 = 9 assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7 } @@ -55,7 +57,7 @@ class AutoDiffTest { fun testSqrSqr() { val x = Variable(2) val y = deriv { sqr(sqr(x)) } - assertEquals(16.0, y.x) // y = x ^ 4 = 16 + assertEquals(16.0, y.value) // y = x ^ 4 = 16 assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32 } @@ -63,7 +65,7 @@ class AutoDiffTest { fun testX3() { val x = Variable(2) // diff w.r.t this x at 2 val y = deriv { x * x * x } - assertEquals(8.0, y.x) // y = x * x * x = 8 + assertEquals(8.0, y.value) // y = x * x * x = 8 assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12 } @@ -72,7 +74,7 @@ class AutoDiffTest { val x = Variable(5) val y = Variable(2) val z = deriv { x / y } - assertEquals(2.5, z.x) // z = x / y = 2.5 + assertEquals(2.5, z.value) // z = x / y = 2.5 assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5 assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25 } @@ -81,7 +83,7 @@ class AutoDiffTest { fun testPow3() { val x = Variable(2) // diff w.r.t this x at 2 val y = deriv { pow(x, 3) } - assertEquals(8.0, y.x) // y = x ^ 3 = 8 + assertEquals(8.0, y.value) // y = x ^ 3 = 8 assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12 } @@ -90,7 +92,7 @@ class AutoDiffTest { val x = Variable(2) val y = Variable(3) val z = deriv { pow(x, y) } - assertApprox(8.0, z.x) // z = x ^ y = 8 + assertApprox(8.0, z.value) // z = x ^ y = 8 assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12 assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x) } @@ -99,7 +101,7 @@ class AutoDiffTest { fun testFromPaper() { val x = Variable(3) val y = deriv { 2 * x + x * x * x } - assertEquals(33.0, y.x) // y = 2 * x + x * x * x = 33 + assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33 assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29 } @@ -109,7 +111,7 @@ class AutoDiffTest { val y = deriv { Variable(1) * x } - assertEquals(1.0, y.x) // y = x ^ n = 1 + assertEquals(1.0, y.value) // y = x ^ n = 1 assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1 } @@ -122,7 +124,7 @@ class AutoDiffTest { for (i in 1..n) res *= x res } - assertEquals(1.0, y.x) // y = x ^ n = 1 + assertEquals(1.0, y.value) // y = x ^ n = 1 assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1 } @@ -130,7 +132,7 @@ class AutoDiffTest { fun testExample() { val x = Variable(2) val y = deriv { sqr(x) + 5 * x + 3 } - assertEquals(17.0, y.x) // the value of result (y) + assertEquals(17.0, y.value) // the value of result (y) assertEquals(9.0, y.deriv(x)) // dy/dx } @@ -138,7 +140,7 @@ class AutoDiffTest { fun testSqrt() { val x = Variable(16) val y = deriv { sqrt(x) } - assertEquals(4.0, y.x) // y = x ^ 1/2 = 4 + assertEquals(4.0, y.value) // y = x ^ 1/2 = 4 assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8 } @@ -146,7 +148,7 @@ class AutoDiffTest { fun testSin() { val x = Variable(PI / 6) val y = deriv { sin(x) } - assertApprox(0.5, y.x) // y = sin(PI/6) = 0.5 + assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5 assertApprox(kotlin.math.sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(PI/6) = sqrt(3)/2 } @@ -154,10 +156,19 @@ class AutoDiffTest { fun testCos() { val x = Variable(PI / 6) val y = deriv { cos(x) } - assertApprox(kotlin.math.sqrt(3.0) / 2, y.x) // y = cos(PI/6) = sqrt(3)/2 + assertApprox(kotlin.math.sqrt(3.0) / 2, y.value) // y = cos(PI/6) = sqrt(3)/2 assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(PI/6) = -0.5 } + @Test + fun testDivGrad() { + val x = Variable(1.0) + val y = Variable(2.0) + val res = deriv { x * x + y * y } + assertEquals(6.0, res.div()) + assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer())) + } + private fun assertApprox(a: Double, b: Double) { if ((a - b) > 1e-10) assertEquals(a, b) }