forked from kscience/kmath
Update DifferentiableExpression by providing second type argument representing the result of differentiation
This commit is contained in:
parent
ef7066b8c9
commit
d14e437659
@ -3,7 +3,6 @@ package kscience.kmath.ast
|
|||||||
import kscience.kmath.asm.compile
|
import kscience.kmath.asm.compile
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
import kscience.kmath.expressions.symbol
|
import kscience.kmath.expressions.symbol
|
||||||
import kscience.kmath.kotlingrad.DifferentiableMstExpression
|
|
||||||
import kscience.kmath.kotlingrad.derivative
|
import kscience.kmath.kotlingrad.derivative
|
||||||
import kscience.kmath.kotlingrad.differentiable
|
import kscience.kmath.kotlingrad.differentiable
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
|
@ -95,10 +95,10 @@ public class DerivativeStructureField(
|
|||||||
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||||
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||||
|
|
||||||
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
public companion object :
|
||||||
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
|
AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> {
|
||||||
return DerivativeStructureExpression(function)
|
public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double, Expression<Double>> =
|
||||||
}
|
DerivativeStructureExpression(function)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ public class DerivativeStructureField(
|
|||||||
*/
|
*/
|
||||||
public class DerivativeStructureExpression(
|
public class DerivativeStructureExpression(
|
||||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||||
) : DifferentiableExpression<Double> {
|
) : DifferentiableExpression<Double, Expression<Double>> {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
DerivativeStructureField(0, arguments).function().value
|
DerivativeStructureField(0, arguments).function().value
|
||||||
|
|
||||||
|
@ -1,29 +1,37 @@
|
|||||||
package kscience.kmath.expressions
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An expression that provides derivatives
|
* Represents expression which structure can be differentiated.
|
||||||
|
*
|
||||||
|
* @param T the type this expression takes as argument and returns.
|
||||||
|
* @param R the type of expression this expression can be differentiated to.
|
||||||
*/
|
*/
|
||||||
public interface DifferentiableExpression<T> : Expression<T> {
|
public interface DifferentiableExpression<T, R : Expression<T>> : Expression<T> {
|
||||||
public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>?
|
/**
|
||||||
|
* Differentiates this expression by ordered collection of [symbols].
|
||||||
|
*/
|
||||||
|
public fun derivativeOrNull(symbols: List<Symbol>): R?
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> =
|
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R =
|
||||||
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> =
|
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R =
|
||||||
derivative(symbols.toList())
|
derivative(symbols.toList())
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(name: String): R =
|
||||||
derivative(StringSymbol(name))
|
derivative(StringSymbol(name))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A [DifferentiableExpression] that defines only first derivatives
|
* A [DifferentiableExpression] that defines only first derivatives
|
||||||
*/
|
*/
|
||||||
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
public abstract class FirstDerivativeExpression<T, R : Expression<T>> : DifferentiableExpression<T,R> {
|
||||||
|
/**
|
||||||
|
* Returns first derivative of this expression by given [symbol].
|
||||||
|
*/
|
||||||
|
public abstract fun derivativeOrNull(symbol: Symbol): R?
|
||||||
|
|
||||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
public final override fun derivativeOrNull(symbols: List<Symbol>): R? {
|
||||||
|
|
||||||
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
|
|
||||||
val dSymbol = symbols.firstOrNull() ?: return null
|
val dSymbol = symbols.firstOrNull() ?: return null
|
||||||
return derivativeOrNull(dSymbol)
|
return derivativeOrNull(dSymbol)
|
||||||
}
|
}
|
||||||
@ -32,6 +40,6 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
|
|||||||
/**
|
/**
|
||||||
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
||||||
*/
|
*/
|
||||||
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
|
public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, R : Expression<T>> {
|
||||||
public fun process(function: A.() -> I): DifferentiableExpression<T>
|
public fun process(function: A.() -> I): DifferentiableExpression<T, R>
|
||||||
}
|
}
|
@ -22,7 +22,9 @@ public inline class StringSymbol(override val identity: String) : Symbol {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An elementary function that could be invoked on a map of arguments
|
* An elementary function that could be invoked on a map of arguments.
|
||||||
|
*
|
||||||
|
* @param T the type this expression takes as argument and returns.
|
||||||
*/
|
*/
|
||||||
public fun interface Expression<T> {
|
public fun interface Expression<T> {
|
||||||
/**
|
/**
|
||||||
|
@ -68,7 +68,7 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
|||||||
): DerivationResult<T> {
|
): DerivationResult<T> {
|
||||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
return SimpleAutoDiffField(this, bindings).derivate(body)
|
return SimpleAutoDiffField(this, bindings).differentiate(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
@ -83,12 +83,21 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
public val context: F,
|
public val context: F,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||||
|
public override val zero: AutoDiffValue<T>
|
||||||
|
get() = const(context.zero)
|
||||||
|
|
||||||
|
public override val one: AutoDiffValue<T>
|
||||||
|
get() = const(context.one)
|
||||||
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
private var sp: Int = 0
|
private var sp: Int = 0
|
||||||
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||||
|
|
||||||
|
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||||
|
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||||
* with respect to this variable.
|
* with respect to this variable.
|
||||||
@ -106,11 +115,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
override fun hashCode(): Int = identity.hashCode()
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
public override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
||||||
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
|
||||||
|
|
||||||
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
||||||
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
||||||
@ -119,7 +124,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
private fun runBackwardPass() {
|
private fun runBackwardPass() {
|
||||||
while (sp > 0) {
|
while (sp > 0) {
|
||||||
@ -129,9 +133,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
|
||||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
|
||||||
|
|
||||||
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -165,7 +166,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||||
val result = function()
|
val result = function()
|
||||||
result.d = context.one // computing derivative w.r.t result
|
result.d = context.one // computing derivative w.r.t result
|
||||||
runBackwardPass()
|
runBackwardPass()
|
||||||
@ -174,41 +175,41 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
|
|
||||||
// Overloads for Double constants
|
// Overloads for Double constants
|
||||||
|
|
||||||
override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
||||||
b.d += z.d
|
b.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
|
public override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
|
||||||
|
|
||||||
override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||||
|
|
||||||
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
public override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
||||||
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||||
|
|
||||||
|
|
||||||
// Basic math (+, -, *, /)
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value + b.value }) { z ->
|
derive(const { a.value + b.value }) { z ->
|
||||||
a.d += z.d
|
a.d += z.d
|
||||||
b.d += z.d
|
b.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value * b.value }) { z ->
|
derive(const { a.value * b.value }) { z ->
|
||||||
a.d += z.d * b.value
|
a.d += z.d * b.value
|
||||||
b.d += z.d * a.value
|
b.d += z.d * a.value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value / b.value }) { z ->
|
derive(const { a.value / b.value }) { z ->
|
||||||
a.d += z.d / b.value
|
a.d += z.d / b.value
|
||||||
b.d -= z.d * a.value / (b.value * b.value)
|
b.d -= z.d * a.value / (b.value * b.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
|
public override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
|
||||||
derive(const { k.toDouble() * a.value }) { z ->
|
derive(const { k.toDouble() * a.value }) { z ->
|
||||||
a.d += z.d * k.toDouble()
|
a.d += z.d * k.toDouble()
|
||||||
}
|
}
|
||||||
@ -220,15 +221,15 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
public val field: F,
|
public val field: F,
|
||||||
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
) : FirstDerivativeExpression<T>() {
|
) : FirstDerivativeExpression<T, Expression<T>>() {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
return SimpleAutoDiffField(field, arguments).function().value
|
return SimpleAutoDiffField(field, arguments).function().value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
public override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
|
val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
|
||||||
derivationResult.derivative(symbol)
|
derivationResult.derivative(symbol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -236,13 +237,10 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
|||||||
/**
|
/**
|
||||||
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>, Expression<T>> =
|
||||||
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
AutoDiffProcessor { function ->
|
||||||
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
|
SimpleAutoDiffExpression(field, function)
|
||||||
return SimpleAutoDiffExpression(field, function)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ import kscience.kmath.ast.MST
|
|||||||
import kscience.kmath.ast.MstAlgebra
|
import kscience.kmath.ast.MstAlgebra
|
||||||
import kscience.kmath.ast.MstExpression
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.expressions.DifferentiableExpression
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
import kscience.kmath.expressions.StringSymbol
|
|
||||||
import kscience.kmath.expressions.Symbol
|
import kscience.kmath.expressions.Symbol
|
||||||
import kscience.kmath.operations.NumericAlgebra
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
@ -20,7 +19,7 @@ import kscience.kmath.operations.NumericAlgebra
|
|||||||
* @property expr the underlying [MstExpression].
|
* @property expr the underlying [MstExpression].
|
||||||
*/
|
*/
|
||||||
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
|
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
|
||||||
DifferentiableExpression<T> where A : NumericAlgebra<T>, T : Number {
|
DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T>, T : Number {
|
||||||
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -47,15 +46,6 @@ public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpres
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Number, A : NumericAlgebra<T>> DifferentiableMstExpression<T, A>.derivative(symbols: List<Symbol>): MstExpression<T, A> =
|
|
||||||
derivativeOrNull(symbols)
|
|
||||||
|
|
||||||
public fun <T : Number, A : NumericAlgebra<T>> DifferentiableMstExpression<T, A>.derivative(vararg symbols: Symbol): MstExpression<T, A> =
|
|
||||||
derivative(symbols.toList())
|
|
||||||
|
|
||||||
public fun <T : Number, A : NumericAlgebra<T>> DifferentiableMstExpression<T, A>.derivative(name: String): MstExpression<T, A> =
|
|
||||||
derivative(StringSymbol(name))
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
||||||
*/
|
*/
|
||||||
|
Loading…
Reference in New Issue
Block a user