diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt new file mode 100644 index 000000000..91256d114 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt @@ -0,0 +1,233 @@ +package scientifik.kmath.misc + +import scientifik.kmath.linear.Point +import scientifik.kmath.operations.ExtendedField +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.sum +import scientifik.kmath.structures.asBuffer + +/* + * Implementation of backward-mode automatic differentiation. + * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d + */ + +/** + * Differentiable variable with value and derivative of differentiation ([deriv]) result + * with respect to this variable. + */ +open class Variable(val value: T) + +class DerivationResult( + value: T, + val deriv: Map, T>, + val context: ExtendedField +) : Variable(value) { + fun deriv(variable: Variable) = deriv[variable] ?: context.zero + + /** + * compute divergence + */ + fun div() = context.run { sum(deriv.values) } + + /** + * Compute a gradient for variables in given order + */ + fun grad(vararg variables: Variable): Point = if (variables.isEmpty()) { + error("Variable order is not provided for gradient construction") + } else { + variables.map(::deriv).asBuffer() + } +} + +/** + * Runs differentiation and establishes [AutoDiffField] context inside the block of code. + * + * The partial derivatives are placed in argument `d` variable + * + * Example: + * ``` + * val x = Variable(2) // define variable(s) and their values + * val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context + * assertEquals(17.0, y.x) // the value of result (y) + * assertEquals(9.0, x.d) // dy/dx + * ``` + */ +fun ExtendedField.deriv(body: AutoDiffField.() -> Variable): DerivationResult = + AutoDiffContext(this).run { + val result = body() + result.d = context.one// computing derivative w.r.t result + runBackwardPass() + DerivationResult(result.value, derivatives, this@deriv) + } + + +abstract class AutoDiffField : Field> { + + abstract val context: ExtendedField + + /** + * Performs update of derivative after the rest of the formula in the back-pass. + * + * For example, implementation of `sin` function is: + * + * ``` + * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result + * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function + * } + * ``` + */ + abstract fun derive(value: R, block: ExtendedField.(R) -> Unit): R + + /** + * A variable accessing inner state of derivatives. + * Use this function in inner builders to avoid creating additional derivative bindings + */ + abstract var Variable.d: T + + abstract fun variable(value: T): Variable + + inline fun variable(block: ExtendedField.() -> T) = variable(context.block()) + + // Overloads for Double constants + + operator fun Number.plus(that: Variable): Variable = + derive(variable { this@plus.toDouble() * one + that.value }) { z -> + that.d += z.d + } + + operator fun Variable.plus(b: Number): Variable = b.plus(this) + + operator fun Number.minus(that: Variable): Variable = + derive(variable { this@minus.toDouble() * one - that.value }) { z -> + that.d -= z.d + } + + operator fun Variable.minus(that: Number): Variable = + derive(variable { this@minus.value - one * that.toDouble() }) { z -> + this@minus.d += z.d + } +} + +/** + * Automatic Differentiation context class. + */ +private class AutoDiffContext(override val context: ExtendedField) : AutoDiffField() { + + // this stack contains pairs of blocks and values to apply them to + private var stack = arrayOfNulls(8) + private var sp = 0 + + internal val derivatives = HashMap, T>() + + + /** + * A variable coupled with its derivative. For internal use only + */ + inner class VariableWithDeriv(x: T, var d: T = context.zero) : Variable(x) + + + override fun variable(value: T): Variable = + VariableWithDeriv(value) + + override var Variable.d: T + get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero + set(value) { + if (this is VariableWithDeriv) { + d = value + } else { + derivatives[this] = value + } + } + + @Suppress("UNCHECKED_CAST") + override fun derive(value: R, block: ExtendedField.(R) -> Unit): R { + // save block to stack for backward pass + if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) + stack[sp++] = block + stack[sp++] = value + return value + } + + @Suppress("UNCHECKED_CAST") + fun runBackwardPass() { + while (sp > 0) { + val value = stack[--sp] + val block = stack[--sp] as ExtendedField.(Any?) -> Unit + context.block(value) + } + } + + // Basic math (+, -, *, /) + + + override fun add(a: Variable, b: Variable): Variable = + derive(variable { a.value + b.value }) { z -> + a.d += z.d + b.d += z.d + } + + override fun multiply(a: Variable, b: Variable): Variable = + derive(variable { a.value * b.value }) { z -> + a.d += z.d * b.value + b.d += z.d * a.value + } + + override fun divide(a: Variable, b: Variable): Variable = + derive(variable { a.value / b.value }) { z -> + a.d += z.d / b.value + b.d -= z.d * a.value / (b.value * b.value) + } + + override fun multiply(a: Variable, k: Number): Variable = + derive(variable { k.toDouble() * a.value }) { z -> + a.d += z.d * k.toDouble() + } + + override val zero: Variable get() = Variable(context.zero) + override val one: Variable get() = Variable(context.one) +} + +// Extensions for differentiation of various basic mathematical functions + +// x ^ 2 +fun AutoDiffField.sqr(x: Variable): Variable = derive(variable { x.value * x.value }) { z -> + x.d += z.d * 2 * x.value +} + +// x ^ 1/2 +fun AutoDiffField.sqrt(x: Variable): Variable = derive(variable { sqrt(x.value) }) { z -> + x.d += z.d * 0.5 / z.value +} + +// x ^ y (const) +fun AutoDiffField.pow(x: Variable, y: Double): Variable = + derive(variable { power(x.value, y) }) { z -> + x.d += z.d * y * power(x.value, y - 1) + } + +fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble()) + +// exp(x) +fun AutoDiffField.exp(x: Variable): Variable = derive(variable { exp(x.value) }) { z -> + x.d += z.d * z.value +} + +// ln(x) +fun AutoDiffField.ln(x: Variable): Variable = derive( + variable { ln(x.value) } +) { z -> + x.d += z.d / x.value +} + +// x ^ y (any) +fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x)) + +// sin(x) +fun AutoDiffField.sin(x: Variable): Variable = derive(variable { sin(x.value) }) { z -> + x.d += z.d * cos(x.value) +} + +// cos(x) +fun AutoDiffField.cos(x: Variable): Variable = derive(variable { cos(x.value) }) { z -> + x.d -= z.d * sin(x.value) +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt deleted file mode 100644 index 9714630c1..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/AutoDiff.kt +++ /dev/null @@ -1,218 +0,0 @@ -package scientifik.kmath.operations - -import scientifik.kmath.linear.Point -import scientifik.kmath.structures.asBuffer -import kotlin.math.pow -import kotlin.math.sqrt - -/* - * Implementation of backward-mode automatic differentiation. - * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d - */ - -/** - * Differentiable variable with value and derivative of differentiation ([deriv]) result - * with respect to this variable. - */ -open class Variable(val value: Double) { - constructor(x: Number) : this(x.toDouble()) -} - -class DerivationResult(value: Double, val deriv: Map) : Variable(value) { - fun deriv(variable: Variable) = deriv[variable] ?: 0.0 - - /** - * compute divergence - */ - fun div() = deriv.values.sum() - - /** - * Compute a gradient for variables in given order - */ - fun grad(vararg variables: Variable): Point = if (variables.isEmpty()) { - error("Variable order is not provided for gradient construction") - } else { - variables.map(::deriv).toDoubleArray().asBuffer() - } -} - -/** - * Runs differentiation and establishes [AutoDiffField] context inside the block of code. - * - * The partial derivatives are placed in argument `d` variable - * - * Example: - * ``` - * val x = Variable(2) // define variable(s) and their values - * val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context - * assertEquals(17.0, y.x) // the value of result (y) - * assertEquals(9.0, x.d) // dy/dx - * ``` - */ -fun deriv(body: AutoDiffField.() -> Variable): DerivationResult = - AutoDiffContext().run { - val result = body() - result.d = 1.0 // computing derivative w.r.t result - runBackwardPass() - DerivationResult(result.value, derivatives) - } - - -abstract class AutoDiffField : Field { - /** - * Performs update of derivative after the rest of the formula in the back-pass. - * - * For example, implementation of `sin` function is: - * - * ``` - * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result - * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function - * } - * ``` - */ - abstract fun derive(value: R, block: (R) -> Unit): R - - /** - * A variable accessing inner state of derivatives. - * Use this function in inner builders to avoid creating additional derivative bindings - */ - abstract var Variable.d: Double - - abstract fun variable(value: Double): Variable - - // Overloads for Double constants - - operator fun Number.plus(that: Variable): Variable = derive(variable(this.toDouble() + that.value)) { z -> - that.d += z.d - } - - operator fun Variable.plus(b: Number): Variable = b.plus(this) - - operator fun Number.minus(that: Variable): Variable = derive(variable(this.toDouble() - that.value)) { z -> - that.d -= z.d - } - - operator fun Variable.minus(that: Number): Variable = derive(variable(this.value - that.toDouble())) { z -> - this.d += z.d - } -} - -/** - * Automatic Differentiation context class. - */ -private class AutoDiffContext : AutoDiffField() { - - // this stack contains pairs of blocks and values to apply them to - private var stack = arrayOfNulls(8) - private var sp = 0 - - internal val derivatives = HashMap() - - - /** - * A variable coupled with its derivative. For internal use only - */ - class VariableWithDeriv(x: Double, var d: Double = 0.0) : Variable(x) - - - override fun variable(value: Double): Variable = VariableWithDeriv(value) - - override var Variable.d: Double - get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: 0.0 - set(value) { - if (this is VariableWithDeriv) { - d = value - } else { - derivatives[this] = value - } - } - - @Suppress("UNCHECKED_CAST") - override fun derive(value: R, block: (R) -> Unit): R { - // save block to stack for backward pass - if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) - stack[sp++] = block - stack[sp++] = value - return value - } - - @Suppress("UNCHECKED_CAST") - fun runBackwardPass() { - while (sp > 0) { - val value = stack[--sp] - val block = stack[--sp] as (Any?) -> Unit - block(value) - } - } - - // Basic math (+, -, *, /) - - - override fun add(a: Variable, b: Variable): Variable = - derive(variable(a.value + b.value)) { z -> - a.d += z.d - b.d += z.d - } - - override fun multiply(a: Variable, b: Variable): Variable = - derive(variable(a.value * b.value)) { z -> - a.d += z.d * b.value - b.d += z.d * a.value - } - - override fun divide(a: Variable, b: Variable): Variable = - derive(Variable(a.value / b.value)) { z -> - a.d += z.d / b.value - b.d -= z.d * a.value / (b.value * b.value) - } - - override fun multiply(a: Variable, k: Number): Variable = - derive(variable(k.toDouble() * a.value)) { z -> - a.d += z.d * k.toDouble() - } - - override val zero: Variable get() = Variable(0.0) - override val one: Variable get() = Variable(1.0) -} - -// Extensions for differentiation of various basic mathematical functions - -// x ^ 2 -fun AutoDiffField.sqr(x: Variable): Variable = derive(variable(x.value * x.value)) { z -> - x.d += z.d * 2 * x.value -} - -// x ^ 1/2 -fun AutoDiffField.sqrt(x: Variable): Variable = derive(variable(sqrt(x.value))) { z -> - x.d += z.d * 0.5 / z.value -} - -// x ^ y (const) -fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(variable(x.value.pow(y))) { z -> - x.d += z.d * y * x.value.pow(y - 1) -} - -fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble()) - -// exp(x) -fun AutoDiffField.exp(x: Variable): Variable = derive(variable(kotlin.math.exp(x.value))) { z -> - x.d += z.d * z.value -} - -// ln(x) -fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.value))) { z -> - x.d += z.d / x.value -} - -// x ^ y (any) -fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x)) - -// sin(x) -fun AutoDiffField.sin(x: Variable): Variable = derive(variable(kotlin.math.sin(x.value))) { z -> - x.d += z.d * kotlin.math.cos(x.value) -} - -// cos(x) -fun AutoDiffField.cos(x: Variable): Variable = derive(variable(kotlin.math.cos(x.value))) { z -> - x.d -= z.d * kotlin.math.sin(x.value) -} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 83fe3a4d9..a8c7f99fa 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -33,7 +33,7 @@ inline class Real(val value: Double) : FieldElement { * A field for double without boxing. Does not produce appropriate field element */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -object RealField : Field, ExtendedFieldOperations, Norm { +object RealField : ExtendedField, Norm { override val zero: Double = 0.0 override inline fun add(a: Double, b: Double) = a + b override inline fun multiply(a: Double, b: Double) = a * b @@ -64,7 +64,7 @@ object RealField : Field, ExtendedFieldOperations, Norm, ExtendedFieldOperations, Norm { +object FloatField : ExtendedField, Norm { override val zero: Float = 0f override inline fun add(a: Float, b: Float) = a + b override inline fun multiply(a: Float, b: Float) = a * b diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index 6aaa2da8f..a38f09c8d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -123,7 +123,7 @@ inline class ListBuffer(val list: List) : Buffer { override fun iterator(): Iterator = list.iterator() } -fun List.asBuffer() = ListBuffer(this) +fun List.asBuffer() = ListBuffer(this) @Suppress("FunctionName") inline fun ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer() diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt similarity index 95% rename from kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt rename to kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt index e4fad888e..9a2597737 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/AutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt @@ -1,5 +1,6 @@ -package scientifik.kmath.operations +package scientifik.kmath.misc +import scientifik.kmath.operations.RealField import scientifik.kmath.structures.asBuffer import kotlin.math.PI import kotlin.test.Test @@ -7,6 +8,11 @@ import kotlin.test.assertEquals import kotlin.test.assertTrue class AutoDiffTest { + + fun Variable(int: Int) = Variable(int.toDouble()) + + fun deriv(body: AutoDiffField.() -> Variable) = RealField.deriv(body) + @Test fun testPlusX2() { val x = Variable(3) // diff w.r.t this x at 3