forked from kscience/kmath
Generalized Elizarov's algorithm for ExtendedField
This commit is contained in:
parent
fcc6269ee8
commit
cee39e0666
@ -0,0 +1,233 @@
|
||||
package scientifik.kmath.misc
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.sum
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
|
||||
/*
|
||||
* Implementation of backward-mode automatic differentiation.
|
||||
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||
*/
|
||||
|
||||
/**
|
||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
||||
* with respect to this variable.
|
||||
*/
|
||||
open class Variable<T : Any>(val value: T)
|
||||
|
||||
class DerivationResult<T : Any>(
|
||||
value: T,
|
||||
val deriv: Map<Variable<T>, T>,
|
||||
val context: ExtendedField<T>
|
||||
) : Variable<T>(value) {
|
||||
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero
|
||||
|
||||
/**
|
||||
* compute divergence
|
||||
*/
|
||||
fun div() = context.run { sum(deriv.values) }
|
||||
|
||||
/**
|
||||
* Compute a gradient for variables in given order
|
||||
*/
|
||||
fun grad(vararg variables: Variable<T>): Point<T> = if (variables.isEmpty()) {
|
||||
error("Variable order is not provided for gradient construction")
|
||||
} else {
|
||||
variables.map(::deriv).asBuffer()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||
*
|
||||
* The partial derivatives are placed in argument `d` variable
|
||||
*
|
||||
* Example:
|
||||
* ```
|
||||
* val x = Variable(2) // define variable(s) and their values
|
||||
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||
* assertEquals(17.0, y.x) // the value of result (y)
|
||||
* assertEquals(9.0, x.d) // dy/dx
|
||||
* ```
|
||||
*/
|
||||
fun <T : Any> ExtendedField<T>.deriv(body: AutoDiffField<T>.() -> Variable<T>): DerivationResult<T> =
|
||||
AutoDiffContext<T>(this).run {
|
||||
val result = body()
|
||||
result.d = context.one// computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
DerivationResult(result.value, derivatives, this@deriv)
|
||||
}
|
||||
|
||||
|
||||
abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
|
||||
|
||||
abstract val context: ExtendedField<T>
|
||||
|
||||
/**
|
||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||
*
|
||||
* For example, implementation of `sin` function is:
|
||||
*
|
||||
* ```
|
||||
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
||||
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
abstract fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R
|
||||
|
||||
/**
|
||||
* A variable accessing inner state of derivatives.
|
||||
* Use this function in inner builders to avoid creating additional derivative bindings
|
||||
*/
|
||||
abstract var Variable<T>.d: T
|
||||
|
||||
abstract fun variable(value: T): Variable<T>
|
||||
|
||||
inline fun variable(block: ExtendedField<T>.() -> T) = variable(context.block())
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
operator fun Number.plus(that: Variable<T>): Variable<T> =
|
||||
derive(variable { this@plus.toDouble() * one + that.value }) { z ->
|
||||
that.d += z.d
|
||||
}
|
||||
|
||||
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
||||
|
||||
operator fun Number.minus(that: Variable<T>): Variable<T> =
|
||||
derive(variable { this@minus.toDouble() * one - that.value }) { z ->
|
||||
that.d -= z.d
|
||||
}
|
||||
|
||||
operator fun Variable<T>.minus(that: Number): Variable<T> =
|
||||
derive(variable { this@minus.value - one * that.toDouble() }) { z ->
|
||||
this@minus.d += z.d
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Automatic Differentiation context class.
|
||||
*/
|
||||
private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) : AutoDiffField<T>() {
|
||||
|
||||
// this stack contains pairs of blocks and values to apply them to
|
||||
private var stack = arrayOfNulls<Any?>(8)
|
||||
private var sp = 0
|
||||
|
||||
internal val derivatives = HashMap<Variable<T>, T>()
|
||||
|
||||
|
||||
/**
|
||||
* A variable coupled with its derivative. For internal use only
|
||||
*/
|
||||
inner class VariableWithDeriv(x: T, var d: T = context.zero) : Variable<T>(x)
|
||||
|
||||
|
||||
override fun variable(value: T): Variable<T> =
|
||||
VariableWithDeriv(value)
|
||||
|
||||
override var Variable<T>.d: T
|
||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
||||
set(value) {
|
||||
if (this is VariableWithDeriv) {
|
||||
d = value
|
||||
} else {
|
||||
derivatives[this] = value
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R {
|
||||
// save block to stack for backward pass
|
||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||
stack[sp++] = block
|
||||
stack[sp++] = value
|
||||
return value
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun runBackwardPass() {
|
||||
while (sp > 0) {
|
||||
val value = stack[--sp]
|
||||
val block = stack[--sp] as ExtendedField<T>.(Any?) -> Unit
|
||||
context.block(value)
|
||||
}
|
||||
}
|
||||
|
||||
// Basic math (+, -, *, /)
|
||||
|
||||
|
||||
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> =
|
||||
derive(variable { a.value + b.value }) { z ->
|
||||
a.d += z.d
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> =
|
||||
derive(variable { a.value * b.value }) { z ->
|
||||
a.d += z.d * b.value
|
||||
b.d += z.d * a.value
|
||||
}
|
||||
|
||||
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> =
|
||||
derive(variable { a.value / b.value }) { z ->
|
||||
a.d += z.d / b.value
|
||||
b.d -= z.d * a.value / (b.value * b.value)
|
||||
}
|
||||
|
||||
override fun multiply(a: Variable<T>, k: Number): Variable<T> =
|
||||
derive(variable { k.toDouble() * a.value }) { z ->
|
||||
a.d += z.d * k.toDouble()
|
||||
}
|
||||
|
||||
override val zero: Variable<T> get() = Variable(context.zero)
|
||||
override val one: Variable<T> get() = Variable(context.one)
|
||||
}
|
||||
|
||||
// Extensions for differentiation of various basic mathematical functions
|
||||
|
||||
// x ^ 2
|
||||
fun <T : Any> AutoDiffField<T>.sqr(x: Variable<T>): Variable<T> = derive(variable { x.value * x.value }) { z ->
|
||||
x.d += z.d * 2 * x.value
|
||||
}
|
||||
|
||||
// x ^ 1/2
|
||||
fun <T : Any> AutoDiffField<T>.sqrt(x: Variable<T>): Variable<T> = derive(variable { sqrt(x.value) }) { z ->
|
||||
x.d += z.d * 0.5 / z.value
|
||||
}
|
||||
|
||||
// x ^ y (const)
|
||||
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Double): Variable<T> =
|
||||
derive(variable { power(x.value, y) }) { z ->
|
||||
x.d += z.d * y * power(x.value, y - 1)
|
||||
}
|
||||
|
||||
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
|
||||
|
||||
// exp(x)
|
||||
fun <T : Any> AutoDiffField<T>.exp(x: Variable<T>): Variable<T> = derive(variable { exp(x.value) }) { z ->
|
||||
x.d += z.d * z.value
|
||||
}
|
||||
|
||||
// ln(x)
|
||||
fun <T : Any> AutoDiffField<T>.ln(x: Variable<T>): Variable<T> = derive(
|
||||
variable { ln(x.value) }
|
||||
) { z ->
|
||||
x.d += z.d / x.value
|
||||
}
|
||||
|
||||
// x ^ y (any)
|
||||
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Variable<T>): Variable<T> = exp(y * ln(x))
|
||||
|
||||
// sin(x)
|
||||
fun <T : Any> AutoDiffField<T>.sin(x: Variable<T>): Variable<T> = derive(variable { sin(x.value) }) { z ->
|
||||
x.d += z.d * cos(x.value)
|
||||
}
|
||||
|
||||
// cos(x)
|
||||
fun <T : Any> AutoDiffField<T>.cos(x: Variable<T>): Variable<T> = derive(variable { cos(x.value) }) { z ->
|
||||
x.d -= z.d * sin(x.value)
|
||||
}
|
@ -1,218 +0,0 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/*
|
||||
* Implementation of backward-mode automatic differentiation.
|
||||
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||
*/
|
||||
|
||||
/**
|
||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
||||
* with respect to this variable.
|
||||
*/
|
||||
open class Variable(val value: Double) {
|
||||
constructor(x: Number) : this(x.toDouble())
|
||||
}
|
||||
|
||||
class DerivationResult(value: Double, val deriv: Map<Variable, Double>) : Variable(value) {
|
||||
fun deriv(variable: Variable) = deriv[variable] ?: 0.0
|
||||
|
||||
/**
|
||||
* compute divergence
|
||||
*/
|
||||
fun div() = deriv.values.sum()
|
||||
|
||||
/**
|
||||
* Compute a gradient for variables in given order
|
||||
*/
|
||||
fun grad(vararg variables: Variable): Point<Double> = if (variables.isEmpty()) {
|
||||
error("Variable order is not provided for gradient construction")
|
||||
} else {
|
||||
variables.map(::deriv).toDoubleArray().asBuffer()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||
*
|
||||
* The partial derivatives are placed in argument `d` variable
|
||||
*
|
||||
* Example:
|
||||
* ```
|
||||
* val x = Variable(2) // define variable(s) and their values
|
||||
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||
* assertEquals(17.0, y.x) // the value of result (y)
|
||||
* assertEquals(9.0, x.d) // dy/dx
|
||||
* ```
|
||||
*/
|
||||
fun deriv(body: AutoDiffField.() -> Variable): DerivationResult =
|
||||
AutoDiffContext().run {
|
||||
val result = body()
|
||||
result.d = 1.0 // computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
DerivationResult(result.value, derivatives)
|
||||
}
|
||||
|
||||
|
||||
abstract class AutoDiffField : Field<Variable> {
|
||||
/**
|
||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||
*
|
||||
* For example, implementation of `sin` function is:
|
||||
*
|
||||
* ```
|
||||
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
||||
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
abstract fun <R> derive(value: R, block: (R) -> Unit): R
|
||||
|
||||
/**
|
||||
* A variable accessing inner state of derivatives.
|
||||
* Use this function in inner builders to avoid creating additional derivative bindings
|
||||
*/
|
||||
abstract var Variable.d: Double
|
||||
|
||||
abstract fun variable(value: Double): Variable
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
operator fun Number.plus(that: Variable): Variable = derive(variable(this.toDouble() + that.value)) { z ->
|
||||
that.d += z.d
|
||||
}
|
||||
|
||||
operator fun Variable.plus(b: Number): Variable = b.plus(this)
|
||||
|
||||
operator fun Number.minus(that: Variable): Variable = derive(variable(this.toDouble() - that.value)) { z ->
|
||||
that.d -= z.d
|
||||
}
|
||||
|
||||
operator fun Variable.minus(that: Number): Variable = derive(variable(this.value - that.toDouble())) { z ->
|
||||
this.d += z.d
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Automatic Differentiation context class.
|
||||
*/
|
||||
private class AutoDiffContext : AutoDiffField() {
|
||||
|
||||
// this stack contains pairs of blocks and values to apply them to
|
||||
private var stack = arrayOfNulls<Any?>(8)
|
||||
private var sp = 0
|
||||
|
||||
internal val derivatives = HashMap<Variable, Double>()
|
||||
|
||||
|
||||
/**
|
||||
* A variable coupled with its derivative. For internal use only
|
||||
*/
|
||||
class VariableWithDeriv(x: Double, var d: Double = 0.0) : Variable(x)
|
||||
|
||||
|
||||
override fun variable(value: Double): Variable = VariableWithDeriv(value)
|
||||
|
||||
override var Variable.d: Double
|
||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: 0.0
|
||||
set(value) {
|
||||
if (this is VariableWithDeriv) {
|
||||
d = value
|
||||
} else {
|
||||
derivatives[this] = value
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <R> derive(value: R, block: (R) -> Unit): R {
|
||||
// save block to stack for backward pass
|
||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||
stack[sp++] = block
|
||||
stack[sp++] = value
|
||||
return value
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun runBackwardPass() {
|
||||
while (sp > 0) {
|
||||
val value = stack[--sp]
|
||||
val block = stack[--sp] as (Any?) -> Unit
|
||||
block(value)
|
||||
}
|
||||
}
|
||||
|
||||
// Basic math (+, -, *, /)
|
||||
|
||||
|
||||
override fun add(a: Variable, b: Variable): Variable =
|
||||
derive(variable(a.value + b.value)) { z ->
|
||||
a.d += z.d
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
override fun multiply(a: Variable, b: Variable): Variable =
|
||||
derive(variable(a.value * b.value)) { z ->
|
||||
a.d += z.d * b.value
|
||||
b.d += z.d * a.value
|
||||
}
|
||||
|
||||
override fun divide(a: Variable, b: Variable): Variable =
|
||||
derive(Variable(a.value / b.value)) { z ->
|
||||
a.d += z.d / b.value
|
||||
b.d -= z.d * a.value / (b.value * b.value)
|
||||
}
|
||||
|
||||
override fun multiply(a: Variable, k: Number): Variable =
|
||||
derive(variable(k.toDouble() * a.value)) { z ->
|
||||
a.d += z.d * k.toDouble()
|
||||
}
|
||||
|
||||
override val zero: Variable get() = Variable(0.0)
|
||||
override val one: Variable get() = Variable(1.0)
|
||||
}
|
||||
|
||||
// Extensions for differentiation of various basic mathematical functions
|
||||
|
||||
// x ^ 2
|
||||
fun AutoDiffField.sqr(x: Variable): Variable = derive(variable(x.value * x.value)) { z ->
|
||||
x.d += z.d * 2 * x.value
|
||||
}
|
||||
|
||||
// x ^ 1/2
|
||||
fun AutoDiffField.sqrt(x: Variable): Variable = derive(variable(sqrt(x.value))) { z ->
|
||||
x.d += z.d * 0.5 / z.value
|
||||
}
|
||||
|
||||
// x ^ y (const)
|
||||
fun AutoDiffField.pow(x: Variable, y: Double): Variable = derive(variable(x.value.pow(y))) { z ->
|
||||
x.d += z.d * y * x.value.pow(y - 1)
|
||||
}
|
||||
|
||||
fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble())
|
||||
|
||||
// exp(x)
|
||||
fun AutoDiffField.exp(x: Variable): Variable = derive(variable(kotlin.math.exp(x.value))) { z ->
|
||||
x.d += z.d * z.value
|
||||
}
|
||||
|
||||
// ln(x)
|
||||
fun AutoDiffField.ln(x: Variable): Variable = derive(Variable(kotlin.math.ln(x.value))) { z ->
|
||||
x.d += z.d / x.value
|
||||
}
|
||||
|
||||
// x ^ y (any)
|
||||
fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x))
|
||||
|
||||
// sin(x)
|
||||
fun AutoDiffField.sin(x: Variable): Variable = derive(variable(kotlin.math.sin(x.value))) { z ->
|
||||
x.d += z.d * kotlin.math.cos(x.value)
|
||||
}
|
||||
|
||||
// cos(x)
|
||||
fun AutoDiffField.cos(x: Variable): Variable = derive(variable(kotlin.math.cos(x.value))) { z ->
|
||||
x.d -= z.d * kotlin.math.sin(x.value)
|
||||
}
|
@ -33,7 +33,7 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
||||
* A field for double without boxing. Does not produce appropriate field element
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
|
||||
object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
override val zero: Double = 0.0
|
||||
override inline fun add(a: Double, b: Double) = a + b
|
||||
override inline fun multiply(a: Double, b: Double) = a * b
|
||||
@ -64,7 +64,7 @@ object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double,
|
||||
}
|
||||
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
object FloatField : Field<Float>, ExtendedFieldOperations<Float>, Norm<Float, Float> {
|
||||
object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
override val zero: Float = 0f
|
||||
override inline fun add(a: Float, b: Float) = a + b
|
||||
override inline fun multiply(a: Float, b: Float) = a * b
|
||||
|
@ -123,7 +123,7 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
||||
override fun iterator(): Iterator<T> = list.iterator()
|
||||
}
|
||||
|
||||
fun <T> List<T>.asBuffer() = ListBuffer(this)
|
||||
fun <T> List<T>.asBuffer() = ListBuffer<T>(this)
|
||||
|
||||
@Suppress("FunctionName")
|
||||
inline fun <T> ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer()
|
||||
|
@ -1,5 +1,6 @@
|
||||
package scientifik.kmath.operations
|
||||
package scientifik.kmath.misc
|
||||
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.structures.asBuffer
|
||||
import kotlin.math.PI
|
||||
import kotlin.test.Test
|
||||
@ -7,6 +8,11 @@ import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class AutoDiffTest {
|
||||
|
||||
fun Variable(int: Int) = Variable(int.toDouble())
|
||||
|
||||
fun deriv(body: AutoDiffField<Double>.() -> Variable<Double>) = RealField.deriv(body)
|
||||
|
||||
@Test
|
||||
fun testPlusX2() {
|
||||
val x = Variable(3) // diff w.r.t this x at 3
|
Loading…
Reference in New Issue
Block a user