New Expression API

This commit is contained in:
Alexander Nozik 2020-10-19 22:51:33 +03:00
parent e44423192d
commit 707ad21f77
17 changed files with 794 additions and 672 deletions

View File

@ -53,9 +53,7 @@ can be used for a wide variety of purposes from high performance calculations to
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
to submit a feature request if you want something to be done first.
* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
## Planned features
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
@ -117,6 +115,12 @@ can be used for a wide variety of purposes from high performance calculations to
> **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-ejml](kmath-ejml)
>
>
> **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-for-real](kmath-for-real)
>
>
@ -178,8 +182,8 @@ repositories{
}
dependencies{
api("kscience.kmath:kmath-core:0.2.0-dev-1")
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-1") for jvm-specific version
api("kscience.kmath:kmath-core:0.2.0-dev-2")
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
}
```

View File

@ -24,4 +24,8 @@ subprojects {
readme {
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
}
apiValidation{
validationDisabled = true
}

View File

@ -107,4 +107,4 @@ with the same artifact names.
## Contributing
The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).

View File

@ -14,8 +14,8 @@ import kotlin.contracts.contract
* @author Alexander Nozik
*/
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
@ -27,7 +27,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
error("Numeric nodes are not supported by $this")
}
override operator fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
}
/**
@ -37,7 +37,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
*/
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E,
block: E.() -> MST
block: E.() -> MST,
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
/**
@ -116,7 +116,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
* @author Iaroslav Postovalov
*/
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
block: MstExtendedField.() -> MST
block: MstExtendedField.() -> MST,
): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block)

View File

@ -2,6 +2,8 @@
package kscience.kmath.asm.internal
import kscience.kmath.expressions.StringSymbol
/**
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
*
@ -9,4 +11,4 @@ package kscience.kmath.asm.internal
*/
@JvmOverloads
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
this[key] ?: default ?: error("Parameter not found: $key")
this[StringSymbol(key.toString())] ?: default ?: error("Parameter not found: $key")

View File

@ -1,6 +1,5 @@
package kscience.kmath.asm
import kscience.kmath.asm.compile
import kscience.kmath.ast.mstInField
import kscience.kmath.ast.mstInRing
import kscience.kmath.ast.mstInSpace
@ -11,6 +10,7 @@ import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmAlgebras {
@Test
fun space() {
val res1 = ByteRing.mstInSpace {

View File

@ -1,48 +1,57 @@
package kscience.kmath.commons.expressions
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.ExpressionAlgebra
import kscience.kmath.expressions.Symbol
import kscience.kmath.operations.ExtendedField
import kscience.kmath.operations.Field
import kscience.kmath.operations.invoke
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import kotlin.properties.ReadOnlyProperty
/**
* A field over commons-math [DerivativeStructure].
*
* @property order The derivation order.
* @property parameters The map of free parameters.
* @property bindings The map of bindings values. All bindings are considered free parameters
*/
public class DerivativeStructureField(
public val order: Int,
public val parameters: Map<String, Double>,
) : ExtendedField<DerivativeStructure> {
public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) }
public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) }
private val bindings: Map<Symbol, Double>
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
/**
* A class that implements both [DerivativeStructure] and a [Symbol]
*/
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
override val identity: Any = symbol.identity
}
public val variable: ReadOnlyProperty<Any?, DerivativeStructure> = ReadOnlyProperty { _, property ->
variables[property.name] ?: error("A variable with name ${property.name} does not exist")
/**
* Identity-based symbol bindings map
*/
private val variables: Map<Any?, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
key.identity to DerivativeStructureSymbol(key, value)
}
public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
variables[name] ?: default ?: error("A variable with name $name does not exist")
override fun const(value: Double): DerivativeStructure = DerivativeStructure(order, bindings.size, value)
public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
return deriv(mapOf(parName to order))
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
public fun Number.const(): DerivativeStructure = const(toDouble())
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
return derivative(mapOf(parameter to order))
}
public fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double {
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
}
public fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
@ -85,48 +94,16 @@ public class DerivativeStructureField(
/**
* A constructs that creates a derivative structure with required order on-demand
*/
public class DiffExpression(
public class DerivativeStructureExpression(
public val function: DerivativeStructureField.() -> DerivativeStructure,
) : Expression<Double> {
public override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
0,
arguments
).function().value
) : DifferentiableExpression<Double> {
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
DerivativeStructureField(0, arguments).function().value
/**
* Get the derivative expression with given orders
* TODO make result [DiffExpression]
*/
public fun derivative(orders: Map<String, Int>): Expression<Double> = Expression { arguments ->
(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().deriv(orders) }
public override fun derivative(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
}
//TODO add gradient and maybe other vector operators
}
public fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
public fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
/**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
*/
public object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
public override val zero: DiffExpression = DiffExpression { 0.0.const() }
public override val one: DiffExpression = DiffExpression { 1.0.const() }
public override fun variable(name: String, default: Double?): DiffExpression =
DiffExpression { variable(name, default?.const()) }
public override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) + b.function(this) }
public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k }
public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) * b.function(this) }
public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) / b.function(this) }
}

View File

@ -1,6 +1,6 @@
package kscience.kmath.commons.expressions
import kscience.kmath.expressions.invoke
import kscience.kmath.expressions.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.test.Test
@ -8,33 +8,37 @@ import kotlin.test.assertEquals
internal inline fun <R> diff(
order: Int,
vararg parameters: Pair<String, Double>,
block: DerivativeStructureField.() -> R
vararg parameters: Pair<Symbol, Double>,
block: DerivativeStructureField.() -> R,
): R {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
}
internal class AutoDiffTest {
private val x by symbol
private val y by symbol
@Test
fun derivativeStructureFieldTest() {
val res: Double = diff(3, "x" to 1.0, "y" to 1.0) {
val x by variable
val y = variable("y")
val res: Double = diff(3, x to 1.0, y to 1.0) {
val x = bind(x)//by binding()
val y = symbol("y")
val z = x * (-sin(x * y) + y)
z.deriv("x")
z.derivative(x)
}
println(res)
}
@Test
fun autoDifTest() {
val f = DiffExpression {
val x by variable
val y by variable
val f = DerivativeStructureExpression {
val x by binding()
val y by binding()
x.pow(2) + 2 * x * y + y.pow(2) + 1
}
assertEquals(10.0, f("x" to 1.0, "y" to 2.0))
assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0))
assertEquals(10.0, f(x to 1.0, y to 2.0))
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
}
}

View File

@ -12,7 +12,7 @@ The core features of KMath:
> #### Artifact:
>
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-1`.
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
>
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
>
@ -22,25 +22,28 @@ The core features of KMath:
>
> ```gradle
> repositories {
> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
> }
>
> dependencies {
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-1'
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
> }
> ```
> **Gradle Kotlin DSL:**
>
> ```kotlin
> repositories {
> maven("https://dl.bintray.com/kotlin/kotlin-eap")
> maven("https://dl.bintray.com/mipt-npm/kscience")
> maven("https://dl.bintray.com/mipt-npm/dev")
> maven("https://dl.bintray.com/hotkeytlt/maven")
> }
>
> dependencies {
> implementation("kscience.kmath:kmath-core:0.2.0-dev-1")
> implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
> }
> ```

View File

@ -41,6 +41,6 @@ readme {
feature(
id = "autodif",
description = "Automatic differentiation",
ref = "src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt"
ref = "src/commonMain/kotlin/kscience/kmath/misc/SimpleAutoDiff.kt"
)
}

View File

@ -1,6 +1,26 @@
package kscience.kmath.expressions
import kscience.kmath.operations.Algebra
import kotlin.jvm.JvmName
import kotlin.properties.ReadOnlyProperty
/**
* A marker interface for a symbol. A symbol mus have an identity
*/
public interface Symbol {
/**
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
* By default uses object identity
*/
public val identity: Any get() = this
}
/**
* A [Symbol] with a [String] identity
*/
public inline class StringSymbol(override val identity: String) : Symbol {
override fun toString(): String = identity
}
/**
* An elementary function that could be invoked on a map of arguments
@ -12,30 +32,81 @@ public fun interface Expression<T> {
* @param arguments the map of arguments.
* @return the value.
*/
public operator fun invoke(arguments: Map<String, T>): T
public operator fun invoke(arguments: Map<Symbol, T>): T
public companion object
}
/**
* Invlode an expression without parameters
*/
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
//This method exists to avoid resolution ambiguity of vararg methods
/**
* Calls this expression from arguments.
*
* @param pairs the pair of arguments' names to values.
* @return the value.
*/
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
@JvmName("callBySymbol")
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
@JvmName("callByString")
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
/**
* And object that could be differentiated
*/
public interface Differentiable<T> {
public fun derivative(orders: Map<Symbol, Int>): T
}
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
derivative(mapOf(*orders))
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
/**
* A context for expression construction
*
* @param T type of the constants for the expression
* @param E type of the actual expression state
*/
public interface ExpressionAlgebra<T, E> : Algebra<E> {
public interface ExpressionAlgebra<in T, E> : Algebra<E> {
/**
* Introduce a variable into expression context
* Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
*/
public fun variable(name: String, default: T? = null): E
public fun bindOrNull(symbol: Symbol): E?
/**
* Bind a string to a context using [StringSymbol]
*/
override fun symbol(value: String): E = bind(StringSymbol(value))
/**
* A constant expression which does not depend on arguments
*/
public fun const(value: T): E
}
/**
* Bind a given [Symbol] to this context variable and produce context-specific object.
*/
public fun <T, E> ExpressionAlgebra<T, E>.bind(symbol: Symbol): E =
bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this")
public val symbol: ReadOnlyProperty<Any?, Symbol> = ReadOnlyProperty { _, property ->
StringSymbol(property.name)
}
public fun <T, E> ExpressionAlgebra<T, E>.binding(): ReadOnlyProperty<Any?, E> =
ReadOnlyProperty { _, property ->
bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
}

View File

@ -2,39 +2,6 @@ package kscience.kmath.expressions
import kscience.kmath.operations.*
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T =
context.unaryOperation(name, expr.invoke(arguments))
}
internal class FunctionalBinaryOperation<T>(
val context: Algebra<T>,
val name: String,
val first: Expression<T>,
val second: Expression<T>
) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T =
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
}
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T = value
}
internal class FunctionalConstProductExpression<T>(
val context: Space<T>,
private val expr: Expression<T>,
val const: Number
) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
/**
* A context class for [Expression] construction.
*
@ -45,24 +12,32 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val
/**
* Builds an Expression of constant expression which does not depend on arguments.
*/
public override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
public override fun const(value: T): Expression<T> = Expression { value }
/**
* Builds an Expression to access a variable.
*/
public override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
public override fun bindOrNull(symbol: Symbol): Expression<T>? = Expression { arguments ->
arguments[symbol] ?: error("Argument not found: $symbol")
}
/**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, operation, left, right)
public override fun binaryOperation(
operation: String,
left: Expression<T>,
right: Expression<T>,
): Expression<T> = Expression { arguments ->
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
}
/**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
FunctionalUnaryOperation(algebra, operation, arg)
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
algebra.unaryOperation(operation, arg.invoke(arguments))
}
}
/**
@ -81,8 +56,9 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
/**
* Builds an Expression of multiplication of expression by number.
*/
public override fun multiply(a: Expression<T>, k: Number): Expression<T> =
FunctionalConstProductExpression(algebra, a, k)
public override fun multiply(a: Expression<T>, k: Number): Expression<T> = Expression { arguments ->
algebra.multiply(a.invoke(arguments), k)
}
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
@ -118,8 +94,8 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
}
public open class FunctionalExpressionField<T, A>(algebra: A) :
FunctionalExpressionRing<T, A>(algebra),
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>
where A : Field<T>, A : NumericAlgebra<T> {
/**
* Builds an Expression of division an expression by another one.
*/

View File

@ -0,0 +1,329 @@
package kscience.kmath.expressions
import kscience.kmath.linear.Point
import kscience.kmath.operations.*
import kscience.kmath.structures.asBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/*
* Implementation of backward-mode automatic differentiation.
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
*/
/**
* A [Symbol] with bound value
*/
public interface BoundSymbol<out T> : Symbol {
public val value: T
}
/**
* Bind a [Symbol] to a [value] and produce [BoundSymbol]
*/
public fun <T> Symbol.bind(value: T): BoundSymbol<T> = object : BoundSymbol<T> {
override val identity = this@bind.identity
override val value: T = value
}
/**
* Represents result of [withAutoDiff] call.
*
* @param T the non-nullable type of value.
* @param value the value of result.
* @property withAutoDiff The mapping of differentiated variables to their derivatives.
* @property context The field over [T].
*/
public class DerivationResult<T : Any>(
override val value: T,
private val derivativeValues: Map<Any, T>,
public val context: Field<T>,
) : BoundSymbol<T> {
/**
* Returns derivative of [variable] or returns [Ring.zero] in [context].
*/
public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero
/**
* Computes the divergence.
*/
public fun div(): T = context { sum(derivativeValues.values) }
}
/**
* Computes the gradient for variables in given order.
*/
public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> {
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
return variables.map(::derivative).asBuffer()
}
/**
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
*
* The partial derivatives are placed in argument `d` variable
*
* Example:
* ```
* val x by symbol // define variable(s) and their values
* val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
* assertEquals(17.0, y.x) // the value of result (y)
* assertEquals(9.0, x.d) // dy/dx
* ```
*
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
* @return the result of differentiation.
*/
public fun <T : Any, F : Field<T>> F.withAutoDiff(
bindings: Collection<BoundSymbol<T>>,
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return AutoDiffContext(this, bindings).derivate(body)
}
public fun <T : Any, F : Field<T>> F.withAutoDiff(
vararg bindings: Pair<Symbol, T>,
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
): DerivationResult<T> = withAutoDiff(bindings.map { it.first.bind(it.second) }, body)
/**
* Represents field in context of which functions can be derived.
*/
public abstract class AutoDiffField<T : Any, F : Field<T>>
: Field<BoundSymbol<T>>, ExpressionAlgebra<T, BoundSymbol<T>> {
public abstract val context: F
/**
* A variable accessing inner state of derivatives.
* Use this value in inner builders to avoid creating additional derivative bindings.
*/
public abstract var BoundSymbol<T>.d: T
/**
* Performs update of derivative after the rest of the formula in the back-pass.
*
* For example, implementation of `sin` function is:
*
* ```
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
* }
* ```
*/
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
public inline fun const(block: F.() -> T): BoundSymbol<T> = const(context.block())
// Overloads for Double constants
override operator fun Number.plus(b: BoundSymbol<T>): BoundSymbol<T> =
derive(const { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d
}
override operator fun BoundSymbol<T>.plus(b: Number): BoundSymbol<T> = b.plus(this)
override operator fun Number.minus(b: BoundSymbol<T>): BoundSymbol<T> =
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
override operator fun BoundSymbol<T>.minus(b: Number): BoundSymbol<T> =
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
}
/**
* Automatic Differentiation context class.
*/
private class AutoDiffContext<T : Any, F : Field<T>>(
override val context: F,
bindings: Collection<BoundSymbol<T>>,
) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0
private val derivatives: MutableMap<Any, T> = hashMapOf()
override val zero: BoundSymbol<T> get() = const(context.zero)
override val one: BoundSymbol<T> get() = const(context.one)
/**
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
* with respect to this variable.
*
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
private class AutoDiffVariableWithDeriv<T : Any>(override val value: T, var d: T) : BoundSymbol<T>
private val bindings: Map<Any, BoundSymbol<T>> = bindings.associateBy { it.identity }
override fun bindOrNull(symbol: Symbol): BoundSymbol<T>? = bindings[symbol.identity]
override fun const(value: T): BoundSymbol<T> = AutoDiffVariableWithDeriv(value, context.zero)
override var BoundSymbol<T>.d: T
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
@Suppress("UNCHECKED_CAST")
fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
// Basic math (+, -, *, /)
override fun add(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
derive(const { a.value + b.value }) { z ->
a.d += z.d
b.d += z.d
}
override fun multiply(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
derive(const { a.value * b.value }) { z ->
a.d += z.d * b.value
b.d += z.d * a.value
}
override fun divide(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
derive(const { a.value / b.value }) { z ->
a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value)
}
override fun multiply(a: BoundSymbol<T>, k: Number): BoundSymbol<T> =
derive(const { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble()
}
inline fun derivate(function: AutoDiffField<T, F>.() -> BoundSymbol<T>): DerivationResult<T> {
val result = function()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
return DerivationResult(result.value, derivatives, context)
}
}
/**
* A constructs that creates a derivative structure with required order on-demand
*/
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F,
public val function: AutoDiffField<T, F>.() -> BoundSymbol<T>,
) : DifferentiableExpression<T> {
public override operator fun invoke(arguments: Map<Symbol, T>): T {
val bindings = arguments.entries.map { it.key.bind(it.value) }
return AutoDiffContext(field, bindings).function().value
}
/**
* Get the derivative expression with given orders
*/
public override fun derivative(orders: Map<Symbol, Int>): Expression<T> {
val dSymbol = orders.entries.singleOrNull { it.value == 1 }
?: error("SimpleAutoDiff supports only first order derivatives")
return Expression { arguments ->
val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = AutoDiffContext(field, bindings).derivate(function)
derivationResult.derivative(dSymbol.key)
}
}
}
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
x: BoundSymbol<T>,
y: Double,
): BoundSymbol<T> =
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
x: BoundSymbol<T>,
y: Int,
): BoundSymbol<T> =
pow(x, y.toDouble())
// exp(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
x: BoundSymbol<T>,
y: BoundSymbol<T>,
): BoundSymbol<T> =
exp(y * ln(x))
// sin(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { tan(x.value) }) { z ->
val c = cos(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { tan(x.value) }) { z ->
val c = cosh(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: BoundSymbol<T>): BoundSymbol<T> =
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }

View File

@ -1,266 +0,0 @@
package kscience.kmath.misc
import kscience.kmath.linear.Point
import kscience.kmath.operations.*
import kscience.kmath.structures.asBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/*
* Implementation of backward-mode automatic differentiation.
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
*/
/**
* Differentiable variable with value and derivative of differentiation ([deriv]) result
* with respect to this variable.
*
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
public open class Variable<T : Any>(public val value: T)
/**
* Represents result of [deriv] call.
*
* @param T the non-nullable type of value.
* @param value the value of result.
* @property deriv The mapping of differentiated variables to their derivatives.
* @property context The field over [T].
*/
public class DerivationResult<T : Any>(
value: T,
public val deriv: Map<Variable<T>, T>,
public val context: Field<T>
) : Variable<T>(value) {
/**
* Returns derivative of [variable] or returns [Ring.zero] in [context].
*/
public fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
/**
* Computes the divergence.
*/
public fun div(): T = context { sum(deriv.values) }
/**
* Computes the gradient for variables in given order.
*/
public fun grad(vararg variables: Variable<T>): Point<T> {
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
return variables.map(::deriv).asBuffer()
}
}
/**
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
*
* The partial derivatives are placed in argument `d` variable
*
* Example:
* ```
* val x = Variable(2) // define variable(s) and their values
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
* assertEquals(17.0, y.x) // the value of result (y)
* assertEquals(9.0, x.d) // dy/dx
* ```
*
* @param body the action in [AutoDiffField] context returning [Variable] to differentiate with respect to.
* @return the result of differentiation.
*/
public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return (AutoDiffContext(this)) {
val result = body()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
DerivationResult(result.value, derivatives, this@deriv)
}
}
/**
* Represents field in context of which functions can be derived.
*/
public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
public abstract val context: F
/**
* A variable accessing inner state of derivatives.
* Use this value in inner builders to avoid creating additional derivative bindings.
*/
public abstract var Variable<T>.d: T
/**
* Performs update of derivative after the rest of the formula in the back-pass.
*
* For example, implementation of `sin` function is:
*
* ```
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
* }
* ```
*/
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
/**
*
*/
public abstract fun variable(value: T): Variable<T>
public inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
// Overloads for Double constants
override operator fun Number.plus(b: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d
}
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
override operator fun Number.minus(b: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
override operator fun Variable<T>.minus(b: Number): Variable<T> =
derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
}
/**
* Automatic Differentiation context class.
*/
@PublishedApi
internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0
val derivatives: MutableMap<Variable<T>, T> = hashMapOf()
override val zero: Variable<T> get() = Variable(context.zero)
override val one: Variable<T> get() = Variable(context.one)
/**
* A variable coupled with its derivative. For internal use only
*/
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
override fun variable(value: T): Variable<T> =
VariableWithDeriv(value, context.zero)
override var Variable<T>.d: T
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
@Suppress("UNCHECKED_CAST")
fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
// Basic math (+, -, *, /)
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
a.d += z.d
b.d += z.d
}
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value * b.value }) { z ->
a.d += z.d * b.value
b.d += z.d * a.value
}
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value / b.value }) { z ->
a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value)
}
override fun multiply(a: Variable<T>, k: Number): Variable<T> = derive(variable { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble()
}
}
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> =
pow(x, y.toDouble())
// exp(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
exp(y * ln(x))
// sin(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: Variable<T>): Variable<T> =
derive(variable { tan(x.value) }) { z ->
val c = cos(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: Variable<T>): Variable<T> =
derive(variable { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: Variable<T>): Variable<T> =
derive(variable { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: Variable<T>): Variable<T> =
derive(variable { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: Variable<T>): Variable<T> =
derive(variable { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: Variable<T>): Variable<T> =
derive(variable { tan(x.value) }) { z ->
val c = cosh(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: Variable<T>): Variable<T> =
derive(variable { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: Variable<T>): Variable<T> =
derive(variable { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: Variable<T>): Variable<T> =
derive(variable { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }

View File

@ -6,19 +6,21 @@ import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFails
class ExpressionFieldTest {
val x by symbol
@Test
fun testExpression() {
val context = FunctionalExpressionField(RealField)
val expression = context {
val x = variable("x", 2.0)
val x by binding()
x * x + 2 * x + one
}
assertEquals(expression("x" to 1.0), 4.0)
assertEquals(expression(), 9.0)
assertEquals(expression(x to 1.0), 4.0)
assertFails { expression()}
}
@Test
@ -26,33 +28,33 @@ class ExpressionFieldTest {
val context = FunctionalExpressionField(ComplexField)
val expression = context {
val x = variable("x", Complex(2.0, 0.0))
val x = bind(x)
x * x + 2 * x + one
}
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
assertEquals(expression(), Complex(9.0, 0.0))
assertEquals(expression(x to Complex(1.0, 0.0)), Complex(4.0, 0.0))
//assertEquals(expression(), Complex(9.0, 0.0))
}
@Test
fun separateContext() {
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
val x = variable("x")
val x by binding()
return x * x + 2 * x + one
}
val expression = FunctionalExpressionField(RealField).expression()
assertEquals(expression("x" to 1.0), 4.0)
assertEquals(expression(x to 1.0), 4.0)
}
@Test
fun valueExpression() {
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
val x = variable("x")
val x by binding()
x * x + 2 * x + one
}
val expression = FunctionalExpressionField(RealField).expressionBuilder()
assertEquals(expression("x" to 1.0), 4.0)
assertEquals(expression(x to 1.0), 4.0)
}
}

View File

@ -0,0 +1,277 @@
package kscience.kmath.expressions
import kscience.kmath.operations.RealField
import kscience.kmath.structures.asBuffer
import kotlin.math.PI
import kotlin.math.pow
import kotlin.math.sqrt
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class SimpleAutoDiffTest {
fun d(
vararg bindings: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
fun dx(
xBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>) -> BoundSymbol<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) }
fun dxy(
xBinding: Pair<Symbol, Double>,
yBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>, y: BoundSymbol<Double>) -> BoundSymbol<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) {
body(bind(xBinding.first), bind(yBinding.first))
}
fun diff(block: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>): SimpleAutoDiffExpression<Double, RealField> {
return SimpleAutoDiffExpression(RealField, block)
}
val x by symbol
val y by symbol
val z by symbol
@Test
fun testPlusX2() {
val y = d(x to 3.0) {
// diff w.r.t this x at 3
val x = bind(x)
x + x
}
assertEquals(6.0, y.value) // y = x + x = 6
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
}
@Test
fun testPlus() {
// two variables
val z = d(x to 2.0, y to 3.0) {
val x = bind(x)
val y = bind(y)
x + y
}
assertEquals(5.0, z.value) // z = x + y = 5
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
assertEquals(1.0, z.derivative(y)) // dz/dy = 1
}
@Test
fun testMinus() {
// two variables
val z = d(x to 7.0, y to 3.0) {
val x = bind(x)
val y = bind(y)
x - y
}
assertEquals(4.0, z.value) // z = x - y = 4
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
assertEquals(-1.0, z.derivative(y)) // dz/dy = -1
}
@Test
fun testMulX2() {
val y = dx(x to 3.0) { x ->
// diff w.r.t this x at 3
x * x
}
assertEquals(9.0, y.value) // y = x * x = 9
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
}
@Test
fun testSqr() {
val y = dx(x to 3.0) { x -> sqr(x) }
assertEquals(9.0, y.value) // y = x ^ 2 = 9
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
}
@Test
fun testSqrSqr() {
val y = dx(x to 2.0) { x -> sqr(sqr(x)) }
assertEquals(16.0, y.value) // y = x ^ 4 = 16
assertEquals(32.0, y.derivative(x)) // dy/dx = 4 * x^3 = 32
}
@Test
fun testX3() {
val y = dx(x to 2.0) { x ->
// diff w.r.t this x at 2
x * x * x
}
assertEquals(8.0, y.value) // y = x * x * x = 8
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x * x = 12
}
@Test
fun testDiv() {
val z = dxy(x to 5.0, y to 2.0) { x, y ->
x / y
}
assertEquals(2.5, z.value) // z = x / y = 2.5
assertEquals(0.5, z.derivative(x)) // dz/dx = 1 / y = 0.5
assertEquals(-1.25, z.derivative(y)) // dz/dy = -x / y^2 = -1.25
}
@Test
fun testPow3() {
val y = dx(x to 2.0) { x ->
// diff w.r.t this x at 2
pow(x, 3)
}
assertEquals(8.0, y.value) // y = x ^ 3 = 8
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x ^ 2 = 12
}
@Test
fun testPowFull() {
val z = dxy(x to 2.0, y to 3.0) { x, y ->
pow(x, y)
}
assertApprox(8.0, z.value) // z = x ^ y = 8
assertApprox(12.0, z.derivative(x)) // dz/dx = y * x ^ (y - 1) = 12
assertApprox(8.0 * kotlin.math.ln(2.0), z.derivative(y)) // dz/dy = x ^ y * ln(x)
}
@Test
fun testFromPaper() {
val y = dx(x to 3.0) { x -> 2 * x + x * x * x }
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
assertEquals(29.0, y.derivative(x)) // dy/dx = 2 + 3 * x * x = 29
}
@Test
fun testInnerVariable() {
val y = dx(x to 1.0) { x ->
const(1.0) * x
}
assertEquals(1.0, y.value) // y = x ^ n = 1
assertEquals(1.0, y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
}
@Test
fun testLongChain() {
val n = 10_000
val y = dx(x to 1.0) { x ->
var res = const(1.0)
for (i in 1..n) res *= x
res
}
assertEquals(1.0, y.value) // y = x ^ n = 1
assertEquals(n.toDouble(), y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
}
@Test
fun testExample() {
val y = dx(x to 2.0) { x -> sqr(x) + 5 * x + 3 }
assertEquals(17.0, y.value) // the value of result (y)
assertEquals(9.0, y.derivative(x)) // dy/dx
}
@Test
fun testSqrt() {
val y = dx(x to 16.0) { x -> sqrt(x) }
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
assertEquals(1.0 / 8, y.derivative(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
}
@Test
fun testSin() {
val y = dx(x to PI / 6.0) { x -> sin(x) }
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
assertApprox(sqrt(3.0) / 2, y.derivative(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
}
@Test
fun testCos() {
val y = dx(x to PI / 6) { x -> cos(x) }
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
assertApprox(-0.5, y.derivative(x)) // dy/dx = -sin(pi/6) = -0.5
}
@Test
fun testTan() {
val y = dx(x to PI / 6) { x -> tan(x) }
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
assertApprox(4.0 / 3.0, y.derivative(x)) // dy/dx = sec(pi/6)^2 = 4/3
}
@Test
fun testAsin() {
val y = dx(x to PI / 6) { x -> asin(x) }
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
assertApprox(6.0 / sqrt(36 - PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(36-pi^2)
}
@Test
fun testAcos() {
val y = dx(x to PI / 6) { x -> acos(x) }
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
}
@Test
fun testAtan() {
val y = dx(x to PI / 6) { x -> atan(x) }
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
assertApprox(36.0 / (36.0 + PI * PI), y.derivative(x)) // dy/dx = 36/(36+pi^2)
}
@Test
fun testSinh() {
val y = dx(x to 0.0) { x -> sinh(x) }
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
assertApprox(kotlin.math.cosh(0.0), y.derivative(x)) // dy/dx = cosh(0)
}
@Test
fun testCosh() {
val y = dx(x to 0.0) { x -> cosh(x) }
assertApprox(1.0, y.value) //y = cosh(0)
assertApprox(0.0, y.derivative(x)) // dy/dx = sinh(0)
}
@Test
fun testTanh() {
val y = dx(x to PI / 6) { x -> tanh(x) }
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
}
@Test
fun testAsinh() {
val y = dx(x to PI / 6) { x -> asinh(x) }
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
assertApprox(6.0 / sqrt(36 + PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(pi^2+36)
}
@Test
fun testAcosh() {
val y = dx(x to PI / 6) { x -> acosh(x) }
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
}
@Test
fun testAtanh() {
val y = dx(x to PI / 6) { x -> atanh(x) }
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
assertApprox(-36.0 / (PI * PI - 36.0), y.derivative(x)) // dy/dx = -36/(pi^2-36)
}
@Test
fun testDivGrad() {
val res = dxy(x to 1.0, y to 2.0) { x, y -> x * x + y * y }
assertEquals(6.0, res.div())
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
}
private fun assertApprox(a: Double, b: Double) {
if ((a - b) > 1e-10) assertEquals(a, b)
}
}

View File

@ -1,261 +0,0 @@
package kscience.kmath.misc
import kscience.kmath.operations.RealField
import kscience.kmath.structures.asBuffer
import kotlin.math.PI
import kotlin.math.pow
import kotlin.math.sqrt
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class AutoDiffTest {
inline fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
RealField.deriv(body)
@Test
fun testPlusX2() {
val x = Variable(3.0) // diff w.r.t this x at 3
val y = deriv { x + x }
assertEquals(6.0, y.value) // y = x + x = 6
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
}
@Test
fun testPlus() {
// two variables
val x = Variable(2.0)
val y = Variable(3.0)
val z = deriv { x + y }
assertEquals(5.0, z.value) // z = x + y = 5
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
}
@Test
fun testMinus() {
// two variables
val x = Variable(7.0)
val y = Variable(3.0)
val z = deriv { x - y }
assertEquals(4.0, z.value) // z = x - y = 4
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
}
@Test
fun testMulX2() {
val x = Variable(3.0) // diff w.r.t this x at 3
val y = deriv { x * x }
assertEquals(9.0, y.value) // y = x * x = 9
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
}
@Test
fun testSqr() {
val x = Variable(3.0)
val y = deriv { sqr(x) }
assertEquals(9.0, y.value) // y = x ^ 2 = 9
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
}
@Test
fun testSqrSqr() {
val x = Variable(2.0)
val y = deriv { sqr(sqr(x)) }
assertEquals(16.0, y.value) // y = x ^ 4 = 16
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
}
@Test
fun testX3() {
val x = Variable(2.0) // diff w.r.t this x at 2
val y = deriv { x * x * x }
assertEquals(8.0, y.value) // y = x * x * x = 8
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
}
@Test
fun testDiv() {
val x = Variable(5.0)
val y = Variable(2.0)
val z = deriv { x / y }
assertEquals(2.5, z.value) // z = x / y = 2.5
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
}
@Test
fun testPow3() {
val x = Variable(2.0) // diff w.r.t this x at 2
val y = deriv { pow(x, 3) }
assertEquals(8.0, y.value) // y = x ^ 3 = 8
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
}
@Test
fun testPowFull() {
val x = Variable(2.0)
val y = Variable(3.0)
val z = deriv { pow(x, y) }
assertApprox(8.0, z.value) // z = x ^ y = 8
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
}
@Test
fun testFromPaper() {
val x = Variable(3.0)
val y = deriv { 2 * x + x * x * x }
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
}
@Test
fun testInnerVariable() {
val x = Variable(1.0)
val y = deriv {
Variable(1.0) * x
}
assertEquals(1.0, y.value) // y = x ^ n = 1
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
}
@Test
fun testLongChain() {
val n = 10_000
val x = Variable(1.0)
val y = deriv {
var res = Variable(1.0)
for (i in 1..n) res *= x
res
}
assertEquals(1.0, y.value) // y = x ^ n = 1
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
}
@Test
fun testExample() {
val x = Variable(2.0)
val y = deriv { sqr(x) + 5 * x + 3 }
assertEquals(17.0, y.value) // the value of result (y)
assertEquals(9.0, y.deriv(x)) // dy/dx
}
@Test
fun testSqrt() {
val x = Variable(16.0)
val y = deriv { sqrt(x) }
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
}
@Test
fun testSin() {
val x = Variable(PI / 6.0)
val y = deriv { sin(x) }
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
assertApprox(sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
}
@Test
fun testCos() {
val x = Variable(PI / 6)
val y = deriv { cos(x) }
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(pi/6) = -0.5
}
@Test
fun testTan() {
val x = Variable(PI / 6)
val y = deriv { tan(x) }
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
assertApprox(4.0 / 3.0, y.deriv(x)) // dy/dx = sec(pi/6)^2 = 4/3
}
@Test
fun testAsin() {
val x = Variable(PI / 6)
val y = deriv { asin(x) }
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
assertApprox(6.0 / sqrt(36 - PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(36-pi^2)
}
@Test
fun testAcos() {
val x = Variable(PI / 6)
val y = deriv { acos(x) }
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
}
@Test
fun testAtan() {
val x = Variable(PI / 6)
val y = deriv { atan(x) }
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
assertApprox(36.0 / (36.0 + PI * PI), y.deriv(x)) // dy/dx = 36/(36+pi^2)
}
@Test
fun testSinh() {
val x = Variable(0.0)
val y = deriv { sinh(x) }
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
assertApprox(kotlin.math.cosh(0.0), y.deriv(x)) // dy/dx = cosh(0)
}
@Test
fun testCosh() {
val x = Variable(0.0)
val y = deriv { cosh(x) }
assertApprox(1.0, y.value) //y = cosh(0)
assertApprox(0.0, y.deriv(x)) // dy/dx = sinh(0)
}
@Test
fun testTanh() {
val x = Variable(PI / 6)
val y = deriv { tanh(x) }
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.deriv(x)) // dy/dx = sech(pi/6)^2
}
@Test
fun testAsinh() {
val x = Variable(PI / 6)
val y = deriv { asinh(x) }
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
assertApprox(6.0 / sqrt(36 + PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(pi^2+36)
}
@Test
fun testAcosh() {
val x = Variable(PI / 6)
val y = deriv { acosh(x) }
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
}
@Test
fun testAtanh() {
val x = Variable(PI / 6.0)
val y = deriv { atanh(x) }
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
assertApprox(-36.0 / (PI * PI - 36.0), y.deriv(x)) // dy/dx = -36/(pi^2-36)
}
@Test
fun testDivGrad() {
val x = Variable(1.0)
val y = Variable(2.0)
val res = deriv { x * x + y * y }
assertEquals(6.0, res.div())
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
}
private fun assertApprox(a: Double, b: Double) {
if ((a - b) > 1e-10) assertEquals(a, b)
}
}