Fine grained field types for autodiff
This commit is contained in:
parent
cee39e0666
commit
f706122266
@ -2,7 +2,6 @@ package scientifik.kmath.expressions
|
||||
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
import scientifik.kmath.operations.ExtendedFieldOperations
|
||||
import scientifik.kmath.operations.Field
|
||||
import kotlin.properties.ReadOnlyProperty
|
||||
import kotlin.reflect.KProperty
|
||||
@ -10,8 +9,10 @@ import kotlin.reflect.KProperty
|
||||
/**
|
||||
* A field wrapping commons-math derivative structures
|
||||
*/
|
||||
class DerivativeStructureField(val order: Int, val parameters: Map<String, Double>) :
|
||||
ExtendedField<DerivativeStructure> {
|
||||
class DerivativeStructureField(
|
||||
val order: Int,
|
||||
val parameters: Map<String, Double>
|
||||
) : ExtendedField<DerivativeStructure> {
|
||||
|
||||
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
|
||||
|
||||
|
@ -20,7 +20,7 @@ open class Variable<T : Any>(val value: T)
|
||||
class DerivationResult<T : Any>(
|
||||
value: T,
|
||||
val deriv: Map<Variable<T>, T>,
|
||||
val context: ExtendedField<T>
|
||||
val context: Field<T>
|
||||
) : Variable<T>(value) {
|
||||
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero
|
||||
|
||||
@ -52,8 +52,8 @@ class DerivationResult<T : Any>(
|
||||
* assertEquals(9.0, x.d) // dy/dx
|
||||
* ```
|
||||
*/
|
||||
fun <T : Any> ExtendedField<T>.deriv(body: AutoDiffField<T>.() -> Variable<T>): DerivationResult<T> =
|
||||
AutoDiffContext<T>(this).run {
|
||||
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> =
|
||||
AutoDiffContext<T, F>(this).run {
|
||||
val result = body()
|
||||
result.d = context.one// computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
@ -61,9 +61,9 @@ fun <T : Any> ExtendedField<T>.deriv(body: AutoDiffField<T>.() -> Variable<T>):
|
||||
}
|
||||
|
||||
|
||||
abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
|
||||
abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
||||
|
||||
abstract val context: ExtendedField<T>
|
||||
abstract val context: F
|
||||
|
||||
/**
|
||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||
@ -76,7 +76,7 @@ abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
abstract fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R
|
||||
abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
||||
|
||||
/**
|
||||
* A variable accessing inner state of derivatives.
|
||||
@ -86,7 +86,7 @@ abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
|
||||
|
||||
abstract fun variable(value: T): Variable<T>
|
||||
|
||||
inline fun variable(block: ExtendedField<T>.() -> T) = variable(context.block())
|
||||
inline fun variable(block: F.() -> T) = variable(context.block())
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
@ -111,7 +111,7 @@ abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
|
||||
/**
|
||||
* Automatic Differentiation context class.
|
||||
*/
|
||||
private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) : AutoDiffField<T>() {
|
||||
private class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
|
||||
|
||||
// this stack contains pairs of blocks and values to apply them to
|
||||
private var stack = arrayOfNulls<Any?>(8)
|
||||
@ -123,11 +123,11 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
|
||||
/**
|
||||
* A variable coupled with its derivative. For internal use only
|
||||
*/
|
||||
inner class VariableWithDeriv(x: T, var d: T = context.zero) : Variable<T>(x)
|
||||
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
|
||||
|
||||
|
||||
override fun variable(value: T): Variable<T> =
|
||||
VariableWithDeriv(value)
|
||||
VariableWithDeriv(value, context.zero)
|
||||
|
||||
override var Variable<T>.d: T
|
||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
||||
@ -140,7 +140,7 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R {
|
||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||
// save block to stack for backward pass
|
||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||
stack[sp++] = block
|
||||
@ -152,7 +152,7 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
|
||||
fun runBackwardPass() {
|
||||
while (sp > 0) {
|
||||
val value = stack[--sp]
|
||||
val block = stack[--sp] as ExtendedField<T>.(Any?) -> Unit
|
||||
val block = stack[--sp] as F.(Any?) -> Unit
|
||||
context.block(value)
|
||||
}
|
||||
}
|
||||
@ -190,44 +190,50 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
|
||||
// Extensions for differentiation of various basic mathematical functions
|
||||
|
||||
// x ^ 2
|
||||
fun <T : Any> AutoDiffField<T>.sqr(x: Variable<T>): Variable<T> = derive(variable { x.value * x.value }) { z ->
|
||||
fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
|
||||
derive(variable { x.value * x.value }) { z ->
|
||||
x.d += z.d * 2 * x.value
|
||||
}
|
||||
|
||||
// x ^ 1/2
|
||||
fun <T : Any> AutoDiffField<T>.sqrt(x: Variable<T>): Variable<T> = derive(variable { sqrt(x.value) }) { z ->
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
|
||||
derive(variable { sqrt(x.value) }) { z ->
|
||||
x.d += z.d * 0.5 / z.value
|
||||
}
|
||||
|
||||
// x ^ y (const)
|
||||
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Double): Variable<T> =
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
|
||||
derive(variable { power(x.value, y) }) { z ->
|
||||
x.d += z.d * y * power(x.value, y - 1)
|
||||
}
|
||||
|
||||
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
|
||||
|
||||
// exp(x)
|
||||
fun <T : Any> AutoDiffField<T>.exp(x: Variable<T>): Variable<T> = derive(variable { exp(x.value) }) { z ->
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
|
||||
derive(variable { exp(x.value) }) { z ->
|
||||
x.d += z.d * z.value
|
||||
}
|
||||
|
||||
// ln(x)
|
||||
fun <T : Any> AutoDiffField<T>.ln(x: Variable<T>): Variable<T> = derive(
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> = derive(
|
||||
variable { ln(x.value) }
|
||||
) { z ->
|
||||
x.d += z.d / x.value
|
||||
}
|
||||
|
||||
// x ^ y (any)
|
||||
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Variable<T>): Variable<T> = exp(y * ln(x))
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
|
||||
exp(y * ln(x))
|
||||
|
||||
// sin(x)
|
||||
fun <T : Any> AutoDiffField<T>.sin(x: Variable<T>): Variable<T> = derive(variable { sin(x.value) }) { z ->
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
|
||||
derive(variable { sin(x.value) }) { z ->
|
||||
x.d += z.d * cos(x.value)
|
||||
}
|
||||
|
||||
// cos(x)
|
||||
fun <T : Any> AutoDiffField<T>.cos(x: Variable<T>): Variable<T> = derive(variable { cos(x.value) }) { z ->
|
||||
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
||||
derive(variable { cos(x.value) }) { z ->
|
||||
x.d -= z.d * sin(x.value)
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
import scientifik.kmath.operations.ExtendedFieldOperations
|
||||
import scientifik.kmath.operations.Field
|
||||
import kotlin.math.*
|
||||
|
||||
|
||||
@ -98,7 +98,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||
}
|
||||
}
|
||||
|
||||
class RealBufferField(val size: Int) : Field<Buffer<Double>>, ExtendedFieldOperations<Buffer<Double>> {
|
||||
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
||||
|
||||
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
||||
|
||||
|
@ -11,7 +11,7 @@ class AutoDiffTest {
|
||||
|
||||
fun Variable(int: Int) = Variable(int.toDouble())
|
||||
|
||||
fun deriv(body: AutoDiffField<Double>.() -> Variable<Double>) = RealField.deriv(body)
|
||||
fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>) = RealField.deriv(body)
|
||||
|
||||
@Test
|
||||
fun testPlusX2() {
|
||||
|
Loading…
Reference in New Issue
Block a user