From f706122266e579ffd93957b83f6995992c34a29e Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 6 Jun 2019 19:04:32 +0300 Subject: [PATCH] Fine grained field types for autodiff --- .../kmath/expressions/DiffExpression.kt | 7 +- .../kotlin/scientifik/kmath/misc/AutoDiff.kt | 68 ++++++++++--------- .../kmath/structures/RealBufferField.kt | 4 +- .../scientifik/kmath/misc/AutoDiffTest.kt | 2 +- 4 files changed, 44 insertions(+), 37 deletions(-) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt index 6c1759d54..5f6d0d8fe 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt @@ -2,7 +2,6 @@ package scientifik.kmath.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.operations.ExtendedField -import scientifik.kmath.operations.ExtendedFieldOperations import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KProperty @@ -10,8 +9,10 @@ import kotlin.reflect.KProperty /** * A field wrapping commons-math derivative structures */ -class DerivativeStructureField(val order: Int, val parameters: Map) : - ExtendedField { +class DerivativeStructureField( + val order: Int, + val parameters: Map +) : ExtendedField { override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt index 91256d114..ed77054cf 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/AutoDiff.kt @@ -20,7 +20,7 @@ open class Variable(val value: T) class DerivationResult( value: T, val deriv: Map, T>, - val context: ExtendedField + val context: Field ) : Variable(value) { fun deriv(variable: Variable) = deriv[variable] ?: context.zero @@ -52,8 +52,8 @@ class DerivationResult( * assertEquals(9.0, x.d) // dy/dx * ``` */ -fun ExtendedField.deriv(body: AutoDiffField.() -> Variable): DerivationResult = - AutoDiffContext(this).run { +fun > F.deriv(body: AutoDiffField.() -> Variable): DerivationResult = + AutoDiffContext(this).run { val result = body() result.d = context.one// computing derivative w.r.t result runBackwardPass() @@ -61,9 +61,9 @@ fun ExtendedField.deriv(body: AutoDiffField.() -> Variable): } -abstract class AutoDiffField : Field> { +abstract class AutoDiffField> : Field> { - abstract val context: ExtendedField + abstract val context: F /** * Performs update of derivative after the rest of the formula in the back-pass. @@ -76,7 +76,7 @@ abstract class AutoDiffField : Field> { * } * ``` */ - abstract fun derive(value: R, block: ExtendedField.(R) -> Unit): R + abstract fun derive(value: R, block: F.(R) -> Unit): R /** * A variable accessing inner state of derivatives. @@ -86,7 +86,7 @@ abstract class AutoDiffField : Field> { abstract fun variable(value: T): Variable - inline fun variable(block: ExtendedField.() -> T) = variable(context.block()) + inline fun variable(block: F.() -> T) = variable(context.block()) // Overloads for Double constants @@ -111,7 +111,7 @@ abstract class AutoDiffField : Field> { /** * Automatic Differentiation context class. */ -private class AutoDiffContext(override val context: ExtendedField) : AutoDiffField() { +private class AutoDiffContext>(override val context: F) : AutoDiffField() { // this stack contains pairs of blocks and values to apply them to private var stack = arrayOfNulls(8) @@ -123,11 +123,11 @@ private class AutoDiffContext(override val context: ExtendedField) : /** * A variable coupled with its derivative. For internal use only */ - inner class VariableWithDeriv(x: T, var d: T = context.zero) : Variable(x) + private class VariableWithDeriv(x: T, var d: T) : Variable(x) override fun variable(value: T): Variable = - VariableWithDeriv(value) + VariableWithDeriv(value, context.zero) override var Variable.d: T get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero @@ -140,7 +140,7 @@ private class AutoDiffContext(override val context: ExtendedField) : } @Suppress("UNCHECKED_CAST") - override fun derive(value: R, block: ExtendedField.(R) -> Unit): R { + override fun derive(value: R, block: F.(R) -> Unit): R { // save block to stack for backward pass if (sp >= stack.size) stack = stack.copyOf(stack.size * 2) stack[sp++] = block @@ -152,7 +152,7 @@ private class AutoDiffContext(override val context: ExtendedField) : fun runBackwardPass() { while (sp > 0) { val value = stack[--sp] - val block = stack[--sp] as ExtendedField.(Any?) -> Unit + val block = stack[--sp] as F.(Any?) -> Unit context.block(value) } } @@ -190,44 +190,50 @@ private class AutoDiffContext(override val context: ExtendedField) : // Extensions for differentiation of various basic mathematical functions // x ^ 2 -fun AutoDiffField.sqr(x: Variable): Variable = derive(variable { x.value * x.value }) { z -> - x.d += z.d * 2 * x.value -} +fun > AutoDiffField.sqr(x: Variable): Variable = + derive(variable { x.value * x.value }) { z -> + x.d += z.d * 2 * x.value + } // x ^ 1/2 -fun AutoDiffField.sqrt(x: Variable): Variable = derive(variable { sqrt(x.value) }) { z -> - x.d += z.d * 0.5 / z.value -} +fun > AutoDiffField.sqrt(x: Variable): Variable = + derive(variable { sqrt(x.value) }) { z -> + x.d += z.d * 0.5 / z.value + } // x ^ y (const) -fun AutoDiffField.pow(x: Variable, y: Double): Variable = +fun > AutoDiffField.pow(x: Variable, y: Double): Variable = derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } -fun AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble()) +fun > AutoDiffField.pow(x: Variable, y: Int): Variable = pow(x, y.toDouble()) // exp(x) -fun AutoDiffField.exp(x: Variable): Variable = derive(variable { exp(x.value) }) { z -> - x.d += z.d * z.value -} +fun > AutoDiffField.exp(x: Variable): Variable = + derive(variable { exp(x.value) }) { z -> + x.d += z.d * z.value + } // ln(x) -fun AutoDiffField.ln(x: Variable): Variable = derive( +fun > AutoDiffField.ln(x: Variable): Variable = derive( variable { ln(x.value) } ) { z -> x.d += z.d / x.value } // x ^ y (any) -fun AutoDiffField.pow(x: Variable, y: Variable): Variable = exp(y * ln(x)) +fun > AutoDiffField.pow(x: Variable, y: Variable): Variable = + exp(y * ln(x)) // sin(x) -fun AutoDiffField.sin(x: Variable): Variable = derive(variable { sin(x.value) }) { z -> - x.d += z.d * cos(x.value) -} +fun > AutoDiffField.sin(x: Variable): Variable = + derive(variable { sin(x.value) }) { z -> + x.d += z.d * cos(x.value) + } // cos(x) -fun AutoDiffField.cos(x: Variable): Variable = derive(variable { cos(x.value) }) { z -> - x.d -= z.d * sin(x.value) -} \ No newline at end of file +fun > AutoDiffField.cos(x: Variable): Variable = + derive(variable { cos(x.value) }) { z -> + x.d -= z.d * sin(x.value) + } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index b56f42717..88c8c29db 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -1,7 +1,7 @@ package scientifik.kmath.structures +import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.ExtendedFieldOperations -import scientifik.kmath.operations.Field import kotlin.math.* @@ -98,7 +98,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { } } -class RealBufferField(val size: Int) : Field>, ExtendedFieldOperations> { +class RealBufferField(val size: Int) : ExtendedField> { override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt index 9a2597737..9dc8f5ef7 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/misc/AutoDiffTest.kt @@ -11,7 +11,7 @@ class AutoDiffTest { fun Variable(int: Int) = Variable(int.toDouble()) - fun deriv(body: AutoDiffField.() -> Variable) = RealField.deriv(body) + fun deriv(body: AutoDiffField.() -> Variable) = RealField.deriv(body) @Test fun testPlusX2() {