Fine grained field types for autodiff
This commit is contained in:
parent
cee39e0666
commit
f706122266
@ -2,7 +2,6 @@ package scientifik.kmath.expressions
|
|||||||
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
import scientifik.kmath.operations.ExtendedField
|
import scientifik.kmath.operations.ExtendedField
|
||||||
import scientifik.kmath.operations.ExtendedFieldOperations
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
import kotlin.reflect.KProperty
|
import kotlin.reflect.KProperty
|
||||||
@ -10,8 +9,10 @@ import kotlin.reflect.KProperty
|
|||||||
/**
|
/**
|
||||||
* A field wrapping commons-math derivative structures
|
* A field wrapping commons-math derivative structures
|
||||||
*/
|
*/
|
||||||
class DerivativeStructureField(val order: Int, val parameters: Map<String, Double>) :
|
class DerivativeStructureField(
|
||||||
ExtendedField<DerivativeStructure> {
|
val order: Int,
|
||||||
|
val parameters: Map<String, Double>
|
||||||
|
) : ExtendedField<DerivativeStructure> {
|
||||||
|
|
||||||
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
|
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ open class Variable<T : Any>(val value: T)
|
|||||||
class DerivationResult<T : Any>(
|
class DerivationResult<T : Any>(
|
||||||
value: T,
|
value: T,
|
||||||
val deriv: Map<Variable<T>, T>,
|
val deriv: Map<Variable<T>, T>,
|
||||||
val context: ExtendedField<T>
|
val context: Field<T>
|
||||||
) : Variable<T>(value) {
|
) : Variable<T>(value) {
|
||||||
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero
|
fun deriv(variable: Variable<T>) = deriv[variable] ?: context.zero
|
||||||
|
|
||||||
@ -52,8 +52,8 @@ class DerivationResult<T : Any>(
|
|||||||
* assertEquals(9.0, x.d) // dy/dx
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
fun <T : Any> ExtendedField<T>.deriv(body: AutoDiffField<T>.() -> Variable<T>): DerivationResult<T> =
|
fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> =
|
||||||
AutoDiffContext<T>(this).run {
|
AutoDiffContext<T, F>(this).run {
|
||||||
val result = body()
|
val result = body()
|
||||||
result.d = context.one// computing derivative w.r.t result
|
result.d = context.one// computing derivative w.r.t result
|
||||||
runBackwardPass()
|
runBackwardPass()
|
||||||
@ -61,9 +61,9 @@ fun <T : Any> ExtendedField<T>.deriv(body: AutoDiffField<T>.() -> Variable<T>):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
|
abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
||||||
|
|
||||||
abstract val context: ExtendedField<T>
|
abstract val context: F
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||||
@ -76,7 +76,7 @@ abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
|
|||||||
* }
|
* }
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
abstract fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R
|
abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A variable accessing inner state of derivatives.
|
* A variable accessing inner state of derivatives.
|
||||||
@ -86,7 +86,7 @@ abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
|
|||||||
|
|
||||||
abstract fun variable(value: T): Variable<T>
|
abstract fun variable(value: T): Variable<T>
|
||||||
|
|
||||||
inline fun variable(block: ExtendedField<T>.() -> T) = variable(context.block())
|
inline fun variable(block: F.() -> T) = variable(context.block())
|
||||||
|
|
||||||
// Overloads for Double constants
|
// Overloads for Double constants
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ abstract class AutoDiffField<T : Any> : Field<Variable<T>> {
|
|||||||
/**
|
/**
|
||||||
* Automatic Differentiation context class.
|
* Automatic Differentiation context class.
|
||||||
*/
|
*/
|
||||||
private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) : AutoDiffField<T>() {
|
private class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
|
||||||
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
private var stack = arrayOfNulls<Any?>(8)
|
private var stack = arrayOfNulls<Any?>(8)
|
||||||
@ -123,11 +123,11 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
|
|||||||
/**
|
/**
|
||||||
* A variable coupled with its derivative. For internal use only
|
* A variable coupled with its derivative. For internal use only
|
||||||
*/
|
*/
|
||||||
inner class VariableWithDeriv(x: T, var d: T = context.zero) : Variable<T>(x)
|
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
|
||||||
|
|
||||||
|
|
||||||
override fun variable(value: T): Variable<T> =
|
override fun variable(value: T): Variable<T> =
|
||||||
VariableWithDeriv(value)
|
VariableWithDeriv(value, context.zero)
|
||||||
|
|
||||||
override var Variable<T>.d: T
|
override var Variable<T>.d: T
|
||||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
||||||
@ -140,7 +140,7 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun <R> derive(value: R, block: ExtendedField<T>.(R) -> Unit): R {
|
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||||
// save block to stack for backward pass
|
// save block to stack for backward pass
|
||||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||||
stack[sp++] = block
|
stack[sp++] = block
|
||||||
@ -152,7 +152,7 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
|
|||||||
fun runBackwardPass() {
|
fun runBackwardPass() {
|
||||||
while (sp > 0) {
|
while (sp > 0) {
|
||||||
val value = stack[--sp]
|
val value = stack[--sp]
|
||||||
val block = stack[--sp] as ExtendedField<T>.(Any?) -> Unit
|
val block = stack[--sp] as F.(Any?) -> Unit
|
||||||
context.block(value)
|
context.block(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -190,44 +190,50 @@ private class AutoDiffContext<T : Any>(override val context: ExtendedField<T>) :
|
|||||||
// Extensions for differentiation of various basic mathematical functions
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
// x ^ 2
|
// x ^ 2
|
||||||
fun <T : Any> AutoDiffField<T>.sqr(x: Variable<T>): Variable<T> = derive(variable { x.value * x.value }) { z ->
|
fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
|
||||||
|
derive(variable { x.value * x.value }) { z ->
|
||||||
x.d += z.d * 2 * x.value
|
x.d += z.d * 2 * x.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// x ^ 1/2
|
// x ^ 1/2
|
||||||
fun <T : Any> AutoDiffField<T>.sqrt(x: Variable<T>): Variable<T> = derive(variable { sqrt(x.value) }) { z ->
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
|
||||||
|
derive(variable { sqrt(x.value) }) { z ->
|
||||||
x.d += z.d * 0.5 / z.value
|
x.d += z.d * 0.5 / z.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// x ^ y (const)
|
// x ^ y (const)
|
||||||
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Double): Variable<T> =
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
|
||||||
derive(variable { power(x.value, y) }) { z ->
|
derive(variable { power(x.value, y) }) { z ->
|
||||||
x.d += z.d * y * power(x.value, y - 1)
|
x.d += z.d * y * power(x.value, y - 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> = pow(x, y.toDouble())
|
||||||
|
|
||||||
// exp(x)
|
// exp(x)
|
||||||
fun <T : Any> AutoDiffField<T>.exp(x: Variable<T>): Variable<T> = derive(variable { exp(x.value) }) { z ->
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
|
||||||
|
derive(variable { exp(x.value) }) { z ->
|
||||||
x.d += z.d * z.value
|
x.d += z.d * z.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// ln(x)
|
// ln(x)
|
||||||
fun <T : Any> AutoDiffField<T>.ln(x: Variable<T>): Variable<T> = derive(
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> = derive(
|
||||||
variable { ln(x.value) }
|
variable { ln(x.value) }
|
||||||
) { z ->
|
) { z ->
|
||||||
x.d += z.d / x.value
|
x.d += z.d / x.value
|
||||||
}
|
}
|
||||||
|
|
||||||
// x ^ y (any)
|
// x ^ y (any)
|
||||||
fun <T : Any> AutoDiffField<T>.pow(x: Variable<T>, y: Variable<T>): Variable<T> = exp(y * ln(x))
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
|
||||||
|
exp(y * ln(x))
|
||||||
|
|
||||||
// sin(x)
|
// sin(x)
|
||||||
fun <T : Any> AutoDiffField<T>.sin(x: Variable<T>): Variable<T> = derive(variable { sin(x.value) }) { z ->
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
|
||||||
|
derive(variable { sin(x.value) }) { z ->
|
||||||
x.d += z.d * cos(x.value)
|
x.d += z.d * cos(x.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// cos(x)
|
// cos(x)
|
||||||
fun <T : Any> AutoDiffField<T>.cos(x: Variable<T>): Variable<T> = derive(variable { cos(x.value) }) { z ->
|
fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
||||||
|
derive(variable { cos(x.value) }) { z ->
|
||||||
x.d -= z.d * sin(x.value)
|
x.d -= z.d * sin(x.value)
|
||||||
}
|
}
|
@ -1,7 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.ExtendedField
|
||||||
import scientifik.kmath.operations.ExtendedFieldOperations
|
import scientifik.kmath.operations.ExtendedFieldOperations
|
||||||
import scientifik.kmath.operations.Field
|
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class RealBufferField(val size: Int) : Field<Buffer<Double>>, ExtendedFieldOperations<Buffer<Double>> {
|
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
||||||
|
|
||||||
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ class AutoDiffTest {
|
|||||||
|
|
||||||
fun Variable(int: Int) = Variable(int.toDouble())
|
fun Variable(int: Int) = Variable(int.toDouble())
|
||||||
|
|
||||||
fun deriv(body: AutoDiffField<Double>.() -> Variable<Double>) = RealField.deriv(body)
|
fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>) = RealField.deriv(body)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPlusX2() {
|
fun testPlusX2() {
|
||||||
|
Loading…
Reference in New Issue
Block a user