Feature/diff api #154
12
README.md
12
README.md
@ -54,8 +54,6 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||||
to submit a feature request if you want something to be done first.
|
to submit a feature request if you want something to be done first.
|
||||||
|
|
||||||
* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
|
|
||||||
|
|
||||||
## Planned features
|
## Planned features
|
||||||
|
|
||||||
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
||||||
@ -117,6 +115,12 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
<hr/>
|
||||||
|
|
||||||
|
* ### [kmath-ejml](kmath-ejml)
|
||||||
|
>
|
||||||
|
>
|
||||||
|
> **Maturity**: EXPERIMENTAL
|
||||||
|
<hr/>
|
||||||
|
|
||||||
* ### [kmath-for-real](kmath-for-real)
|
* ### [kmath-for-real](kmath-for-real)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
@ -178,8 +182,8 @@ repositories{
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies{
|
dependencies{
|
||||||
api("kscience.kmath:kmath-core:0.2.0-dev-1")
|
api("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||||
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-1") for jvm-specific version
|
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -25,3 +25,7 @@ subprojects {
|
|||||||
readme {
|
readme {
|
||||||
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apiValidation{
|
||||||
|
validationDisabled = true
|
||||||
|
}
|
2
docs/templates/README-TEMPLATE.md
vendored
2
docs/templates/README-TEMPLATE.md
vendored
@ -107,4 +107,4 @@ with the same artifact names.
|
|||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
|
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
|
@ -14,8 +14,8 @@ import kotlin.contracts.contract
|
|||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
||||||
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
|
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||||
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
@ -27,7 +27,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
error("Numeric nodes are not supported by $this")
|
error("Numeric nodes are not supported by $this")
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -37,7 +37,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||||
mstAlgebra: E,
|
mstAlgebra: E,
|
||||||
block: E.() -> MST
|
block: E.() -> MST,
|
||||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -116,7 +116,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||||
block: MstExtendedField.() -> MST
|
block: MstExtendedField.() -> MST,
|
||||||
): MstExpression<T> {
|
): MstExpression<T> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInExtendedField(block)
|
return algebra.mstInExtendedField(block)
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
package kscience.kmath.asm.internal
|
package kscience.kmath.asm.internal
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.StringSymbol
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
|
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
|
||||||
*
|
*
|
||||||
@ -9,4 +11,4 @@ package kscience.kmath.asm.internal
|
|||||||
*/
|
*/
|
||||||
@JvmOverloads
|
@JvmOverloads
|
||||||
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
|
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
|
||||||
this[key] ?: default ?: error("Parameter not found: $key")
|
this[StringSymbol(key.toString())] ?: default ?: error("Parameter not found: $key")
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
package kscience.kmath.asm
|
package kscience.kmath.asm
|
||||||
|
|
||||||
import kscience.kmath.asm.compile
|
|
||||||
import kscience.kmath.ast.mstInField
|
import kscience.kmath.ast.mstInField
|
||||||
import kscience.kmath.ast.mstInRing
|
import kscience.kmath.ast.mstInRing
|
||||||
import kscience.kmath.ast.mstInSpace
|
import kscience.kmath.ast.mstInSpace
|
||||||
@ -11,6 +10,7 @@ import kotlin.test.Test
|
|||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class TestAsmAlgebras {
|
internal class TestAsmAlgebras {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun space() {
|
fun space() {
|
||||||
val res1 = ByteRing.mstInSpace {
|
val res1 = ByteRing.mstInSpace {
|
||||||
|
@ -1,48 +1,57 @@
|
|||||||
package kscience.kmath.commons.expressions
|
package kscience.kmath.commons.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.expressions.ExpressionAlgebra
|
import kscience.kmath.expressions.ExpressionAlgebra
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
import kscience.kmath.operations.ExtendedField
|
import kscience.kmath.operations.ExtendedField
|
||||||
import kscience.kmath.operations.Field
|
|
||||||
import kscience.kmath.operations.invoke
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
import kotlin.properties.ReadOnlyProperty
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field over commons-math [DerivativeStructure].
|
* A field over commons-math [DerivativeStructure].
|
||||||
*
|
*
|
||||||
* @property order The derivation order.
|
* @property order The derivation order.
|
||||||
* @property parameters The map of free parameters.
|
* @property bindings The map of bindings values. All bindings are considered free parameters
|
||||||
*/
|
*/
|
||||||
public class DerivativeStructureField(
|
public class DerivativeStructureField(
|
||||||
public val order: Int,
|
public val order: Int,
|
||||||
public val parameters: Map<String, Double>,
|
private val bindings: Map<Symbol, Double>
|
||||||
) : ExtendedField<DerivativeStructure> {
|
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) }
|
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
|
||||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) }
|
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
|
||||||
|
|
||||||
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
|
/**
|
||||||
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
|
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||||
|
*/
|
||||||
|
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
||||||
|
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
||||||
|
override val identity: Any = symbol.identity
|
||||||
}
|
}
|
||||||
|
|
||||||
public val variable: ReadOnlyProperty<Any?, DerivativeStructure> = ReadOnlyProperty { _, property ->
|
/**
|
||||||
variables[property.name] ?: error("A variable with name ${property.name} does not exist")
|
* Identity-based symbol bindings map
|
||||||
|
*/
|
||||||
|
private val variables: Map<Any?, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
|
||||||
|
key.identity to DerivativeStructureSymbol(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
|
override fun const(value: Double): DerivativeStructure = DerivativeStructure(order, bindings.size, value)
|
||||||
variables[name] ?: default ?: error("A variable with name $name does not exist")
|
|
||||||
|
|
||||||
public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
|
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
|
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
||||||
return deriv(mapOf(parName to order))
|
|
||||||
|
public fun Number.const(): DerivativeStructure = const(toDouble())
|
||||||
|
|
||||||
|
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
||||||
|
return derivative(mapOf(parameter to order))
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
|
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double {
|
||||||
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
|
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
|
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
|
||||||
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
||||||
|
|
||||||
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
||||||
@ -85,48 +94,16 @@ public class DerivativeStructureField(
|
|||||||
/**
|
/**
|
||||||
* A constructs that creates a derivative structure with required order on-demand
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
*/
|
*/
|
||||||
public class DiffExpression(
|
public class DerivativeStructureExpression(
|
||||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||||
) : Expression<Double> {
|
) : DifferentiableExpression<Double> {
|
||||||
public override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
|
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
0,
|
DerivativeStructureField(0, arguments).function().value
|
||||||
arguments
|
|
||||||
).function().value
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the derivative expression with given orders
|
* Get the derivative expression with given orders
|
||||||
* TODO make result [DiffExpression]
|
|
||||||
*/
|
*/
|
||||||
public fun derivative(orders: Map<String, Int>): Expression<Double> = Expression { arguments ->
|
public override fun derivative(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
||||||
(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().deriv(orders) }
|
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO add gradient and maybe other vector operators
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
|
|
||||||
public fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
|
||||||
*/
|
|
||||||
public object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
|
|
||||||
public override val zero: DiffExpression = DiffExpression { 0.0.const() }
|
|
||||||
public override val one: DiffExpression = DiffExpression { 1.0.const() }
|
|
||||||
|
|
||||||
public override fun variable(name: String, default: Double?): DiffExpression =
|
|
||||||
DiffExpression { variable(name, default?.const()) }
|
|
||||||
|
|
||||||
public override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
|
|
||||||
|
|
||||||
public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
|
|
||||||
DiffExpression { a.function(this) + b.function(this) }
|
|
||||||
|
|
||||||
public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k }
|
|
||||||
|
|
||||||
public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
|
|
||||||
DiffExpression { a.function(this) * b.function(this) }
|
|
||||||
|
|
||||||
public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
|
|
||||||
DiffExpression { a.function(this) / b.function(this) }
|
|
||||||
}
|
}
|
@ -1,6 +1,6 @@
|
|||||||
package kscience.kmath.commons.expressions
|
package kscience.kmath.commons.expressions
|
||||||
|
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.*
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -8,33 +8,37 @@ import kotlin.test.assertEquals
|
|||||||
|
|
||||||
internal inline fun <R> diff(
|
internal inline fun <R> diff(
|
||||||
order: Int,
|
order: Int,
|
||||||
vararg parameters: Pair<String, Double>,
|
vararg parameters: Pair<Symbol, Double>,
|
||||||
block: DerivativeStructureField.() -> R
|
block: DerivativeStructureField.() -> R,
|
||||||
): R {
|
): R {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AutoDiffTest {
|
internal class AutoDiffTest {
|
||||||
|
private val x by symbol
|
||||||
|
private val y by symbol
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun derivativeStructureFieldTest() {
|
fun derivativeStructureFieldTest() {
|
||||||
val res: Double = diff(3, "x" to 1.0, "y" to 1.0) {
|
val res: Double = diff(3, x to 1.0, y to 1.0) {
|
||||||
val x by variable
|
val x = bind(x)//by binding()
|
||||||
val y = variable("y")
|
val y = symbol("y")
|
||||||
val z = x * (-sin(x * y) + y)
|
val z = x * (-sin(x * y) + y)
|
||||||
z.deriv("x")
|
z.derivative(x)
|
||||||
}
|
}
|
||||||
|
println(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun autoDifTest() {
|
fun autoDifTest() {
|
||||||
val f = DiffExpression {
|
val f = DerivativeStructureExpression {
|
||||||
val x by variable
|
val x by binding()
|
||||||
val y by variable
|
val y by binding()
|
||||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(10.0, f("x" to 1.0, "y" to 2.0))
|
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||||
assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0))
|
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -12,7 +12,7 @@ The core features of KMath:
|
|||||||
|
|
||||||
> #### Artifact:
|
> #### Artifact:
|
||||||
>
|
>
|
||||||
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-1`.
|
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
|
||||||
>
|
>
|
||||||
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
||||||
>
|
>
|
||||||
@ -22,25 +22,28 @@ The core features of KMath:
|
|||||||
>
|
>
|
||||||
> ```gradle
|
> ```gradle
|
||||||
> repositories {
|
> repositories {
|
||||||
|
> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
|
||||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||||
|
|
||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-1'
|
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
> **Gradle Kotlin DSL:**
|
> **Gradle Kotlin DSL:**
|
||||||
>
|
>
|
||||||
> ```kotlin
|
> ```kotlin
|
||||||
> repositories {
|
> repositories {
|
||||||
|
> maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
> maven("https://dl.bintray.com/mipt-npm/dev")
|
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation("kscience.kmath:kmath-core:0.2.0-dev-1")
|
> implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
|
@ -41,6 +41,6 @@ readme {
|
|||||||
feature(
|
feature(
|
||||||
id = "autodif",
|
id = "autodif",
|
||||||
description = "Automatic differentiation",
|
description = "Automatic differentiation",
|
||||||
ref = "src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt"
|
ref = "src/commonMain/kotlin/kscience/kmath/misc/SimpleAutoDiff.kt"
|
||||||
)
|
)
|
||||||
}
|
}
|
@ -1,6 +1,26 @@
|
|||||||
package kscience.kmath.expressions
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
|
import kotlin.jvm.JvmName
|
||||||
|
import kotlin.properties.ReadOnlyProperty
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A marker interface for a symbol. A symbol mus have an identity
|
||||||
|
*/
|
||||||
|
public interface Symbol {
|
||||||
|
/**
|
||||||
|
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
|
||||||
|
* By default uses object identity
|
||||||
|
*/
|
||||||
|
public val identity: Any get() = this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A [Symbol] with a [String] identity
|
||||||
|
*/
|
||||||
|
public inline class StringSymbol(override val identity: String) : Symbol {
|
||||||
|
override fun toString(): String = identity
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An elementary function that could be invoked on a map of arguments
|
* An elementary function that could be invoked on a map of arguments
|
||||||
@ -12,30 +32,81 @@ public fun interface Expression<T> {
|
|||||||
* @param arguments the map of arguments.
|
* @param arguments the map of arguments.
|
||||||
* @return the value.
|
* @return the value.
|
||||||
*/
|
*/
|
||||||
public operator fun invoke(arguments: Map<String, T>): T
|
public operator fun invoke(arguments: Map<Symbol, T>): T
|
||||||
|
|
||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invlode an expression without parameters
|
||||||
|
*/
|
||||||
|
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
||||||
|
//This method exists to avoid resolution ambiguity of vararg methods
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls this expression from arguments.
|
* Calls this expression from arguments.
|
||||||
*
|
*
|
||||||
* @param pairs the pair of arguments' names to values.
|
* @param pairs the pair of arguments' names to values.
|
||||||
* @return the value.
|
* @return the value.
|
||||||
*/
|
*/
|
||||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
@JvmName("callBySymbol")
|
||||||
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
|
||||||
|
|
||||||
|
@JvmName("callByString")
|
||||||
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||||
|
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* And object that could be differentiated
|
||||||
|
*/
|
||||||
|
public interface Differentiable<T> {
|
||||||
|
public fun derivative(orders: Map<Symbol, Int>): T
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||||
|
derivative(mapOf(*orders))
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for expression construction
|
* A context for expression construction
|
||||||
|
*
|
||||||
|
* @param T type of the constants for the expression
|
||||||
|
* @param E type of the actual expression state
|
||||||
*/
|
*/
|
||||||
public interface ExpressionAlgebra<T, E> : Algebra<E> {
|
public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Introduce a variable into expression context
|
* Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
|
||||||
*/
|
*/
|
||||||
public fun variable(name: String, default: T? = null): E
|
public fun bindOrNull(symbol: Symbol): E?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind a string to a context using [StringSymbol]
|
||||||
|
*/
|
||||||
|
override fun symbol(value: String): E = bind(StringSymbol(value))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A constant expression which does not depend on arguments
|
* A constant expression which does not depend on arguments
|
||||||
*/
|
*/
|
||||||
public fun const(value: T): E
|
public fun const(value: T): E
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind a given [Symbol] to this context variable and produce context-specific object.
|
||||||
|
*/
|
||||||
|
public fun <T, E> ExpressionAlgebra<T, E>.bind(symbol: Symbol): E =
|
||||||
|
bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this")
|
||||||
|
|
||||||
|
public val symbol: ReadOnlyProperty<Any?, Symbol> = ReadOnlyProperty { _, property ->
|
||||||
|
StringSymbol(property.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T, E> ExpressionAlgebra<T, E>.binding(): ReadOnlyProperty<Any?, E> =
|
||||||
|
ReadOnlyProperty { _, property ->
|
||||||
|
bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
|
||||||
|
}
|
@ -2,39 +2,6 @@ package kscience.kmath.expressions
|
|||||||
|
|
||||||
import kscience.kmath.operations.*
|
import kscience.kmath.operations.*
|
||||||
|
|
||||||
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
|
||||||
Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
|
||||||
context.unaryOperation(name, expr.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalBinaryOperation<T>(
|
|
||||||
val context: Algebra<T>,
|
|
||||||
val name: String,
|
|
||||||
val first: Expression<T>,
|
|
||||||
val second: Expression<T>
|
|
||||||
) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
|
||||||
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
|
||||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = value
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalConstProductExpression<T>(
|
|
||||||
val context: Space<T>,
|
|
||||||
private val expr: Expression<T>,
|
|
||||||
val const: Number
|
|
||||||
) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context class for [Expression] construction.
|
* A context class for [Expression] construction.
|
||||||
*
|
*
|
||||||
@ -45,24 +12,32 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of constant expression which does not depend on arguments.
|
* Builds an Expression of constant expression which does not depend on arguments.
|
||||||
*/
|
*/
|
||||||
public override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
public override fun const(value: T): Expression<T> = Expression { value }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression to access a variable.
|
* Builds an Expression to access a variable.
|
||||||
*/
|
*/
|
||||||
public override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
public override fun bindOrNull(symbol: Symbol): Expression<T>? = Expression { arguments ->
|
||||||
|
arguments[symbol] ?: error("Argument not found: $symbol")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||||
*/
|
*/
|
||||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(
|
||||||
FunctionalBinaryOperation(algebra, operation, left, right)
|
operation: String,
|
||||||
|
left: Expression<T>,
|
||||||
|
right: Expression<T>,
|
||||||
|
): Expression<T> = Expression { arguments ->
|
||||||
|
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||||
*/
|
*/
|
||||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
|
||||||
FunctionalUnaryOperation(algebra, operation, arg)
|
algebra.unaryOperation(operation, arg.invoke(arguments))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -81,8 +56,9 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of multiplication of expression by number.
|
* Builds an Expression of multiplication of expression by number.
|
||||||
*/
|
*/
|
||||||
public override fun multiply(a: Expression<T>, k: Number): Expression<T> =
|
public override fun multiply(a: Expression<T>, k: Number): Expression<T> = Expression { arguments ->
|
||||||
FunctionalConstProductExpression(algebra, a, k)
|
algebra.multiply(a.invoke(arguments), k)
|
||||||
|
}
|
||||||
|
|
||||||
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||||
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
||||||
@ -118,8 +94,8 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
|
|||||||
}
|
}
|
||||||
|
|
||||||
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||||
FunctionalExpressionRing<T, A>(algebra),
|
FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>
|
||||||
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
|
where A : Field<T>, A : NumericAlgebra<T> {
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of division an expression by another one.
|
* Builds an Expression of division an expression by another one.
|
||||||
*/
|
*/
|
||||||
|
@ -0,0 +1,329 @@
|
|||||||
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.operations.*
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Implementation of backward-mode automatic differentiation.
|
||||||
|
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
|||||||
|
* A [Symbol] with bound value
|
||||||
|
*/
|
||||||
|
public interface BoundSymbol<out T> : Symbol {
|
||||||
|
public val value: T
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind a [Symbol] to a [value] and produce [BoundSymbol]
|
||||||
|
*/
|
||||||
|
public fun <T> Symbol.bind(value: T): BoundSymbol<T> = object : BoundSymbol<T> {
|
||||||
|
override val identity = this@bind.identity
|
||||||
|
override val value: T = value
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents result of [withAutoDiff] call.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @param value the value of result.
|
||||||
|
* @property withAutoDiff The mapping of differentiated variables to their derivatives.
|
||||||
|
* @property context The field over [T].
|
||||||
|
*/
|
||||||
|
public class DerivationResult<T : Any>(
|
||||||
|
override val value: T,
|
||||||
|
private val derivativeValues: Map<Any, T>,
|
||||||
|
public val context: Field<T>,
|
||||||
|
) : BoundSymbol<T> {
|
||||||
|
/**
|
||||||
|
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
||||||
|
*/
|
||||||
|
public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the divergence.
|
||||||
|
*/
|
||||||
|
public fun div(): T = context { sum(derivativeValues.values) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the gradient for variables in given order.
|
||||||
|
*/
|
||||||
|
public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> {
|
||||||
|
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
||||||
|
return variables.map(::derivative).asBuffer()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||||
|
*
|
||||||
|
* The partial derivatives are placed in argument `d` variable
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* ```
|
||||||
|
* val x by symbol // define variable(s) and their values
|
||||||
|
* val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||||
|
* assertEquals(17.0, y.x) // the value of result (y)
|
||||||
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
|
* ```
|
||||||
|
*
|
||||||
|
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||||
|
* @return the result of differentiation.
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
||||||
|
bindings: Collection<BoundSymbol<T>>,
|
||||||
|
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
|
||||||
|
): DerivationResult<T> {
|
||||||
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
|
return AutoDiffContext(this, bindings).derivate(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
||||||
|
vararg bindings: Pair<Symbol, T>,
|
||||||
|
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
|
||||||
|
): DerivationResult<T> = withAutoDiff(bindings.map { it.first.bind(it.second) }, body)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents field in context of which functions can be derived.
|
||||||
|
*/
|
||||||
|
public abstract class AutoDiffField<T : Any, F : Field<T>>
|
||||||
|
: Field<BoundSymbol<T>>, ExpressionAlgebra<T, BoundSymbol<T>> {
|
||||||
|
|
||||||
|
public abstract val context: F
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A variable accessing inner state of derivatives.
|
||||||
|
* Use this value in inner builders to avoid creating additional derivative bindings.
|
||||||
|
*/
|
||||||
|
public abstract var BoundSymbol<T>.d: T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||||
|
*
|
||||||
|
* For example, implementation of `sin` function is:
|
||||||
|
*
|
||||||
|
* ```
|
||||||
|
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
||||||
|
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
||||||
|
|
||||||
|
public inline fun const(block: F.() -> T): BoundSymbol<T> = const(context.block())
|
||||||
|
|
||||||
|
// Overloads for Double constants
|
||||||
|
|
||||||
|
override operator fun Number.plus(b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
||||||
|
b.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun BoundSymbol<T>.plus(b: Number): BoundSymbol<T> = b.plus(this)
|
||||||
|
|
||||||
|
override operator fun Number.minus(b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||||
|
|
||||||
|
override operator fun BoundSymbol<T>.minus(b: Number): BoundSymbol<T> =
|
||||||
|
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatic Differentiation context class.
|
||||||
|
*/
|
||||||
|
private class AutoDiffContext<T : Any, F : Field<T>>(
|
||||||
|
override val context: F,
|
||||||
|
bindings: Collection<BoundSymbol<T>>,
|
||||||
|
) : AutoDiffField<T, F>() {
|
||||||
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
|
private var sp: Int = 0
|
||||||
|
private val derivatives: MutableMap<Any, T> = hashMapOf()
|
||||||
|
override val zero: BoundSymbol<T> get() = const(context.zero)
|
||||||
|
override val one: BoundSymbol<T> get() = const(context.one)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
|
||||||
|
* with respect to this variable.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @property value The value of this variable.
|
||||||
|
*/
|
||||||
|
private class AutoDiffVariableWithDeriv<T : Any>(override val value: T, var d: T) : BoundSymbol<T>
|
||||||
|
|
||||||
|
private val bindings: Map<Any, BoundSymbol<T>> = bindings.associateBy { it.identity }
|
||||||
|
|
||||||
|
override fun bindOrNull(symbol: Symbol): BoundSymbol<T>? = bindings[symbol.identity]
|
||||||
|
|
||||||
|
override fun const(value: T): BoundSymbol<T> = AutoDiffVariableWithDeriv(value, context.zero)
|
||||||
|
|
||||||
|
override var BoundSymbol<T>.d: T
|
||||||
|
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero
|
||||||
|
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||||
|
// save block to stack for backward pass
|
||||||
|
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||||
|
stack[sp++] = block
|
||||||
|
stack[sp++] = value
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
fun runBackwardPass() {
|
||||||
|
while (sp > 0) {
|
||||||
|
val value = stack[--sp]
|
||||||
|
val block = stack[--sp] as F.(Any?) -> Unit
|
||||||
|
context.block(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
|
override fun add(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { a.value + b.value }) { z ->
|
||||||
|
a.d += z.d
|
||||||
|
b.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { a.value * b.value }) { z ->
|
||||||
|
a.d += z.d * b.value
|
||||||
|
b.d += z.d * a.value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun divide(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { a.value / b.value }) { z ->
|
||||||
|
a.d += z.d / b.value
|
||||||
|
b.d -= z.d * a.value / (b.value * b.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: BoundSymbol<T>, k: Number): BoundSymbol<T> =
|
||||||
|
derive(const { k.toDouble() * a.value }) { z ->
|
||||||
|
a.d += z.d * k.toDouble()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun derivate(function: AutoDiffField<T, F>.() -> BoundSymbol<T>): DerivationResult<T> {
|
||||||
|
val result = function()
|
||||||
|
result.d = context.one // computing derivative w.r.t result
|
||||||
|
runBackwardPass()
|
||||||
|
return DerivationResult(result.value, derivatives, context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
|
*/
|
||||||
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
|
public val field: F,
|
||||||
|
public val function: AutoDiffField<T, F>.() -> BoundSymbol<T>,
|
||||||
|
) : DifferentiableExpression<T> {
|
||||||
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
|
val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
|
return AutoDiffContext(field, bindings).function().value
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the derivative expression with given orders
|
||||||
|
*/
|
||||||
|
public override fun derivative(orders: Map<Symbol, Int>): Expression<T> {
|
||||||
|
val dSymbol = orders.entries.singleOrNull { it.value == 1 }
|
||||||
|
?: error("SimpleAutoDiff supports only first order derivatives")
|
||||||
|
return Expression { arguments ->
|
||||||
|
val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
|
val derivationResult = AutoDiffContext(field, bindings).derivate(function)
|
||||||
|
derivationResult.derivative(dSymbol.key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
|
// x ^ 2
|
||||||
|
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||||
|
|
||||||
|
// x ^ 1/2
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||||
|
|
||||||
|
// x ^ y (const)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||||
|
x: BoundSymbol<T>,
|
||||||
|
y: Double,
|
||||||
|
): BoundSymbol<T> =
|
||||||
|
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||||
|
x: BoundSymbol<T>,
|
||||||
|
y: Int,
|
||||||
|
): BoundSymbol<T> =
|
||||||
|
pow(x, y.toDouble())
|
||||||
|
|
||||||
|
// exp(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||||
|
|
||||||
|
// ln(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||||
|
|
||||||
|
// x ^ y (any)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||||
|
x: BoundSymbol<T>,
|
||||||
|
y: BoundSymbol<T>,
|
||||||
|
): BoundSymbol<T> =
|
||||||
|
exp(y * ln(x))
|
||||||
|
|
||||||
|
// sin(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||||
|
|
||||||
|
// cos(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { tan(x.value) }) { z ->
|
||||||
|
val c = cos(x.value)
|
||||||
|
x.d += z.d / (c * c)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { tan(x.value) }) { z ->
|
||||||
|
val c = cosh(x.value)
|
||||||
|
x.d += z.d / (c * c)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
||||||
|
|
@ -1,266 +0,0 @@
|
|||||||
package kscience.kmath.misc
|
|
||||||
|
|
||||||
import kscience.kmath.linear.Point
|
|
||||||
import kscience.kmath.operations.*
|
|
||||||
import kscience.kmath.structures.asBuffer
|
|
||||||
import kotlin.contracts.InvocationKind
|
|
||||||
import kotlin.contracts.contract
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Implementation of backward-mode automatic differentiation.
|
|
||||||
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
|
||||||
* with respect to this variable.
|
|
||||||
*
|
|
||||||
* @param T the non-nullable type of value.
|
|
||||||
* @property value The value of this variable.
|
|
||||||
*/
|
|
||||||
public open class Variable<T : Any>(public val value: T)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents result of [deriv] call.
|
|
||||||
*
|
|
||||||
* @param T the non-nullable type of value.
|
|
||||||
* @param value the value of result.
|
|
||||||
* @property deriv The mapping of differentiated variables to their derivatives.
|
|
||||||
* @property context The field over [T].
|
|
||||||
*/
|
|
||||||
public class DerivationResult<T : Any>(
|
|
||||||
value: T,
|
|
||||||
public val deriv: Map<Variable<T>, T>,
|
|
||||||
public val context: Field<T>
|
|
||||||
) : Variable<T>(value) {
|
|
||||||
/**
|
|
||||||
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
|
||||||
*/
|
|
||||||
public fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the divergence.
|
|
||||||
*/
|
|
||||||
public fun div(): T = context { sum(deriv.values) }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the gradient for variables in given order.
|
|
||||||
*/
|
|
||||||
public fun grad(vararg variables: Variable<T>): Point<T> {
|
|
||||||
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
|
||||||
return variables.map(::deriv).asBuffer()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
|
||||||
*
|
|
||||||
* The partial derivatives are placed in argument `d` variable
|
|
||||||
*
|
|
||||||
* Example:
|
|
||||||
* ```
|
|
||||||
* val x = Variable(2) // define variable(s) and their values
|
|
||||||
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
|
||||||
* assertEquals(17.0, y.x) // the value of result (y)
|
|
||||||
* assertEquals(9.0, x.d) // dy/dx
|
|
||||||
* ```
|
|
||||||
*
|
|
||||||
* @param body the action in [AutoDiffField] context returning [Variable] to differentiate with respect to.
|
|
||||||
* @return the result of differentiation.
|
|
||||||
*/
|
|
||||||
public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
|
|
||||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
|
|
||||||
return (AutoDiffContext(this)) {
|
|
||||||
val result = body()
|
|
||||||
result.d = context.one // computing derivative w.r.t result
|
|
||||||
runBackwardPass()
|
|
||||||
DerivationResult(result.value, derivatives, this@deriv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents field in context of which functions can be derived.
|
|
||||||
*/
|
|
||||||
public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
|
||||||
public abstract val context: F
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A variable accessing inner state of derivatives.
|
|
||||||
* Use this value in inner builders to avoid creating additional derivative bindings.
|
|
||||||
*/
|
|
||||||
public abstract var Variable<T>.d: T
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
|
||||||
*
|
|
||||||
* For example, implementation of `sin` function is:
|
|
||||||
*
|
|
||||||
* ```
|
|
||||||
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
|
||||||
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
|
||||||
* }
|
|
||||||
* ```
|
|
||||||
*/
|
|
||||||
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public abstract fun variable(value: T): Variable<T>
|
|
||||||
|
|
||||||
public inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
|
||||||
|
|
||||||
// Overloads for Double constants
|
|
||||||
|
|
||||||
override operator fun Number.plus(b: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
|
|
||||||
b.d += z.d
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
|
||||||
|
|
||||||
override operator fun Number.minus(b: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
|
||||||
|
|
||||||
override operator fun Variable<T>.minus(b: Number): Variable<T> =
|
|
||||||
derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Automatic Differentiation context class.
|
|
||||||
*/
|
|
||||||
@PublishedApi
|
|
||||||
internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
|
||||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
|
||||||
private var sp: Int = 0
|
|
||||||
val derivatives: MutableMap<Variable<T>, T> = hashMapOf()
|
|
||||||
override val zero: Variable<T> get() = Variable(context.zero)
|
|
||||||
override val one: Variable<T> get() = Variable(context.one)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A variable coupled with its derivative. For internal use only
|
|
||||||
*/
|
|
||||||
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
|
|
||||||
|
|
||||||
override fun variable(value: T): Variable<T> =
|
|
||||||
VariableWithDeriv(value, context.zero)
|
|
||||||
|
|
||||||
override var Variable<T>.d: T
|
|
||||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
|
||||||
set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
|
||||||
// save block to stack for backward pass
|
|
||||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
|
||||||
stack[sp++] = block
|
|
||||||
stack[sp++] = value
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
fun runBackwardPass() {
|
|
||||||
while (sp > 0) {
|
|
||||||
val value = stack[--sp]
|
|
||||||
val block = stack[--sp] as F.(Any?) -> Unit
|
|
||||||
context.block(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic math (+, -, *, /)
|
|
||||||
|
|
||||||
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
|
|
||||||
a.d += z.d
|
|
||||||
b.d += z.d
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value * b.value }) { z ->
|
|
||||||
a.d += z.d * b.value
|
|
||||||
b.d += z.d * a.value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value / b.value }) { z ->
|
|
||||||
a.d += z.d / b.value
|
|
||||||
b.d -= z.d * a.value / (b.value * b.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: Variable<T>, k: Number): Variable<T> = derive(variable { k.toDouble() * a.value }) { z ->
|
|
||||||
a.d += z.d * k.toDouble()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
|
||||||
|
|
||||||
// x ^ 2
|
|
||||||
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
|
||||||
|
|
||||||
// x ^ 1/2
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
|
||||||
|
|
||||||
// x ^ y (const)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
|
|
||||||
derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> =
|
|
||||||
pow(x, y.toDouble())
|
|
||||||
|
|
||||||
// exp(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
|
|
||||||
|
|
||||||
// ln(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
|
|
||||||
|
|
||||||
// x ^ y (any)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
|
|
||||||
exp(y * ln(x))
|
|
||||||
|
|
||||||
// sin(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
|
||||||
|
|
||||||
// cos(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { tan(x.value) }) { z ->
|
|
||||||
val c = cos(x.value)
|
|
||||||
x.d += z.d / (c * c)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { tan(x.value) }) { z ->
|
|
||||||
val c = cosh(x.value)
|
|
||||||
x.d += z.d / (c * c)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
|
||||||
|
|
@ -6,19 +6,21 @@ import kscience.kmath.operations.RealField
|
|||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertFails
|
||||||
|
|
||||||
class ExpressionFieldTest {
|
class ExpressionFieldTest {
|
||||||
|
val x by symbol
|
||||||
@Test
|
@Test
|
||||||
fun testExpression() {
|
fun testExpression() {
|
||||||
val context = FunctionalExpressionField(RealField)
|
val context = FunctionalExpressionField(RealField)
|
||||||
|
|
||||||
val expression = context {
|
val expression = context {
|
||||||
val x = variable("x", 2.0)
|
val x by binding()
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression(x to 1.0), 4.0)
|
||||||
assertEquals(expression(), 9.0)
|
assertFails { expression()}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -26,33 +28,33 @@ class ExpressionFieldTest {
|
|||||||
val context = FunctionalExpressionField(ComplexField)
|
val context = FunctionalExpressionField(ComplexField)
|
||||||
|
|
||||||
val expression = context {
|
val expression = context {
|
||||||
val x = variable("x", Complex(2.0, 0.0))
|
val x = bind(x)
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
assertEquals(expression(x to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
||||||
assertEquals(expression(), Complex(9.0, 0.0))
|
//assertEquals(expression(), Complex(9.0, 0.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun separateContext() {
|
fun separateContext() {
|
||||||
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
|
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
|
||||||
val x = variable("x")
|
val x by binding()
|
||||||
return x * x + 2 * x + one
|
return x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
val expression = FunctionalExpressionField(RealField).expression()
|
val expression = FunctionalExpressionField(RealField).expression()
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression(x to 1.0), 4.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun valueExpression() {
|
fun valueExpression() {
|
||||||
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
|
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
|
||||||
val x = variable("x")
|
val x by binding()
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression(x to 1.0), 4.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,277 @@
|
|||||||
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.PI
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
class SimpleAutoDiffTest {
|
||||||
|
fun d(
|
||||||
|
vararg bindings: Pair<Symbol, Double>,
|
||||||
|
body: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>,
|
||||||
|
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
|
||||||
|
|
||||||
|
fun dx(
|
||||||
|
xBinding: Pair<Symbol, Double>,
|
||||||
|
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>) -> BoundSymbol<Double>,
|
||||||
|
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||||
|
|
||||||
|
fun dxy(
|
||||||
|
xBinding: Pair<Symbol, Double>,
|
||||||
|
yBinding: Pair<Symbol, Double>,
|
||||||
|
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>, y: BoundSymbol<Double>) -> BoundSymbol<Double>,
|
||||||
|
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) {
|
||||||
|
body(bind(xBinding.first), bind(yBinding.first))
|
||||||
|
}
|
||||||
|
|
||||||
|
fun diff(block: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
||||||
|
return SimpleAutoDiffExpression(RealField, block)
|
||||||
|
}
|
||||||
|
|
||||||
|
val x by symbol
|
||||||
|
val y by symbol
|
||||||
|
val z by symbol
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlusX2() {
|
||||||
|
val y = d(x to 3.0) {
|
||||||
|
// diff w.r.t this x at 3
|
||||||
|
val x = bind(x)
|
||||||
|
x + x
|
||||||
|
}
|
||||||
|
assertEquals(6.0, y.value) // y = x + x = 6
|
||||||
|
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlus() {
|
||||||
|
// two variables
|
||||||
|
val z = d(x to 2.0, y to 3.0) {
|
||||||
|
val x = bind(x)
|
||||||
|
val y = bind(y)
|
||||||
|
x + y
|
||||||
|
}
|
||||||
|
assertEquals(5.0, z.value) // z = x + y = 5
|
||||||
|
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
|
||||||
|
assertEquals(1.0, z.derivative(y)) // dz/dy = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMinus() {
|
||||||
|
// two variables
|
||||||
|
val z = d(x to 7.0, y to 3.0) {
|
||||||
|
val x = bind(x)
|
||||||
|
val y = bind(y)
|
||||||
|
|
||||||
|
x - y
|
||||||
|
}
|
||||||
|
assertEquals(4.0, z.value) // z = x - y = 4
|
||||||
|
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
|
||||||
|
assertEquals(-1.0, z.derivative(y)) // dz/dy = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMulX2() {
|
||||||
|
val y = dx(x to 3.0) { x ->
|
||||||
|
// diff w.r.t this x at 3
|
||||||
|
x * x
|
||||||
|
}
|
||||||
|
assertEquals(9.0, y.value) // y = x * x = 9
|
||||||
|
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqr() {
|
||||||
|
val y = dx(x to 3.0) { x -> sqr(x) }
|
||||||
|
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
||||||
|
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrSqr() {
|
||||||
|
val y = dx(x to 2.0) { x -> sqr(sqr(x)) }
|
||||||
|
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
||||||
|
assertEquals(32.0, y.derivative(x)) // dy/dx = 4 * x^3 = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testX3() {
|
||||||
|
val y = dx(x to 2.0) { x ->
|
||||||
|
// diff w.r.t this x at 2
|
||||||
|
x * x * x
|
||||||
|
}
|
||||||
|
assertEquals(8.0, y.value) // y = x * x * x = 8
|
||||||
|
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x * x = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDiv() {
|
||||||
|
val z = dxy(x to 5.0, y to 2.0) { x, y ->
|
||||||
|
x / y
|
||||||
|
}
|
||||||
|
assertEquals(2.5, z.value) // z = x / y = 2.5
|
||||||
|
assertEquals(0.5, z.derivative(x)) // dz/dx = 1 / y = 0.5
|
||||||
|
assertEquals(-1.25, z.derivative(y)) // dz/dy = -x / y^2 = -1.25
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPow3() {
|
||||||
|
val y = dx(x to 2.0) { x ->
|
||||||
|
// diff w.r.t this x at 2
|
||||||
|
pow(x, 3)
|
||||||
|
}
|
||||||
|
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
||||||
|
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x ^ 2 = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPowFull() {
|
||||||
|
val z = dxy(x to 2.0, y to 3.0) { x, y ->
|
||||||
|
pow(x, y)
|
||||||
|
}
|
||||||
|
assertApprox(8.0, z.value) // z = x ^ y = 8
|
||||||
|
assertApprox(12.0, z.derivative(x)) // dz/dx = y * x ^ (y - 1) = 12
|
||||||
|
assertApprox(8.0 * kotlin.math.ln(2.0), z.derivative(y)) // dz/dy = x ^ y * ln(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testFromPaper() {
|
||||||
|
val y = dx(x to 3.0) { x -> 2 * x + x * x * x }
|
||||||
|
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
||||||
|
assertEquals(29.0, y.derivative(x)) // dy/dx = 2 + 3 * x * x = 29
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testInnerVariable() {
|
||||||
|
val y = dx(x to 1.0) { x ->
|
||||||
|
const(1.0) * x
|
||||||
|
}
|
||||||
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
|
assertEquals(1.0, y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLongChain() {
|
||||||
|
val n = 10_000
|
||||||
|
val y = dx(x to 1.0) { x ->
|
||||||
|
var res = const(1.0)
|
||||||
|
for (i in 1..n) res *= x
|
||||||
|
res
|
||||||
|
}
|
||||||
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
|
assertEquals(n.toDouble(), y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testExample() {
|
||||||
|
val y = dx(x to 2.0) { x -> sqr(x) + 5 * x + 3 }
|
||||||
|
assertEquals(17.0, y.value) // the value of result (y)
|
||||||
|
assertEquals(9.0, y.derivative(x)) // dy/dx
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrt() {
|
||||||
|
val y = dx(x to 16.0) { x -> sqrt(x) }
|
||||||
|
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
||||||
|
assertEquals(1.0 / 8, y.derivative(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSin() {
|
||||||
|
val y = dx(x to PI / 6.0) { x -> sin(x) }
|
||||||
|
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
||||||
|
assertApprox(sqrt(3.0) / 2, y.derivative(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCos() {
|
||||||
|
val y = dx(x to PI / 6) { x -> cos(x) }
|
||||||
|
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
|
||||||
|
assertApprox(-0.5, y.derivative(x)) // dy/dx = -sin(pi/6) = -0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTan() {
|
||||||
|
val y = dx(x to PI / 6) { x -> tan(x) }
|
||||||
|
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
|
||||||
|
assertApprox(4.0 / 3.0, y.derivative(x)) // dy/dx = sec(pi/6)^2 = 4/3
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAsin() {
|
||||||
|
val y = dx(x to PI / 6) { x -> asin(x) }
|
||||||
|
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
|
||||||
|
assertApprox(6.0 / sqrt(36 - PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(36-pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAcos() {
|
||||||
|
val y = dx(x to PI / 6) { x -> acos(x) }
|
||||||
|
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
|
||||||
|
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAtan() {
|
||||||
|
val y = dx(x to PI / 6) { x -> atan(x) }
|
||||||
|
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
|
||||||
|
assertApprox(36.0 / (36.0 + PI * PI), y.derivative(x)) // dy/dx = 36/(36+pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSinh() {
|
||||||
|
val y = dx(x to 0.0) { x -> sinh(x) }
|
||||||
|
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
|
||||||
|
assertApprox(kotlin.math.cosh(0.0), y.derivative(x)) // dy/dx = cosh(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCosh() {
|
||||||
|
val y = dx(x to 0.0) { x -> cosh(x) }
|
||||||
|
assertApprox(1.0, y.value) //y = cosh(0)
|
||||||
|
assertApprox(0.0, y.derivative(x)) // dy/dx = sinh(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTanh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> tanh(x) }
|
||||||
|
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
||||||
|
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAsinh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> asinh(x) }
|
||||||
|
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
|
||||||
|
assertApprox(6.0 / sqrt(36 + PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(pi^2+36)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAcosh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> acosh(x) }
|
||||||
|
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
|
||||||
|
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAtanh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> atanh(x) }
|
||||||
|
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
|
||||||
|
assertApprox(-36.0 / (PI * PI - 36.0), y.derivative(x)) // dy/dx = -36/(pi^2-36)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDivGrad() {
|
||||||
|
val res = dxy(x to 1.0, y to 2.0) { x, y -> x * x + y * y }
|
||||||
|
assertEquals(6.0, res.div())
|
||||||
|
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun assertApprox(a: Double, b: Double) {
|
||||||
|
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||||
|
}
|
||||||
|
}
|
@ -1,261 +0,0 @@
|
|||||||
package kscience.kmath.misc
|
|
||||||
|
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
import kscience.kmath.structures.asBuffer
|
|
||||||
import kotlin.math.PI
|
|
||||||
import kotlin.math.pow
|
|
||||||
import kotlin.math.sqrt
|
|
||||||
import kotlin.test.Test
|
|
||||||
import kotlin.test.assertEquals
|
|
||||||
import kotlin.test.assertTrue
|
|
||||||
|
|
||||||
class AutoDiffTest {
|
|
||||||
inline fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
|
|
||||||
RealField.deriv(body)
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPlusX2() {
|
|
||||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
|
||||||
val y = deriv { x + x }
|
|
||||||
assertEquals(6.0, y.value) // y = x + x = 6
|
|
||||||
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPlus() {
|
|
||||||
// two variables
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = Variable(3.0)
|
|
||||||
val z = deriv { x + y }
|
|
||||||
assertEquals(5.0, z.value) // z = x + y = 5
|
|
||||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
|
||||||
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testMinus() {
|
|
||||||
// two variables
|
|
||||||
val x = Variable(7.0)
|
|
||||||
val y = Variable(3.0)
|
|
||||||
val z = deriv { x - y }
|
|
||||||
assertEquals(4.0, z.value) // z = x - y = 4
|
|
||||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
|
||||||
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testMulX2() {
|
|
||||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
|
||||||
val y = deriv { x * x }
|
|
||||||
assertEquals(9.0, y.value) // y = x * x = 9
|
|
||||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSqr() {
|
|
||||||
val x = Variable(3.0)
|
|
||||||
val y = deriv { sqr(x) }
|
|
||||||
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
|
||||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSqrSqr() {
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = deriv { sqr(sqr(x)) }
|
|
||||||
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
|
||||||
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testX3() {
|
|
||||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
|
||||||
val y = deriv { x * x * x }
|
|
||||||
assertEquals(8.0, y.value) // y = x * x * x = 8
|
|
||||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testDiv() {
|
|
||||||
val x = Variable(5.0)
|
|
||||||
val y = Variable(2.0)
|
|
||||||
val z = deriv { x / y }
|
|
||||||
assertEquals(2.5, z.value) // z = x / y = 2.5
|
|
||||||
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
|
|
||||||
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPow3() {
|
|
||||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
|
||||||
val y = deriv { pow(x, 3) }
|
|
||||||
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
|
||||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPowFull() {
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = Variable(3.0)
|
|
||||||
val z = deriv { pow(x, y) }
|
|
||||||
assertApprox(8.0, z.value) // z = x ^ y = 8
|
|
||||||
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
|
|
||||||
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testFromPaper() {
|
|
||||||
val x = Variable(3.0)
|
|
||||||
val y = deriv { 2 * x + x * x * x }
|
|
||||||
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
|
||||||
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testInnerVariable() {
|
|
||||||
val x = Variable(1.0)
|
|
||||||
val y = deriv {
|
|
||||||
Variable(1.0) * x
|
|
||||||
}
|
|
||||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
|
||||||
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testLongChain() {
|
|
||||||
val n = 10_000
|
|
||||||
val x = Variable(1.0)
|
|
||||||
val y = deriv {
|
|
||||||
var res = Variable(1.0)
|
|
||||||
for (i in 1..n) res *= x
|
|
||||||
res
|
|
||||||
}
|
|
||||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
|
||||||
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testExample() {
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = deriv { sqr(x) + 5 * x + 3 }
|
|
||||||
assertEquals(17.0, y.value) // the value of result (y)
|
|
||||||
assertEquals(9.0, y.deriv(x)) // dy/dx
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSqrt() {
|
|
||||||
val x = Variable(16.0)
|
|
||||||
val y = deriv { sqrt(x) }
|
|
||||||
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
|
||||||
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSin() {
|
|
||||||
val x = Variable(PI / 6.0)
|
|
||||||
val y = deriv { sin(x) }
|
|
||||||
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
|
||||||
assertApprox(sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testCos() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { cos(x) }
|
|
||||||
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
|
|
||||||
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(pi/6) = -0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testTan() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { tan(x) }
|
|
||||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
|
|
||||||
assertApprox(4.0 / 3.0, y.deriv(x)) // dy/dx = sec(pi/6)^2 = 4/3
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAsin() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { asin(x) }
|
|
||||||
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
|
|
||||||
assertApprox(6.0 / sqrt(36 - PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(36-pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAcos() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { acos(x) }
|
|
||||||
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
|
|
||||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAtan() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { atan(x) }
|
|
||||||
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
|
|
||||||
assertApprox(36.0 / (36.0 + PI * PI), y.deriv(x)) // dy/dx = 36/(36+pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSinh() {
|
|
||||||
val x = Variable(0.0)
|
|
||||||
val y = deriv { sinh(x) }
|
|
||||||
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
|
|
||||||
assertApprox(kotlin.math.cosh(0.0), y.deriv(x)) // dy/dx = cosh(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testCosh() {
|
|
||||||
val x = Variable(0.0)
|
|
||||||
val y = deriv { cosh(x) }
|
|
||||||
assertApprox(1.0, y.value) //y = cosh(0)
|
|
||||||
assertApprox(0.0, y.deriv(x)) // dy/dx = sinh(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testTanh() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { tanh(x) }
|
|
||||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
|
||||||
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.deriv(x)) // dy/dx = sech(pi/6)^2
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAsinh() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { asinh(x) }
|
|
||||||
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
|
|
||||||
assertApprox(6.0 / sqrt(36 + PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(pi^2+36)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAcosh() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { acosh(x) }
|
|
||||||
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
|
|
||||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAtanh() {
|
|
||||||
val x = Variable(PI / 6.0)
|
|
||||||
val y = deriv { atanh(x) }
|
|
||||||
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
|
|
||||||
assertApprox(-36.0 / (PI * PI - 36.0), y.deriv(x)) // dy/dx = -36/(pi^2-36)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testDivGrad() {
|
|
||||||
val x = Variable(1.0)
|
|
||||||
val y = Variable(2.0)
|
|
||||||
val res = deriv { x * x + y * y }
|
|
||||||
assertEquals(6.0, res.div())
|
|
||||||
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun assertApprox(a: Double, b: Double) {
|
|
||||||
if ((a - b) > 1e-10) assertEquals(a, b)
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user
I am sure that it may be replaced with sealed class to prevent one to extend AutoDiffValue with irrelevant object.
The idea is that it could be extended anytime. Here
AutoDiffValue
is just a marker interface. It is possible to even replace it with an inline class, but we need performance measurements to make that change.