From ae07652d9ec60ba632d985281c799d9493653e36 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 21 Oct 2020 11:38:28 +0300 Subject: [PATCH] Symbol identity is always a string --- .../DerivativeStructureExpression.kt | 4 +- .../kscience/kmath/expressions/Expression.kt | 2 +- .../kmath/expressions/SimpleAutoDiff.kt | 140 +++++++++--------- .../kmath/expressions/SimpleAutoDiffTest.kt | 8 +- 4 files changed, 73 insertions(+), 81 deletions(-) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 9a27e40cd..2ec69255e 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -25,13 +25,13 @@ public class DerivativeStructureField( */ public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) : DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol { - override val identity: Any = symbol.identity + override val identity: String = symbol.identity } /** * Identity-based symbol bindings map */ - private val variables: Map = bindings.entries.associate { (key, value) -> + private val variables: Map = bindings.entries.associate { (key, value) -> key.identity to DerivativeStructureSymbol(key, value) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index d64eb5a55..bd83261f7 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -12,7 +12,7 @@ public interface Symbol { * Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol. * By default uses object identity */ - public val identity: Any get() = this + public val identity: String } /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index 5e8fe3e99..a718154d3 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -12,20 +12,7 @@ import kotlin.contracts.contract */ -/** - * A [Symbol] with bound value - */ -public interface BoundSymbol : Symbol { - public val value: T -} - -/** - * Bind a [Symbol] to a [value] and produce [BoundSymbol] - */ -public fun Symbol.bind(value: T): BoundSymbol = object : BoundSymbol { - override val identity = this@bind.identity - override val value: T = value -} +public open class AutoDiffValue(public val value: T) /** * Represents result of [withAutoDiff] call. @@ -36,10 +23,10 @@ public fun Symbol.bind(value: T): BoundSymbol = object : BoundSymbol { * @property context The field over [T]. */ public class DerivationResult( - override val value: T, - private val derivativeValues: Map, + public val value: T, + private val derivativeValues: Map, public val context: Field, -) : BoundSymbol { +) { /** * Returns derivative of [variable] or returns [Ring.zero] in [context]. */ @@ -76,8 +63,8 @@ public fun DerivationResult.grad(vararg variables: Symbol): Point> F.withAutoDiff( - bindings: Collection>, - body: AutoDiffField.() -> BoundSymbol, + bindings: Map, + body: AutoDiffField.() -> AutoDiffValue, ): DerivationResult { contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } @@ -86,14 +73,14 @@ public fun > F.withAutoDiff( public fun > F.withAutoDiff( vararg bindings: Pair, - body: AutoDiffField.() -> BoundSymbol, -): DerivationResult = withAutoDiff(bindings.map { it.first.bind(it.second) }, body) + body: AutoDiffField.() -> AutoDiffValue, +): DerivationResult = withAutoDiff(bindings.toMap(), body) /** * Represents field in context of which functions can be derived. */ public abstract class AutoDiffField> - : Field>, ExpressionAlgebra> { + : Field>, ExpressionAlgebra> { public abstract val context: F @@ -101,7 +88,7 @@ public abstract class AutoDiffField> * A variable accessing inner state of derivatives. * Use this value in inner builders to avoid creating additional derivative bindings. */ - public abstract var BoundSymbol.d: T + public abstract var AutoDiffValue.d: T /** * Performs update of derivative after the rest of the formula in the back-pass. @@ -116,21 +103,21 @@ public abstract class AutoDiffField> */ public abstract fun derive(value: R, block: F.(R) -> Unit): R - public inline fun const(block: F.() -> T): BoundSymbol = const(context.block()) + public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block()) // Overloads for Double constants - override operator fun Number.plus(b: BoundSymbol): BoundSymbol = + override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue = derive(const { this@plus.toDouble() * one + b.value }) { z -> b.d += z.d } - override operator fun BoundSymbol.plus(b: Number): BoundSymbol = b.plus(this) + override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this) - override operator fun Number.minus(b: BoundSymbol): BoundSymbol = + override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue = derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } - override operator fun BoundSymbol.minus(b: Number): BoundSymbol = + override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue = derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } } @@ -139,14 +126,14 @@ public abstract class AutoDiffField> */ private class AutoDiffContext>( override val context: F, - bindings: Collection>, + bindings: Map, ) : AutoDiffField() { // this stack contains pairs of blocks and values to apply them to private var stack: Array = arrayOfNulls(8) private var sp: Int = 0 - private val derivatives: MutableMap = hashMapOf() - override val zero: BoundSymbol get() = const(context.zero) - override val one: BoundSymbol get() = const(context.one) + private val derivatives: MutableMap, T> = hashMapOf() + override val zero: AutoDiffValue get() = const(context.zero) + override val one: AutoDiffValue get() = const(context.one) /** * Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result @@ -155,17 +142,23 @@ private class AutoDiffContext>( * @param T the non-nullable type of value. * @property value The value of this variable. */ - private class AutoDiffVariableWithDeriv(override val value: T, var d: T) : BoundSymbol + private class AutoDiffVariableWithDeriv( + override val identity: String, + value: T, + var d: T, + ) : AutoDiffValue(value), Symbol - private val bindings: Map> = bindings.associateBy { it.identity } + private val bindings: Map> = bindings.entries.associate { + it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero) + } - override fun bindOrNull(symbol: Symbol): BoundSymbol? = bindings[symbol.identity] + override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv? = bindings[symbol.identity] - override fun const(value: T): BoundSymbol = AutoDiffVariableWithDeriv(value, context.zero) + override fun const(value: T): AutoDiffValue = AutoDiffValue(value) - override var BoundSymbol.d: T - get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero - set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value + override var AutoDiffValue.d: T + get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero + set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value @Suppress("UNCHECKED_CAST") override fun derive(value: R, block: F.(R) -> Unit): R { @@ -187,34 +180,34 @@ private class AutoDiffContext>( // Basic math (+, -, *, /) - override fun add(a: BoundSymbol, b: BoundSymbol): BoundSymbol = + override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = derive(const { a.value + b.value }) { z -> a.d += z.d b.d += z.d } - override fun multiply(a: BoundSymbol, b: BoundSymbol): BoundSymbol = + override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = derive(const { a.value * b.value }) { z -> a.d += z.d * b.value b.d += z.d * a.value } - override fun divide(a: BoundSymbol, b: BoundSymbol): BoundSymbol = + override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue = derive(const { a.value / b.value }) { z -> a.d += z.d / b.value b.d -= z.d * a.value / (b.value * b.value) } - override fun multiply(a: BoundSymbol, k: Number): BoundSymbol = + override fun multiply(a: AutoDiffValue, k: Number): AutoDiffValue = derive(const { k.toDouble() * a.value }) { z -> a.d += z.d * k.toDouble() } - inline fun derivate(function: AutoDiffField.() -> BoundSymbol): DerivationResult { + inline fun derivate(function: AutoDiffField.() -> AutoDiffValue): DerivationResult { val result = function() result.d = context.one // computing derivative w.r.t result runBackwardPass() - return DerivationResult(result.value, derivatives, context) + return DerivationResult(result.value, bindings.mapValues { it.value.d }, context) } } @@ -223,11 +216,11 @@ private class AutoDiffContext>( */ public class SimpleAutoDiffExpression>( public val field: F, - public val function: AutoDiffField.() -> BoundSymbol, + public val function: AutoDiffField.() -> AutoDiffValue, ) : DifferentiableExpression { public override operator fun invoke(arguments: Map): T { - val bindings = arguments.entries.map { it.key.bind(it.value) } - return AutoDiffContext(field, bindings).function().value + //val bindings = arguments.entries.map { it.key.bind(it.value) } + return AutoDiffContext(field, arguments).function().value } /** @@ -237,8 +230,8 @@ public class SimpleAutoDiffExpression>( val dSymbol = orders.entries.singleOrNull { it.value == 1 } ?: error("SimpleAutoDiff supports only first order derivatives") return Expression { arguments -> - val bindings = arguments.entries.map { it.key.bind(it.value) } - val derivationResult = AutoDiffContext(field, bindings).derivate(function) + //val bindings = arguments.entries.map { it.key.bind(it.value) } + val derivationResult = AutoDiffContext(field, arguments).derivate(function) derivationResult.derivative(dSymbol.key) } } @@ -248,82 +241,81 @@ public class SimpleAutoDiffExpression>( // Extensions for differentiation of various basic mathematical functions // x ^ 2 -public fun > AutoDiffField.sqr(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.sqr(x: AutoDiffValue): AutoDiffValue = derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } // x ^ 1/2 -public fun > AutoDiffField.sqrt(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.sqrt(x: AutoDiffValue): AutoDiffValue = derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } // x ^ y (const) public fun > AutoDiffField.pow( - x: BoundSymbol, + x: AutoDiffValue, y: Double, -): BoundSymbol = +): AutoDiffValue = derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } public fun > AutoDiffField.pow( - x: BoundSymbol, + x: AutoDiffValue, y: Int, -): BoundSymbol = - pow(x, y.toDouble()) +): AutoDiffValue = pow(x, y.toDouble()) // exp(x) -public fun > AutoDiffField.exp(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.exp(x: AutoDiffValue): AutoDiffValue = derive(const { exp(x.value) }) { z -> x.d += z.d * z.value } // ln(x) -public fun > AutoDiffField.ln(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.ln(x: AutoDiffValue): AutoDiffValue = derive(const { ln(x.value) }) { z -> x.d += z.d / x.value } // x ^ y (any) public fun > AutoDiffField.pow( - x: BoundSymbol, - y: BoundSymbol, -): BoundSymbol = + x: AutoDiffValue, + y: AutoDiffValue, +): AutoDiffValue = exp(y * ln(x)) // sin(x) -public fun > AutoDiffField.sin(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.sin(x: AutoDiffValue): AutoDiffValue = derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } // cos(x) -public fun > AutoDiffField.cos(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.cos(x: AutoDiffValue): AutoDiffValue = derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } -public fun > AutoDiffField.tan(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.tan(x: AutoDiffValue): AutoDiffValue = derive(const { tan(x.value) }) { z -> val c = cos(x.value) x.d += z.d / (c * c) } -public fun > AutoDiffField.asin(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.asin(x: AutoDiffValue): AutoDiffValue = derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) } -public fun > AutoDiffField.acos(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.acos(x: AutoDiffValue): AutoDiffValue = derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) } -public fun > AutoDiffField.atan(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.atan(x: AutoDiffValue): AutoDiffValue = derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) } -public fun > AutoDiffField.sinh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.sinh(x: AutoDiffValue): AutoDiffValue = derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) } -public fun > AutoDiffField.cosh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.cosh(x: AutoDiffValue): AutoDiffValue = derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) } -public fun > AutoDiffField.tanh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.tanh(x: AutoDiffValue): AutoDiffValue = derive(const { tan(x.value) }) { z -> val c = cosh(x.value) x.d += z.d / (c * c) } -public fun > AutoDiffField.asinh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.asinh(x: AutoDiffValue): AutoDiffValue = derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) } -public fun > AutoDiffField.acosh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.acosh(x: AutoDiffValue): AutoDiffValue = derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) } -public fun > AutoDiffField.atanh(x: BoundSymbol): BoundSymbol = +public fun > AutoDiffField.atanh(x: AutoDiffValue): AutoDiffValue = derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) } diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt index ca5b626fd..ef4a6a06a 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -12,23 +12,23 @@ import kotlin.test.assertTrue class SimpleAutoDiffTest { fun d( vararg bindings: Pair, - body: AutoDiffField.() -> BoundSymbol, + body: AutoDiffField.() -> AutoDiffValue, ): DerivationResult = RealField.withAutoDiff(bindings = bindings, body) fun dx( xBinding: Pair, - body: AutoDiffField.(x: BoundSymbol) -> BoundSymbol, + body: AutoDiffField.(x: AutoDiffValue) -> AutoDiffValue, ): DerivationResult = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) } fun dxy( xBinding: Pair, yBinding: Pair, - body: AutoDiffField.(x: BoundSymbol, y: BoundSymbol) -> BoundSymbol, + body: AutoDiffField.(x: AutoDiffValue, y: AutoDiffValue) -> AutoDiffValue, ): DerivationResult = RealField.withAutoDiff(xBinding, yBinding) { body(bind(xBinding.first), bind(yBinding.first)) } - fun diff(block: AutoDiffField.() -> BoundSymbol): SimpleAutoDiffExpression { + fun diff(block: AutoDiffField.() -> AutoDiffValue): SimpleAutoDiffExpression { return SimpleAutoDiffExpression(RealField, block) }