forked from kscience/kmath
New Expression API
This commit is contained in:
parent
e44423192d
commit
707ad21f77
14
README.md
14
README.md
@ -53,9 +53,7 @@ can be used for a wide variety of purposes from high performance calculations to
|
||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||
to submit a feature request if you want something to be done first.
|
||||
|
||||
* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
|
||||
|
||||
|
||||
## Planned features
|
||||
|
||||
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
||||
@ -117,6 +115,12 @@ can be used for a wide variety of purposes from high performance calculations to
|
||||
> **Maturity**: EXPERIMENTAL
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-ejml](kmath-ejml)
|
||||
>
|
||||
>
|
||||
> **Maturity**: EXPERIMENTAL
|
||||
<hr/>
|
||||
|
||||
* ### [kmath-for-real](kmath-for-real)
|
||||
>
|
||||
>
|
||||
@ -178,8 +182,8 @@ repositories{
|
||||
}
|
||||
|
||||
dependencies{
|
||||
api("kscience.kmath:kmath-core:0.2.0-dev-1")
|
||||
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-1") for jvm-specific version
|
||||
api("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -24,4 +24,8 @@ subprojects {
|
||||
|
||||
readme {
|
||||
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
||||
}
|
||||
|
||||
apiValidation{
|
||||
validationDisabled = true
|
||||
}
|
2
docs/templates/README-TEMPLATE.md
vendored
2
docs/templates/README-TEMPLATE.md
vendored
@ -107,4 +107,4 @@ with the same artifact names.
|
||||
|
||||
## Contributing
|
||||
|
||||
The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
|
||||
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
|
@ -14,8 +14,8 @@ import kotlin.contracts.contract
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
||||
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
|
||||
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
@ -27,7 +27,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
||||
error("Numeric nodes are not supported by $this")
|
||||
}
|
||||
|
||||
override operator fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -37,7 +37,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||
mstAlgebra: E,
|
||||
block: E.() -> MST
|
||||
block: E.() -> MST,
|
||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||
|
||||
/**
|
||||
@ -116,7 +116,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||
block: MstExtendedField.() -> MST
|
||||
block: MstExtendedField.() -> MST,
|
||||
): MstExpression<T> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInExtendedField(block)
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
package kscience.kmath.asm.internal
|
||||
|
||||
import kscience.kmath.expressions.StringSymbol
|
||||
|
||||
/**
|
||||
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
|
||||
*
|
||||
@ -9,4 +11,4 @@ package kscience.kmath.asm.internal
|
||||
*/
|
||||
@JvmOverloads
|
||||
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
|
||||
this[key] ?: default ?: error("Parameter not found: $key")
|
||||
this[StringSymbol(key.toString())] ?: default ?: error("Parameter not found: $key")
|
||||
|
@ -1,6 +1,5 @@
|
||||
package kscience.kmath.asm
|
||||
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.ast.mstInField
|
||||
import kscience.kmath.ast.mstInRing
|
||||
import kscience.kmath.ast.mstInSpace
|
||||
@ -11,6 +10,7 @@ import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class TestAsmAlgebras {
|
||||
|
||||
@Test
|
||||
fun space() {
|
||||
val res1 = ByteRing.mstInSpace {
|
||||
|
@ -1,48 +1,57 @@
|
||||
package kscience.kmath.commons.expressions
|
||||
|
||||
import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.expressions.ExpressionAlgebra
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kscience.kmath.operations.ExtendedField
|
||||
import kscience.kmath.operations.Field
|
||||
import kscience.kmath.operations.invoke
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
import kotlin.properties.ReadOnlyProperty
|
||||
|
||||
/**
|
||||
* A field over commons-math [DerivativeStructure].
|
||||
*
|
||||
* @property order The derivation order.
|
||||
* @property parameters The map of free parameters.
|
||||
* @property bindings The map of bindings values. All bindings are considered free parameters
|
||||
*/
|
||||
public class DerivativeStructureField(
|
||||
public val order: Int,
|
||||
public val parameters: Map<String, Double>,
|
||||
) : ExtendedField<DerivativeStructure> {
|
||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) }
|
||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) }
|
||||
private val bindings: Map<Symbol, Double>
|
||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
|
||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
|
||||
|
||||
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
|
||||
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
|
||||
/**
|
||||
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||
*/
|
||||
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
||||
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
||||
override val identity: Any = symbol.identity
|
||||
}
|
||||
|
||||
public val variable: ReadOnlyProperty<Any?, DerivativeStructure> = ReadOnlyProperty { _, property ->
|
||||
variables[property.name] ?: error("A variable with name ${property.name} does not exist")
|
||||
/**
|
||||
* Identity-based symbol bindings map
|
||||
*/
|
||||
private val variables: Map<Any?, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
|
||||
key.identity to DerivativeStructureSymbol(key, value)
|
||||
}
|
||||
|
||||
public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
|
||||
variables[name] ?: default ?: error("A variable with name $name does not exist")
|
||||
override fun const(value: Double): DerivativeStructure = DerivativeStructure(order, bindings.size, value)
|
||||
|
||||
public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
|
||||
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||
|
||||
public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
|
||||
return deriv(mapOf(parName to order))
|
||||
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
||||
|
||||
public fun Number.const(): DerivativeStructure = const(toDouble())
|
||||
|
||||
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
||||
return derivative(mapOf(parameter to order))
|
||||
}
|
||||
|
||||
public fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
|
||||
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
|
||||
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double {
|
||||
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
|
||||
public fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
|
||||
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
|
||||
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
||||
|
||||
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
||||
@ -85,48 +94,16 @@ public class DerivativeStructureField(
|
||||
/**
|
||||
* A constructs that creates a derivative structure with required order on-demand
|
||||
*/
|
||||
public class DiffExpression(
|
||||
public class DerivativeStructureExpression(
|
||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||
) : Expression<Double> {
|
||||
public override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
|
||||
0,
|
||||
arguments
|
||||
).function().value
|
||||
) : DifferentiableExpression<Double> {
|
||||
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||
DerivativeStructureField(0, arguments).function().value
|
||||
|
||||
/**
|
||||
* Get the derivative expression with given orders
|
||||
* TODO make result [DiffExpression]
|
||||
*/
|
||||
public fun derivative(orders: Map<String, Int>): Expression<Double> = Expression { arguments ->
|
||||
(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().deriv(orders) }
|
||||
public override fun derivative(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
||||
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
||||
}
|
||||
|
||||
//TODO add gradient and maybe other vector operators
|
||||
}
|
||||
|
||||
public fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
|
||||
public fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
|
||||
|
||||
/**
|
||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
||||
*/
|
||||
public object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
|
||||
public override val zero: DiffExpression = DiffExpression { 0.0.const() }
|
||||
public override val one: DiffExpression = DiffExpression { 1.0.const() }
|
||||
|
||||
public override fun variable(name: String, default: Double?): DiffExpression =
|
||||
DiffExpression { variable(name, default?.const()) }
|
||||
|
||||
public override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
|
||||
|
||||
public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
|
||||
DiffExpression { a.function(this) + b.function(this) }
|
||||
|
||||
public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k }
|
||||
|
||||
public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
|
||||
DiffExpression { a.function(this) * b.function(this) }
|
||||
|
||||
public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
|
||||
DiffExpression { a.function(this) / b.function(this) }
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
package kscience.kmath.commons.expressions
|
||||
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.expressions.*
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.test.Test
|
||||
@ -8,33 +8,37 @@ import kotlin.test.assertEquals
|
||||
|
||||
internal inline fun <R> diff(
|
||||
order: Int,
|
||||
vararg parameters: Pair<String, Double>,
|
||||
block: DerivativeStructureField.() -> R
|
||||
vararg parameters: Pair<Symbol, Double>,
|
||||
block: DerivativeStructureField.() -> R,
|
||||
): R {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||
}
|
||||
|
||||
internal class AutoDiffTest {
|
||||
private val x by symbol
|
||||
private val y by symbol
|
||||
|
||||
@Test
|
||||
fun derivativeStructureFieldTest() {
|
||||
val res: Double = diff(3, "x" to 1.0, "y" to 1.0) {
|
||||
val x by variable
|
||||
val y = variable("y")
|
||||
val res: Double = diff(3, x to 1.0, y to 1.0) {
|
||||
val x = bind(x)//by binding()
|
||||
val y = symbol("y")
|
||||
val z = x * (-sin(x * y) + y)
|
||||
z.deriv("x")
|
||||
z.derivative(x)
|
||||
}
|
||||
println(res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun autoDifTest() {
|
||||
val f = DiffExpression {
|
||||
val x by variable
|
||||
val y by variable
|
||||
val f = DerivativeStructureExpression {
|
||||
val x by binding()
|
||||
val y by binding()
|
||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||
}
|
||||
|
||||
assertEquals(10.0, f("x" to 1.0, "y" to 2.0))
|
||||
assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0))
|
||||
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||
}
|
||||
}
|
@ -12,7 +12,7 @@ The core features of KMath:
|
||||
|
||||
> #### Artifact:
|
||||
>
|
||||
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-1`.
|
||||
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
|
||||
>
|
||||
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
||||
>
|
||||
@ -22,25 +22,28 @@ The core features of KMath:
|
||||
>
|
||||
> ```gradle
|
||||
> repositories {
|
||||
> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-1'
|
||||
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
|
||||
> }
|
||||
> ```
|
||||
> **Gradle Kotlin DSL:**
|
||||
>
|
||||
> ```kotlin
|
||||
> repositories {
|
||||
> maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
> }
|
||||
>
|
||||
> dependencies {
|
||||
> implementation("kscience.kmath:kmath-core:0.2.0-dev-1")
|
||||
> implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||
> }
|
||||
> ```
|
||||
|
@ -41,6 +41,6 @@ readme {
|
||||
feature(
|
||||
id = "autodif",
|
||||
description = "Automatic differentiation",
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt"
|
||||
ref = "src/commonMain/kotlin/kscience/kmath/misc/SimpleAutoDiff.kt"
|
||||
)
|
||||
}
|
@ -1,6 +1,26 @@
|
||||
package kscience.kmath.expressions
|
||||
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kotlin.jvm.JvmName
|
||||
import kotlin.properties.ReadOnlyProperty
|
||||
|
||||
/**
|
||||
* A marker interface for a symbol. A symbol mus have an identity
|
||||
*/
|
||||
public interface Symbol {
|
||||
/**
|
||||
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
|
||||
* By default uses object identity
|
||||
*/
|
||||
public val identity: Any get() = this
|
||||
}
|
||||
|
||||
/**
|
||||
* A [Symbol] with a [String] identity
|
||||
*/
|
||||
public inline class StringSymbol(override val identity: String) : Symbol {
|
||||
override fun toString(): String = identity
|
||||
}
|
||||
|
||||
/**
|
||||
* An elementary function that could be invoked on a map of arguments
|
||||
@ -12,30 +32,81 @@ public fun interface Expression<T> {
|
||||
* @param arguments the map of arguments.
|
||||
* @return the value.
|
||||
*/
|
||||
public operator fun invoke(arguments: Map<String, T>): T
|
||||
public operator fun invoke(arguments: Map<Symbol, T>): T
|
||||
|
||||
public companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* Invlode an expression without parameters
|
||||
*/
|
||||
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
||||
//This method exists to avoid resolution ambiguity of vararg methods
|
||||
|
||||
/**
|
||||
* Calls this expression from arguments.
|
||||
*
|
||||
* @param pairs the pair of arguments' names to values.
|
||||
* @return the value.
|
||||
*/
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
||||
@JvmName("callBySymbol")
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
|
||||
|
||||
@JvmName("callByString")
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||
|
||||
/**
|
||||
* And object that could be differentiated
|
||||
*/
|
||||
public interface Differentiable<T> {
|
||||
public fun derivative(orders: Map<Symbol, Int>): T
|
||||
}
|
||||
|
||||
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||
derivative(mapOf(*orders))
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
||||
|
||||
/**
|
||||
* A context for expression construction
|
||||
*
|
||||
* @param T type of the constants for the expression
|
||||
* @param E type of the actual expression state
|
||||
*/
|
||||
public interface ExpressionAlgebra<T, E> : Algebra<E> {
|
||||
public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
||||
|
||||
/**
|
||||
* Introduce a variable into expression context
|
||||
* Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
|
||||
*/
|
||||
public fun variable(name: String, default: T? = null): E
|
||||
public fun bindOrNull(symbol: Symbol): E?
|
||||
|
||||
/**
|
||||
* Bind a string to a context using [StringSymbol]
|
||||
*/
|
||||
override fun symbol(value: String): E = bind(StringSymbol(value))
|
||||
|
||||
/**
|
||||
* A constant expression which does not depend on arguments
|
||||
*/
|
||||
public fun const(value: T): E
|
||||
}
|
||||
|
||||
/**
|
||||
* Bind a given [Symbol] to this context variable and produce context-specific object.
|
||||
*/
|
||||
public fun <T, E> ExpressionAlgebra<T, E>.bind(symbol: Symbol): E =
|
||||
bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this")
|
||||
|
||||
public val symbol: ReadOnlyProperty<Any?, Symbol> = ReadOnlyProperty { _, property ->
|
||||
StringSymbol(property.name)
|
||||
}
|
||||
|
||||
public fun <T, E> ExpressionAlgebra<T, E>.binding(): ReadOnlyProperty<Any?, E> =
|
||||
ReadOnlyProperty { _, property ->
|
||||
bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
|
||||
}
|
@ -2,39 +2,6 @@ package kscience.kmath.expressions
|
||||
|
||||
import kscience.kmath.operations.*
|
||||
|
||||
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
||||
Expression<T> {
|
||||
override operator fun invoke(arguments: Map<String, T>): T =
|
||||
context.unaryOperation(name, expr.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class FunctionalBinaryOperation<T>(
|
||||
val context: Algebra<T>,
|
||||
val name: String,
|
||||
val first: Expression<T>,
|
||||
val second: Expression<T>
|
||||
) : Expression<T> {
|
||||
override operator fun invoke(arguments: Map<String, T>): T =
|
||||
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
||||
}
|
||||
|
||||
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||
override operator fun invoke(arguments: Map<String, T>): T =
|
||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||
}
|
||||
|
||||
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
||||
override operator fun invoke(arguments: Map<String, T>): T = value
|
||||
}
|
||||
|
||||
internal class FunctionalConstProductExpression<T>(
|
||||
val context: Space<T>,
|
||||
private val expr: Expression<T>,
|
||||
val const: Number
|
||||
) : Expression<T> {
|
||||
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||
}
|
||||
|
||||
/**
|
||||
* A context class for [Expression] construction.
|
||||
*
|
||||
@ -45,24 +12,32 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val
|
||||
/**
|
||||
* Builds an Expression of constant expression which does not depend on arguments.
|
||||
*/
|
||||
public override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
||||
public override fun const(value: T): Expression<T> = Expression { value }
|
||||
|
||||
/**
|
||||
* Builds an Expression to access a variable.
|
||||
*/
|
||||
public override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
||||
public override fun bindOrNull(symbol: Symbol): Expression<T>? = Expression { arguments ->
|
||||
arguments[symbol] ?: error("Argument not found: $symbol")
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||
*/
|
||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
FunctionalBinaryOperation(algebra, operation, left, right)
|
||||
public override fun binaryOperation(
|
||||
operation: String,
|
||||
left: Expression<T>,
|
||||
right: Expression<T>,
|
||||
): Expression<T> = Expression { arguments ->
|
||||
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||
*/
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||
FunctionalUnaryOperation(algebra, operation, arg)
|
||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
|
||||
algebra.unaryOperation(operation, arg.invoke(arguments))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -81,8 +56,9 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||
/**
|
||||
* Builds an Expression of multiplication of expression by number.
|
||||
*/
|
||||
public override fun multiply(a: Expression<T>, k: Number): Expression<T> =
|
||||
FunctionalConstProductExpression(algebra, a, k)
|
||||
public override fun multiply(a: Expression<T>, k: Number): Expression<T> = Expression { arguments ->
|
||||
algebra.multiply(a.invoke(arguments), k)
|
||||
}
|
||||
|
||||
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
||||
@ -118,8 +94,8 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||
FunctionalExpressionRing<T, A>(algebra),
|
||||
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
|
||||
FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>
|
||||
where A : Field<T>, A : NumericAlgebra<T> {
|
||||
/**
|
||||
* Builds an Expression of division an expression by another one.
|
||||
*/
|
||||
|
@ -0,0 +1,329 @@
|
||||
package kscience.kmath.expressions
|
||||
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.operations.*
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
/*
|
||||
* Implementation of backward-mode automatic differentiation.
|
||||
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||
*/
|
||||
|
||||
|
||||
/**
|
||||
* A [Symbol] with bound value
|
||||
*/
|
||||
public interface BoundSymbol<out T> : Symbol {
|
||||
public val value: T
|
||||
}
|
||||
|
||||
/**
|
||||
* Bind a [Symbol] to a [value] and produce [BoundSymbol]
|
||||
*/
|
||||
public fun <T> Symbol.bind(value: T): BoundSymbol<T> = object : BoundSymbol<T> {
|
||||
override val identity = this@bind.identity
|
||||
override val value: T = value
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents result of [withAutoDiff] call.
|
||||
*
|
||||
* @param T the non-nullable type of value.
|
||||
* @param value the value of result.
|
||||
* @property withAutoDiff The mapping of differentiated variables to their derivatives.
|
||||
* @property context The field over [T].
|
||||
*/
|
||||
public class DerivationResult<T : Any>(
|
||||
override val value: T,
|
||||
private val derivativeValues: Map<Any, T>,
|
||||
public val context: Field<T>,
|
||||
) : BoundSymbol<T> {
|
||||
/**
|
||||
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
||||
*/
|
||||
public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero
|
||||
|
||||
/**
|
||||
* Computes the divergence.
|
||||
*/
|
||||
public fun div(): T = context { sum(derivativeValues.values) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the gradient for variables in given order.
|
||||
*/
|
||||
public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> {
|
||||
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
||||
return variables.map(::derivative).asBuffer()
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||
*
|
||||
* The partial derivatives are placed in argument `d` variable
|
||||
*
|
||||
* Example:
|
||||
* ```
|
||||
* val x by symbol // define variable(s) and their values
|
||||
* val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||
* assertEquals(17.0, y.x) // the value of result (y)
|
||||
* assertEquals(9.0, x.d) // dy/dx
|
||||
* ```
|
||||
*
|
||||
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||
* @return the result of differentiation.
|
||||
*/
|
||||
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
||||
bindings: Collection<BoundSymbol<T>>,
|
||||
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
|
||||
): DerivationResult<T> {
|
||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||
|
||||
return AutoDiffContext(this, bindings).derivate(body)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
||||
vararg bindings: Pair<Symbol, T>,
|
||||
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
|
||||
): DerivationResult<T> = withAutoDiff(bindings.map { it.first.bind(it.second) }, body)
|
||||
|
||||
/**
|
||||
* Represents field in context of which functions can be derived.
|
||||
*/
|
||||
public abstract class AutoDiffField<T : Any, F : Field<T>>
|
||||
: Field<BoundSymbol<T>>, ExpressionAlgebra<T, BoundSymbol<T>> {
|
||||
|
||||
public abstract val context: F
|
||||
|
||||
/**
|
||||
* A variable accessing inner state of derivatives.
|
||||
* Use this value in inner builders to avoid creating additional derivative bindings.
|
||||
*/
|
||||
public abstract var BoundSymbol<T>.d: T
|
||||
|
||||
/**
|
||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||
*
|
||||
* For example, implementation of `sin` function is:
|
||||
*
|
||||
* ```
|
||||
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
||||
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
||||
|
||||
public inline fun const(block: F.() -> T): BoundSymbol<T> = const(context.block())
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
override operator fun Number.plus(b: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
override operator fun BoundSymbol<T>.plus(b: Number): BoundSymbol<T> = b.plus(this)
|
||||
|
||||
override operator fun Number.minus(b: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||
|
||||
override operator fun BoundSymbol<T>.minus(b: Number): BoundSymbol<T> =
|
||||
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||
}
|
||||
|
||||
/**
|
||||
* Automatic Differentiation context class.
|
||||
*/
|
||||
private class AutoDiffContext<T : Any, F : Field<T>>(
|
||||
override val context: F,
|
||||
bindings: Collection<BoundSymbol<T>>,
|
||||
) : AutoDiffField<T, F>() {
|
||||
// this stack contains pairs of blocks and values to apply them to
|
||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||
private var sp: Int = 0
|
||||
private val derivatives: MutableMap<Any, T> = hashMapOf()
|
||||
override val zero: BoundSymbol<T> get() = const(context.zero)
|
||||
override val one: BoundSymbol<T> get() = const(context.one)
|
||||
|
||||
/**
|
||||
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
|
||||
* with respect to this variable.
|
||||
*
|
||||
* @param T the non-nullable type of value.
|
||||
* @property value The value of this variable.
|
||||
*/
|
||||
private class AutoDiffVariableWithDeriv<T : Any>(override val value: T, var d: T) : BoundSymbol<T>
|
||||
|
||||
private val bindings: Map<Any, BoundSymbol<T>> = bindings.associateBy { it.identity }
|
||||
|
||||
override fun bindOrNull(symbol: Symbol): BoundSymbol<T>? = bindings[symbol.identity]
|
||||
|
||||
override fun const(value: T): BoundSymbol<T> = AutoDiffVariableWithDeriv(value, context.zero)
|
||||
|
||||
override var BoundSymbol<T>.d: T
|
||||
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero
|
||||
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||
// save block to stack for backward pass
|
||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||
stack[sp++] = block
|
||||
stack[sp++] = value
|
||||
return value
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun runBackwardPass() {
|
||||
while (sp > 0) {
|
||||
val value = stack[--sp]
|
||||
val block = stack[--sp] as F.(Any?) -> Unit
|
||||
context.block(value)
|
||||
}
|
||||
}
|
||||
|
||||
// Basic math (+, -, *, /)
|
||||
|
||||
override fun add(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { a.value + b.value }) { z ->
|
||||
a.d += z.d
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
override fun multiply(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { a.value * b.value }) { z ->
|
||||
a.d += z.d * b.value
|
||||
b.d += z.d * a.value
|
||||
}
|
||||
|
||||
override fun divide(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { a.value / b.value }) { z ->
|
||||
a.d += z.d / b.value
|
||||
b.d -= z.d * a.value / (b.value * b.value)
|
||||
}
|
||||
|
||||
override fun multiply(a: BoundSymbol<T>, k: Number): BoundSymbol<T> =
|
||||
derive(const { k.toDouble() * a.value }) { z ->
|
||||
a.d += z.d * k.toDouble()
|
||||
}
|
||||
|
||||
inline fun derivate(function: AutoDiffField<T, F>.() -> BoundSymbol<T>): DerivationResult<T> {
|
||||
val result = function()
|
||||
result.d = context.one // computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
return DerivationResult(result.value, derivatives, context)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A constructs that creates a derivative structure with required order on-demand
|
||||
*/
|
||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||
public val field: F,
|
||||
public val function: AutoDiffField<T, F>.() -> BoundSymbol<T>,
|
||||
) : DifferentiableExpression<T> {
|
||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||
val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||
return AutoDiffContext(field, bindings).function().value
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the derivative expression with given orders
|
||||
*/
|
||||
public override fun derivative(orders: Map<Symbol, Int>): Expression<T> {
|
||||
val dSymbol = orders.entries.singleOrNull { it.value == 1 }
|
||||
?: error("SimpleAutoDiff supports only first order derivatives")
|
||||
return Expression { arguments ->
|
||||
val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||
val derivationResult = AutoDiffContext(field, bindings).derivate(function)
|
||||
derivationResult.derivative(dSymbol.key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Extensions for differentiation of various basic mathematical functions
|
||||
|
||||
// x ^ 2
|
||||
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||
|
||||
// x ^ 1/2
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||
|
||||
// x ^ y (const)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||
x: BoundSymbol<T>,
|
||||
y: Double,
|
||||
): BoundSymbol<T> =
|
||||
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||
x: BoundSymbol<T>,
|
||||
y: Int,
|
||||
): BoundSymbol<T> =
|
||||
pow(x, y.toDouble())
|
||||
|
||||
// exp(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||
|
||||
// ln(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||
|
||||
// x ^ y (any)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||
x: BoundSymbol<T>,
|
||||
y: BoundSymbol<T>,
|
||||
): BoundSymbol<T> =
|
||||
exp(y * ln(x))
|
||||
|
||||
// sin(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||
|
||||
// cos(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { tan(x.value) }) { z ->
|
||||
val c = cos(x.value)
|
||||
x.d += z.d / (c * c)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { tan(x.value) }) { z ->
|
||||
val c = cosh(x.value)
|
||||
x.d += z.d / (c * c)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
||||
|
@ -1,266 +0,0 @@
|
||||
package kscience.kmath.misc
|
||||
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.operations.*
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
/*
|
||||
* Implementation of backward-mode automatic differentiation.
|
||||
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||
*/
|
||||
|
||||
/**
|
||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
||||
* with respect to this variable.
|
||||
*
|
||||
* @param T the non-nullable type of value.
|
||||
* @property value The value of this variable.
|
||||
*/
|
||||
public open class Variable<T : Any>(public val value: T)
|
||||
|
||||
/**
|
||||
* Represents result of [deriv] call.
|
||||
*
|
||||
* @param T the non-nullable type of value.
|
||||
* @param value the value of result.
|
||||
* @property deriv The mapping of differentiated variables to their derivatives.
|
||||
* @property context The field over [T].
|
||||
*/
|
||||
public class DerivationResult<T : Any>(
|
||||
value: T,
|
||||
public val deriv: Map<Variable<T>, T>,
|
||||
public val context: Field<T>
|
||||
) : Variable<T>(value) {
|
||||
/**
|
||||
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
||||
*/
|
||||
public fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
|
||||
|
||||
/**
|
||||
* Computes the divergence.
|
||||
*/
|
||||
public fun div(): T = context { sum(deriv.values) }
|
||||
|
||||
/**
|
||||
* Computes the gradient for variables in given order.
|
||||
*/
|
||||
public fun grad(vararg variables: Variable<T>): Point<T> {
|
||||
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
||||
return variables.map(::deriv).asBuffer()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||
*
|
||||
* The partial derivatives are placed in argument `d` variable
|
||||
*
|
||||
* Example:
|
||||
* ```
|
||||
* val x = Variable(2) // define variable(s) and their values
|
||||
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||
* assertEquals(17.0, y.x) // the value of result (y)
|
||||
* assertEquals(9.0, x.d) // dy/dx
|
||||
* ```
|
||||
*
|
||||
* @param body the action in [AutoDiffField] context returning [Variable] to differentiate with respect to.
|
||||
* @return the result of differentiation.
|
||||
*/
|
||||
public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
|
||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||
|
||||
return (AutoDiffContext(this)) {
|
||||
val result = body()
|
||||
result.d = context.one // computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
DerivationResult(result.value, derivatives, this@deriv)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents field in context of which functions can be derived.
|
||||
*/
|
||||
public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
||||
public abstract val context: F
|
||||
|
||||
/**
|
||||
* A variable accessing inner state of derivatives.
|
||||
* Use this value in inner builders to avoid creating additional derivative bindings.
|
||||
*/
|
||||
public abstract var Variable<T>.d: T
|
||||
|
||||
/**
|
||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||
*
|
||||
* For example, implementation of `sin` function is:
|
||||
*
|
||||
* ```
|
||||
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
||||
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public abstract fun variable(value: T): Variable<T>
|
||||
|
||||
public inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
override operator fun Number.plus(b: Variable<T>): Variable<T> =
|
||||
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
||||
|
||||
override operator fun Number.minus(b: Variable<T>): Variable<T> =
|
||||
derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||
|
||||
override operator fun Variable<T>.minus(b: Number): Variable<T> =
|
||||
derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||
}
|
||||
|
||||
/**
|
||||
* Automatic Differentiation context class.
|
||||
*/
|
||||
@PublishedApi
|
||||
internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
|
||||
// this stack contains pairs of blocks and values to apply them to
|
||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||
private var sp: Int = 0
|
||||
val derivatives: MutableMap<Variable<T>, T> = hashMapOf()
|
||||
override val zero: Variable<T> get() = Variable(context.zero)
|
||||
override val one: Variable<T> get() = Variable(context.one)
|
||||
|
||||
/**
|
||||
* A variable coupled with its derivative. For internal use only
|
||||
*/
|
||||
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
|
||||
|
||||
override fun variable(value: T): Variable<T> =
|
||||
VariableWithDeriv(value, context.zero)
|
||||
|
||||
override var Variable<T>.d: T
|
||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
||||
set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||
// save block to stack for backward pass
|
||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||
stack[sp++] = block
|
||||
stack[sp++] = value
|
||||
return value
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun runBackwardPass() {
|
||||
while (sp > 0) {
|
||||
val value = stack[--sp]
|
||||
val block = stack[--sp] as F.(Any?) -> Unit
|
||||
context.block(value)
|
||||
}
|
||||
}
|
||||
|
||||
// Basic math (+, -, *, /)
|
||||
|
||||
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
|
||||
a.d += z.d
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value * b.value }) { z ->
|
||||
a.d += z.d * b.value
|
||||
b.d += z.d * a.value
|
||||
}
|
||||
|
||||
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value / b.value }) { z ->
|
||||
a.d += z.d / b.value
|
||||
b.d -= z.d * a.value / (b.value * b.value)
|
||||
}
|
||||
|
||||
override fun multiply(a: Variable<T>, k: Number): Variable<T> = derive(variable { k.toDouble() * a.value }) { z ->
|
||||
a.d += z.d * k.toDouble()
|
||||
}
|
||||
}
|
||||
|
||||
// Extensions for differentiation of various basic mathematical functions
|
||||
|
||||
// x ^ 2
|
||||
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
|
||||
derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||
|
||||
// x ^ 1/2
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
|
||||
derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||
|
||||
// x ^ y (const)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
|
||||
derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> =
|
||||
pow(x, y.toDouble())
|
||||
|
||||
// exp(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
|
||||
derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||
|
||||
// ln(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
|
||||
derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||
|
||||
// x ^ y (any)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
|
||||
exp(y * ln(x))
|
||||
|
||||
// sin(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
|
||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||
|
||||
// cos(x)
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
||||
derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: Variable<T>): Variable<T> =
|
||||
derive(variable { tan(x.value) }) { z ->
|
||||
val c = cos(x.value)
|
||||
x.d += z.d / (c * c)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: Variable<T>): Variable<T> =
|
||||
derive(variable { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: Variable<T>): Variable<T> =
|
||||
derive(variable { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: Variable<T>): Variable<T> =
|
||||
derive(variable { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { tan(x.value) }) { z ->
|
||||
val c = cosh(x.value)
|
||||
x.d += z.d / (c * c)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
||||
|
||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: Variable<T>): Variable<T> =
|
||||
derive(variable { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
||||
|
@ -6,19 +6,21 @@ import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFails
|
||||
|
||||
class ExpressionFieldTest {
|
||||
val x by symbol
|
||||
@Test
|
||||
fun testExpression() {
|
||||
val context = FunctionalExpressionField(RealField)
|
||||
|
||||
val expression = context {
|
||||
val x = variable("x", 2.0)
|
||||
val x by binding()
|
||||
x * x + 2 * x + one
|
||||
}
|
||||
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
assertEquals(expression(), 9.0)
|
||||
assertEquals(expression(x to 1.0), 4.0)
|
||||
assertFails { expression()}
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -26,33 +28,33 @@ class ExpressionFieldTest {
|
||||
val context = FunctionalExpressionField(ComplexField)
|
||||
|
||||
val expression = context {
|
||||
val x = variable("x", Complex(2.0, 0.0))
|
||||
val x = bind(x)
|
||||
x * x + 2 * x + one
|
||||
}
|
||||
|
||||
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
||||
assertEquals(expression(), Complex(9.0, 0.0))
|
||||
assertEquals(expression(x to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
||||
//assertEquals(expression(), Complex(9.0, 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun separateContext() {
|
||||
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
|
||||
val x = variable("x")
|
||||
val x by binding()
|
||||
return x * x + 2 * x + one
|
||||
}
|
||||
|
||||
val expression = FunctionalExpressionField(RealField).expression()
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
assertEquals(expression(x to 1.0), 4.0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun valueExpression() {
|
||||
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
|
||||
val x = variable("x")
|
||||
val x by binding()
|
||||
x * x + 2 * x + one
|
||||
}
|
||||
|
||||
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
||||
assertEquals(expression("x" to 1.0), 4.0)
|
||||
assertEquals(expression(x to 1.0), 4.0)
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,277 @@
|
||||
package kscience.kmath.expressions
|
||||
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class SimpleAutoDiffTest {
|
||||
fun d(
|
||||
vararg bindings: Pair<Symbol, Double>,
|
||||
body: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>,
|
||||
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
|
||||
|
||||
fun dx(
|
||||
xBinding: Pair<Symbol, Double>,
|
||||
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>) -> BoundSymbol<Double>,
|
||||
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||
|
||||
fun dxy(
|
||||
xBinding: Pair<Symbol, Double>,
|
||||
yBinding: Pair<Symbol, Double>,
|
||||
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>, y: BoundSymbol<Double>) -> BoundSymbol<Double>,
|
||||
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) {
|
||||
body(bind(xBinding.first), bind(yBinding.first))
|
||||
}
|
||||
|
||||
fun diff(block: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
||||
return SimpleAutoDiffExpression(RealField, block)
|
||||
}
|
||||
|
||||
val x by symbol
|
||||
val y by symbol
|
||||
val z by symbol
|
||||
|
||||
@Test
|
||||
fun testPlusX2() {
|
||||
val y = d(x to 3.0) {
|
||||
// diff w.r.t this x at 3
|
||||
val x = bind(x)
|
||||
x + x
|
||||
}
|
||||
assertEquals(6.0, y.value) // y = x + x = 6
|
||||
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPlus() {
|
||||
// two variables
|
||||
val z = d(x to 2.0, y to 3.0) {
|
||||
val x = bind(x)
|
||||
val y = bind(y)
|
||||
x + y
|
||||
}
|
||||
assertEquals(5.0, z.value) // z = x + y = 5
|
||||
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
|
||||
assertEquals(1.0, z.derivative(y)) // dz/dy = 1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
// two variables
|
||||
val z = d(x to 7.0, y to 3.0) {
|
||||
val x = bind(x)
|
||||
val y = bind(y)
|
||||
|
||||
x - y
|
||||
}
|
||||
assertEquals(4.0, z.value) // z = x - y = 4
|
||||
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
|
||||
assertEquals(-1.0, z.derivative(y)) // dz/dy = -1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMulX2() {
|
||||
val y = dx(x to 3.0) { x ->
|
||||
// diff w.r.t this x at 3
|
||||
x * x
|
||||
}
|
||||
assertEquals(9.0, y.value) // y = x * x = 9
|
||||
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqr() {
|
||||
val y = dx(x to 3.0) { x -> sqr(x) }
|
||||
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
||||
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqrSqr() {
|
||||
val y = dx(x to 2.0) { x -> sqr(sqr(x)) }
|
||||
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
||||
assertEquals(32.0, y.derivative(x)) // dy/dx = 4 * x^3 = 32
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testX3() {
|
||||
val y = dx(x to 2.0) { x ->
|
||||
// diff w.r.t this x at 2
|
||||
x * x * x
|
||||
}
|
||||
assertEquals(8.0, y.value) // y = x * x * x = 8
|
||||
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x * x = 12
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDiv() {
|
||||
val z = dxy(x to 5.0, y to 2.0) { x, y ->
|
||||
x / y
|
||||
}
|
||||
assertEquals(2.5, z.value) // z = x / y = 2.5
|
||||
assertEquals(0.5, z.derivative(x)) // dz/dx = 1 / y = 0.5
|
||||
assertEquals(-1.25, z.derivative(y)) // dz/dy = -x / y^2 = -1.25
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPow3() {
|
||||
val y = dx(x to 2.0) { x ->
|
||||
// diff w.r.t this x at 2
|
||||
pow(x, 3)
|
||||
}
|
||||
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
||||
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x ^ 2 = 12
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPowFull() {
|
||||
val z = dxy(x to 2.0, y to 3.0) { x, y ->
|
||||
pow(x, y)
|
||||
}
|
||||
assertApprox(8.0, z.value) // z = x ^ y = 8
|
||||
assertApprox(12.0, z.derivative(x)) // dz/dx = y * x ^ (y - 1) = 12
|
||||
assertApprox(8.0 * kotlin.math.ln(2.0), z.derivative(y)) // dz/dy = x ^ y * ln(x)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFromPaper() {
|
||||
val y = dx(x to 3.0) { x -> 2 * x + x * x * x }
|
||||
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
||||
assertEquals(29.0, y.derivative(x)) // dy/dx = 2 + 3 * x * x = 29
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testInnerVariable() {
|
||||
val y = dx(x to 1.0) { x ->
|
||||
const(1.0) * x
|
||||
}
|
||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||
assertEquals(1.0, y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLongChain() {
|
||||
val n = 10_000
|
||||
val y = dx(x to 1.0) { x ->
|
||||
var res = const(1.0)
|
||||
for (i in 1..n) res *= x
|
||||
res
|
||||
}
|
||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||
assertEquals(n.toDouble(), y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testExample() {
|
||||
val y = dx(x to 2.0) { x -> sqr(x) + 5 * x + 3 }
|
||||
assertEquals(17.0, y.value) // the value of result (y)
|
||||
assertEquals(9.0, y.derivative(x)) // dy/dx
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqrt() {
|
||||
val y = dx(x to 16.0) { x -> sqrt(x) }
|
||||
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
||||
assertEquals(1.0 / 8, y.derivative(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSin() {
|
||||
val y = dx(x to PI / 6.0) { x -> sin(x) }
|
||||
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
||||
assertApprox(sqrt(3.0) / 2, y.derivative(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCos() {
|
||||
val y = dx(x to PI / 6) { x -> cos(x) }
|
||||
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
|
||||
assertApprox(-0.5, y.derivative(x)) // dy/dx = -sin(pi/6) = -0.5
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTan() {
|
||||
val y = dx(x to PI / 6) { x -> tan(x) }
|
||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
|
||||
assertApprox(4.0 / 3.0, y.derivative(x)) // dy/dx = sec(pi/6)^2 = 4/3
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAsin() {
|
||||
val y = dx(x to PI / 6) { x -> asin(x) }
|
||||
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
|
||||
assertApprox(6.0 / sqrt(36 - PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(36-pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAcos() {
|
||||
val y = dx(x to PI / 6) { x -> acos(x) }
|
||||
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
|
||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAtan() {
|
||||
val y = dx(x to PI / 6) { x -> atan(x) }
|
||||
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
|
||||
assertApprox(36.0 / (36.0 + PI * PI), y.derivative(x)) // dy/dx = 36/(36+pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSinh() {
|
||||
val y = dx(x to 0.0) { x -> sinh(x) }
|
||||
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
|
||||
assertApprox(kotlin.math.cosh(0.0), y.derivative(x)) // dy/dx = cosh(0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCosh() {
|
||||
val y = dx(x to 0.0) { x -> cosh(x) }
|
||||
assertApprox(1.0, y.value) //y = cosh(0)
|
||||
assertApprox(0.0, y.derivative(x)) // dy/dx = sinh(0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTanh() {
|
||||
val y = dx(x to PI / 6) { x -> tanh(x) }
|
||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
||||
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAsinh() {
|
||||
val y = dx(x to PI / 6) { x -> asinh(x) }
|
||||
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
|
||||
assertApprox(6.0 / sqrt(36 + PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(pi^2+36)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAcosh() {
|
||||
val y = dx(x to PI / 6) { x -> acosh(x) }
|
||||
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
|
||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAtanh() {
|
||||
val y = dx(x to PI / 6) { x -> atanh(x) }
|
||||
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
|
||||
assertApprox(-36.0 / (PI * PI - 36.0), y.derivative(x)) // dy/dx = -36/(pi^2-36)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivGrad() {
|
||||
val res = dxy(x to 1.0, y to 2.0) { x, y -> x * x + y * y }
|
||||
assertEquals(6.0, res.div())
|
||||
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
||||
}
|
||||
|
||||
private fun assertApprox(a: Double, b: Double) {
|
||||
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||
}
|
||||
}
|
@ -1,261 +0,0 @@
|
||||
package kscience.kmath.misc
|
||||
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class AutoDiffTest {
|
||||
inline fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
|
||||
RealField.deriv(body)
|
||||
|
||||
@Test
|
||||
fun testPlusX2() {
|
||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
||||
val y = deriv { x + x }
|
||||
assertEquals(6.0, y.value) // y = x + x = 6
|
||||
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPlus() {
|
||||
// two variables
|
||||
val x = Variable(2.0)
|
||||
val y = Variable(3.0)
|
||||
val z = deriv { x + y }
|
||||
assertEquals(5.0, z.value) // z = x + y = 5
|
||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMinus() {
|
||||
// two variables
|
||||
val x = Variable(7.0)
|
||||
val y = Variable(3.0)
|
||||
val z = deriv { x - y }
|
||||
assertEquals(4.0, z.value) // z = x - y = 4
|
||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
||||
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMulX2() {
|
||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
||||
val y = deriv { x * x }
|
||||
assertEquals(9.0, y.value) // y = x * x = 9
|
||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqr() {
|
||||
val x = Variable(3.0)
|
||||
val y = deriv { sqr(x) }
|
||||
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqrSqr() {
|
||||
val x = Variable(2.0)
|
||||
val y = deriv { sqr(sqr(x)) }
|
||||
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
||||
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testX3() {
|
||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
||||
val y = deriv { x * x * x }
|
||||
assertEquals(8.0, y.value) // y = x * x * x = 8
|
||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDiv() {
|
||||
val x = Variable(5.0)
|
||||
val y = Variable(2.0)
|
||||
val z = deriv { x / y }
|
||||
assertEquals(2.5, z.value) // z = x / y = 2.5
|
||||
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
|
||||
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPow3() {
|
||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
||||
val y = deriv { pow(x, 3) }
|
||||
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPowFull() {
|
||||
val x = Variable(2.0)
|
||||
val y = Variable(3.0)
|
||||
val z = deriv { pow(x, y) }
|
||||
assertApprox(8.0, z.value) // z = x ^ y = 8
|
||||
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
|
||||
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFromPaper() {
|
||||
val x = Variable(3.0)
|
||||
val y = deriv { 2 * x + x * x * x }
|
||||
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
||||
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testInnerVariable() {
|
||||
val x = Variable(1.0)
|
||||
val y = deriv {
|
||||
Variable(1.0) * x
|
||||
}
|
||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLongChain() {
|
||||
val n = 10_000
|
||||
val x = Variable(1.0)
|
||||
val y = deriv {
|
||||
var res = Variable(1.0)
|
||||
for (i in 1..n) res *= x
|
||||
res
|
||||
}
|
||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testExample() {
|
||||
val x = Variable(2.0)
|
||||
val y = deriv { sqr(x) + 5 * x + 3 }
|
||||
assertEquals(17.0, y.value) // the value of result (y)
|
||||
assertEquals(9.0, y.deriv(x)) // dy/dx
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSqrt() {
|
||||
val x = Variable(16.0)
|
||||
val y = deriv { sqrt(x) }
|
||||
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
||||
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSin() {
|
||||
val x = Variable(PI / 6.0)
|
||||
val y = deriv { sin(x) }
|
||||
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
||||
assertApprox(sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCos() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { cos(x) }
|
||||
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
|
||||
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(pi/6) = -0.5
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTan() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { tan(x) }
|
||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
|
||||
assertApprox(4.0 / 3.0, y.deriv(x)) // dy/dx = sec(pi/6)^2 = 4/3
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAsin() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { asin(x) }
|
||||
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
|
||||
assertApprox(6.0 / sqrt(36 - PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(36-pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAcos() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { acos(x) }
|
||||
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
|
||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAtan() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { atan(x) }
|
||||
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
|
||||
assertApprox(36.0 / (36.0 + PI * PI), y.deriv(x)) // dy/dx = 36/(36+pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSinh() {
|
||||
val x = Variable(0.0)
|
||||
val y = deriv { sinh(x) }
|
||||
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
|
||||
assertApprox(kotlin.math.cosh(0.0), y.deriv(x)) // dy/dx = cosh(0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCosh() {
|
||||
val x = Variable(0.0)
|
||||
val y = deriv { cosh(x) }
|
||||
assertApprox(1.0, y.value) //y = cosh(0)
|
||||
assertApprox(0.0, y.deriv(x)) // dy/dx = sinh(0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTanh() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { tanh(x) }
|
||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
||||
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.deriv(x)) // dy/dx = sech(pi/6)^2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAsinh() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { asinh(x) }
|
||||
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
|
||||
assertApprox(6.0 / sqrt(36 + PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(pi^2+36)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAcosh() {
|
||||
val x = Variable(PI / 6)
|
||||
val y = deriv { acosh(x) }
|
||||
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
|
||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAtanh() {
|
||||
val x = Variable(PI / 6.0)
|
||||
val y = deriv { atanh(x) }
|
||||
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
|
||||
assertApprox(-36.0 / (PI * PI - 36.0), y.deriv(x)) // dy/dx = -36/(pi^2-36)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivGrad() {
|
||||
val x = Variable(1.0)
|
||||
val y = Variable(2.0)
|
||||
val res = deriv { x * x + y * y }
|
||||
assertEquals(6.0, res.div())
|
||||
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
||||
}
|
||||
|
||||
private fun assertApprox(a: Double, b: Double) {
|
||||
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user