Fine grained field types for autodiff

This commit is contained in:
Alexander Nozik 2019-06-06 19:04:32 +03:00
parent cee39e0666
commit f706122266
4 changed files with 44 additions and 37 deletions

View File

@ -2,7 +2,6 @@ package scientifik.kmath.expressions
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.ExtendedFieldOperations
import scientifik.kmath.operations.Field
import kotlin.properties.ReadOnlyProperty
import kotlin.reflect.KProperty
@ -10,8 +9,10 @@ import kotlin.reflect.KProperty
/**
* A field wrapping commons-math derivative structures
*/
class DerivativeStructureField(val order: Int, val parameters: Map<String, Double>) :
ExtendedField<DerivativeStructure> {
class DerivativeStructureField(
val order: Int,
val parameters: Map<String, Double>
) : ExtendedField<DerivativeStructure> {
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }

View File

@ -20,7 +20,7 @@ open class Variable<T : Any>(val value: T)
class DerivationResult<T : Any>(
value: T,
val deriv: Map<Variable<T>, T>,
val context: ExtendedField<T>
val context: Field<T>
) : Variable<T>(value) {
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero
@ -52,8 +52,8 @@ class DerivationResult<T : Any>(
* assertEquals(9.0, x.d) // dy/dx
* ```
*/
fun <T : Any> ExtendedField<T>.deriv(body: AutoDiffField<T>.() -> Variable<T>): DerivationResult<T> =
AutoDiffContext<T>(this).run {
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> =
AutoDiffContext<T, F>(this).run {
val result = body()
result.d = context.one// computing derivative w.r.t result
runBackwardPass()
@ -61,9 +61,9 @@ fun <T : Any> ExtendedField<T>.deriv(body: AutoDiffField<T>.() -> Variable<T>):
}
abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
abstract val context: ExtendedField<T>
abstract val context: F
/**
* Performs update of derivative after the rest of the formula in the back-pass.
@ -76,7 +76,7 @@ abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
* }
* ```
*/
abstract fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R
abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
/**
* A variable accessing inner state of derivatives.
@ -86,7 +86,7 @@ abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
abstract fun variable(value: T): Variable<T>
inline fun variable(block: ExtendedField<T>.() -> T) = variable(context.block())
inline fun variable(block: F.() -> T) = variable(context.block())
// Overloads for Double constants
@ -111,7 +111,7 @@ abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
/**
* Automatic Differentiation context class.
*/
private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) : AutoDiffField<T>() {
private class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to
private var stack = arrayOfNulls<Any?>(8)
@ -123,11 +123,11 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
/**
* A variable coupled with its derivative. For internal use only
*/
inner class VariableWithDeriv(x: T, var d: T = context.zero) : Variable<T>(x)
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
override fun variable(value: T): Variable<T> =
VariableWithDeriv(value)
VariableWithDeriv(value, context.zero)
override var Variable<T>.d: T
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
@ -140,7 +140,7 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
}
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R {
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
@ -152,7 +152,7 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as ExtendedField<T>.(Any?) -> Unit
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
@ -190,44 +190,50 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
fun <T : Any> AutoDiffField<T>.sqr(x: Variable<T>): Variable<T> = derive(variable { x.value * x.value }) { z ->
fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
derive(variable { x.value * x.value }) { z ->
x.d += z.d * 2 * x.value
}
// x ^ 1/2
fun <T : Any> AutoDiffField<T>.sqrt(x: Variable<T>): Variable<T> = derive(variable { sqrt(x.value) }) { z ->
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
derive(variable { sqrt(x.value) }) { z ->
x.d += z.d * 0.5 / z.value
}
// x ^ y (const)
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Double): Variable<T> =
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
derive(variable { power(x.value, y) }) { z ->
x.d += z.d * y * power(x.value, y - 1)
}
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
// exp(x)
fun <T : Any> AutoDiffField<T>.exp(x: Variable<T>): Variable<T> = derive(variable { exp(x.value) }) { z ->
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
derive(variable { exp(x.value) }) { z ->
x.d += z.d * z.value
}
// ln(x)
fun <T : Any> AutoDiffField<T>.ln(x: Variable<T>): Variable<T> = derive(
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> = derive(
variable { ln(x.value) }
) { z ->
x.d += z.d / x.value
}
// x ^ y (any)
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Variable<T>): Variable<T> = exp(y * ln(x))
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
exp(y * ln(x))
// sin(x)
fun <T : Any> AutoDiffField<T>.sin(x: Variable<T>): Variable<T> = derive(variable { sin(x.value) }) { z ->
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
derive(variable { sin(x.value) }) { z ->
x.d += z.d * cos(x.value)
}
// cos(x)
fun <T : Any> AutoDiffField<T>.cos(x: Variable<T>): Variable<T> = derive(variable { cos(x.value) }) { z ->
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z ->
x.d -= z.d * sin(x.value)
}

View File

@ -1,7 +1,7 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.ExtendedFieldOperations
import scientifik.kmath.operations.Field
import kotlin.math.*
@ -98,7 +98,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
}
}
class RealBufferField(val size: Int) : Field<Buffer<Double>>, ExtendedFieldOperations<Buffer<Double>> {
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }

View File

@ -11,7 +11,7 @@ class AutoDiffTest {
fun Variable(int: Int) = Variable(int.toDouble())
fun deriv(body: AutoDiffField<Double>.() -> Variable<Double>) = RealField.deriv(body)
fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>) = RealField.deriv(body)
@Test
fun testPlusX2() {