New Expression API
This commit is contained in:
parent
e44423192d
commit
707ad21f77
12
README.md
12
README.md
@ -54,8 +54,6 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||||
to submit a feature request if you want something to be done first.
|
to submit a feature request if you want something to be done first.
|
||||||
|
|
||||||
* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
|
|
||||||
|
|
||||||
## Planned features
|
## Planned features
|
||||||
|
|
||||||
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
||||||
@ -117,6 +115,12 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
<hr/>
|
||||||
|
|
||||||
|
* ### [kmath-ejml](kmath-ejml)
|
||||||
|
>
|
||||||
|
>
|
||||||
|
> **Maturity**: EXPERIMENTAL
|
||||||
|
<hr/>
|
||||||
|
|
||||||
* ### [kmath-for-real](kmath-for-real)
|
* ### [kmath-for-real](kmath-for-real)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
@ -178,8 +182,8 @@ repositories{
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies{
|
dependencies{
|
||||||
api("kscience.kmath:kmath-core:0.2.0-dev-1")
|
api("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||||
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-1") for jvm-specific version
|
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -25,3 +25,7 @@ subprojects {
|
|||||||
readme {
|
readme {
|
||||||
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apiValidation{
|
||||||
|
validationDisabled = true
|
||||||
|
}
|
2
docs/templates/README-TEMPLATE.md
vendored
2
docs/templates/README-TEMPLATE.md
vendored
@ -107,4 +107,4 @@ with the same artifact names.
|
|||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
|
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
|
@ -14,8 +14,8 @@ import kotlin.contracts.contract
|
|||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
||||||
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
|
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||||
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
@ -27,7 +27,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
error("Numeric nodes are not supported by $this")
|
error("Numeric nodes are not supported by $this")
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -37,7 +37,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||||
mstAlgebra: E,
|
mstAlgebra: E,
|
||||||
block: E.() -> MST
|
block: E.() -> MST,
|
||||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -116,7 +116,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||||
block: MstExtendedField.() -> MST
|
block: MstExtendedField.() -> MST,
|
||||||
): MstExpression<T> {
|
): MstExpression<T> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInExtendedField(block)
|
return algebra.mstInExtendedField(block)
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
package kscience.kmath.asm.internal
|
package kscience.kmath.asm.internal
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.StringSymbol
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
|
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
|
||||||
*
|
*
|
||||||
@ -9,4 +11,4 @@ package kscience.kmath.asm.internal
|
|||||||
*/
|
*/
|
||||||
@JvmOverloads
|
@JvmOverloads
|
||||||
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
|
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
|
||||||
this[key] ?: default ?: error("Parameter not found: $key")
|
this[StringSymbol(key.toString())] ?: default ?: error("Parameter not found: $key")
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
package kscience.kmath.asm
|
package kscience.kmath.asm
|
||||||
|
|
||||||
import kscience.kmath.asm.compile
|
|
||||||
import kscience.kmath.ast.mstInField
|
import kscience.kmath.ast.mstInField
|
||||||
import kscience.kmath.ast.mstInRing
|
import kscience.kmath.ast.mstInRing
|
||||||
import kscience.kmath.ast.mstInSpace
|
import kscience.kmath.ast.mstInSpace
|
||||||
@ -11,6 +10,7 @@ import kotlin.test.Test
|
|||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class TestAsmAlgebras {
|
internal class TestAsmAlgebras {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun space() {
|
fun space() {
|
||||||
val res1 = ByteRing.mstInSpace {
|
val res1 = ByteRing.mstInSpace {
|
||||||
|
@ -1,48 +1,57 @@
|
|||||||
package kscience.kmath.commons.expressions
|
package kscience.kmath.commons.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.expressions.ExpressionAlgebra
|
import kscience.kmath.expressions.ExpressionAlgebra
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
import kscience.kmath.operations.ExtendedField
|
import kscience.kmath.operations.ExtendedField
|
||||||
import kscience.kmath.operations.Field
|
|
||||||
import kscience.kmath.operations.invoke
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
import kotlin.properties.ReadOnlyProperty
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field over commons-math [DerivativeStructure].
|
* A field over commons-math [DerivativeStructure].
|
||||||
*
|
*
|
||||||
* @property order The derivation order.
|
* @property order The derivation order.
|
||||||
* @property parameters The map of free parameters.
|
* @property bindings The map of bindings values. All bindings are considered free parameters
|
||||||
*/
|
*/
|
||||||
public class DerivativeStructureField(
|
public class DerivativeStructureField(
|
||||||
public val order: Int,
|
public val order: Int,
|
||||||
public val parameters: Map<String, Double>,
|
private val bindings: Map<Symbol, Double>
|
||||||
) : ExtendedField<DerivativeStructure> {
|
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) }
|
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
|
||||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) }
|
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
|
||||||
|
|
||||||
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
|
/**
|
||||||
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
|
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||||
|
*/
|
||||||
|
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
||||||
|
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
||||||
|
override val identity: Any = symbol.identity
|
||||||
}
|
}
|
||||||
|
|
||||||
public val variable: ReadOnlyProperty<Any?, DerivativeStructure> = ReadOnlyProperty { _, property ->
|
/**
|
||||||
variables[property.name] ?: error("A variable with name ${property.name} does not exist")
|
* Identity-based symbol bindings map
|
||||||
|
*/
|
||||||
|
private val variables: Map<Any?, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
|
||||||
|
key.identity to DerivativeStructureSymbol(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
|
override fun const(value: Double): DerivativeStructure = DerivativeStructure(order, bindings.size, value)
|
||||||
variables[name] ?: default ?: error("A variable with name $name does not exist")
|
|
||||||
|
|
||||||
public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
|
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
|
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
||||||
return deriv(mapOf(parName to order))
|
|
||||||
|
public fun Number.const(): DerivativeStructure = const(toDouble())
|
||||||
|
|
||||||
|
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
||||||
|
return derivative(mapOf(parameter to order))
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
|
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double {
|
||||||
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
|
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
|
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
|
||||||
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
||||||
|
|
||||||
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
||||||
@ -85,48 +94,16 @@ public class DerivativeStructureField(
|
|||||||
/**
|
/**
|
||||||
* A constructs that creates a derivative structure with required order on-demand
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
*/
|
*/
|
||||||
public class DiffExpression(
|
public class DerivativeStructureExpression(
|
||||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||||
) : Expression<Double> {
|
) : DifferentiableExpression<Double> {
|
||||||
public override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
|
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
0,
|
DerivativeStructureField(0, arguments).function().value
|
||||||
arguments
|
|
||||||
).function().value
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the derivative expression with given orders
|
* Get the derivative expression with given orders
|
||||||
* TODO make result [DiffExpression]
|
|
||||||
*/
|
*/
|
||||||
public fun derivative(orders: Map<String, Int>): Expression<Double> = Expression { arguments ->
|
public override fun derivative(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
||||||
(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().deriv(orders) }
|
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO add gradient and maybe other vector operators
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
|
|
||||||
public fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
|
||||||
*/
|
|
||||||
public object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
|
|
||||||
public override val zero: DiffExpression = DiffExpression { 0.0.const() }
|
|
||||||
public override val one: DiffExpression = DiffExpression { 1.0.const() }
|
|
||||||
|
|
||||||
public override fun variable(name: String, default: Double?): DiffExpression =
|
|
||||||
DiffExpression { variable(name, default?.const()) }
|
|
||||||
|
|
||||||
public override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
|
|
||||||
|
|
||||||
public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
|
|
||||||
DiffExpression { a.function(this) + b.function(this) }
|
|
||||||
|
|
||||||
public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k }
|
|
||||||
|
|
||||||
public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
|
|
||||||
DiffExpression { a.function(this) * b.function(this) }
|
|
||||||
|
|
||||||
public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
|
|
||||||
DiffExpression { a.function(this) / b.function(this) }
|
|
||||||
}
|
}
|
@ -1,6 +1,6 @@
|
|||||||
package kscience.kmath.commons.expressions
|
package kscience.kmath.commons.expressions
|
||||||
|
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.*
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -8,33 +8,37 @@ import kotlin.test.assertEquals
|
|||||||
|
|
||||||
internal inline fun <R> diff(
|
internal inline fun <R> diff(
|
||||||
order: Int,
|
order: Int,
|
||||||
vararg parameters: Pair<String, Double>,
|
vararg parameters: Pair<Symbol, Double>,
|
||||||
block: DerivativeStructureField.() -> R
|
block: DerivativeStructureField.() -> R,
|
||||||
): R {
|
): R {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AutoDiffTest {
|
internal class AutoDiffTest {
|
||||||
|
private val x by symbol
|
||||||
|
private val y by symbol
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun derivativeStructureFieldTest() {
|
fun derivativeStructureFieldTest() {
|
||||||
val res: Double = diff(3, "x" to 1.0, "y" to 1.0) {
|
val res: Double = diff(3, x to 1.0, y to 1.0) {
|
||||||
val x by variable
|
val x = bind(x)//by binding()
|
||||||
val y = variable("y")
|
val y = symbol("y")
|
||||||
val z = x * (-sin(x * y) + y)
|
val z = x * (-sin(x * y) + y)
|
||||||
z.deriv("x")
|
z.derivative(x)
|
||||||
}
|
}
|
||||||
|
println(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun autoDifTest() {
|
fun autoDifTest() {
|
||||||
val f = DiffExpression {
|
val f = DerivativeStructureExpression {
|
||||||
val x by variable
|
val x by binding()
|
||||||
val y by variable
|
val y by binding()
|
||||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(10.0, f("x" to 1.0, "y" to 2.0))
|
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||||
assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0))
|
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -12,7 +12,7 @@ The core features of KMath:
|
|||||||
|
|
||||||
> #### Artifact:
|
> #### Artifact:
|
||||||
>
|
>
|
||||||
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-1`.
|
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
|
||||||
>
|
>
|
||||||
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
||||||
>
|
>
|
||||||
@ -22,25 +22,28 @@ The core features of KMath:
|
|||||||
>
|
>
|
||||||
> ```gradle
|
> ```gradle
|
||||||
> repositories {
|
> repositories {
|
||||||
|
> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
|
||||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||||
|
|
||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-1'
|
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
> **Gradle Kotlin DSL:**
|
> **Gradle Kotlin DSL:**
|
||||||
>
|
>
|
||||||
> ```kotlin
|
> ```kotlin
|
||||||
> repositories {
|
> repositories {
|
||||||
|
> maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
> maven("https://dl.bintray.com/mipt-npm/dev")
|
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation("kscience.kmath:kmath-core:0.2.0-dev-1")
|
> implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
|
@ -41,6 +41,6 @@ readme {
|
|||||||
feature(
|
feature(
|
||||||
id = "autodif",
|
id = "autodif",
|
||||||
description = "Automatic differentiation",
|
description = "Automatic differentiation",
|
||||||
ref = "src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt"
|
ref = "src/commonMain/kotlin/kscience/kmath/misc/SimpleAutoDiff.kt"
|
||||||
)
|
)
|
||||||
}
|
}
|
@ -1,6 +1,26 @@
|
|||||||
package kscience.kmath.expressions
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
|
import kotlin.jvm.JvmName
|
||||||
|
import kotlin.properties.ReadOnlyProperty
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A marker interface for a symbol. A symbol mus have an identity
|
||||||
|
*/
|
||||||
|
public interface Symbol {
|
||||||
|
/**
|
||||||
|
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
|
||||||
|
* By default uses object identity
|
||||||
|
*/
|
||||||
|
public val identity: Any get() = this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A [Symbol] with a [String] identity
|
||||||
|
*/
|
||||||
|
public inline class StringSymbol(override val identity: String) : Symbol {
|
||||||
|
override fun toString(): String = identity
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An elementary function that could be invoked on a map of arguments
|
* An elementary function that could be invoked on a map of arguments
|
||||||
@ -12,30 +32,81 @@ public fun interface Expression<T> {
|
|||||||
* @param arguments the map of arguments.
|
* @param arguments the map of arguments.
|
||||||
* @return the value.
|
* @return the value.
|
||||||
*/
|
*/
|
||||||
public operator fun invoke(arguments: Map<String, T>): T
|
public operator fun invoke(arguments: Map<Symbol, T>): T
|
||||||
|
|
||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invlode an expression without parameters
|
||||||
|
*/
|
||||||
|
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
||||||
|
//This method exists to avoid resolution ambiguity of vararg methods
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls this expression from arguments.
|
* Calls this expression from arguments.
|
||||||
*
|
*
|
||||||
* @param pairs the pair of arguments' names to values.
|
* @param pairs the pair of arguments' names to values.
|
||||||
* @return the value.
|
* @return the value.
|
||||||
*/
|
*/
|
||||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
@JvmName("callBySymbol")
|
||||||
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
|
||||||
|
|
||||||
|
@JvmName("callByString")
|
||||||
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||||
|
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* And object that could be differentiated
|
||||||
|
*/
|
||||||
|
public interface Differentiable<T> {
|
||||||
|
public fun derivative(orders: Map<Symbol, Int>): T
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||||
|
derivative(mapOf(*orders))
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for expression construction
|
* A context for expression construction
|
||||||
|
*
|
||||||
|
* @param T type of the constants for the expression
|
||||||
|
* @param E type of the actual expression state
|
||||||
*/
|
*/
|
||||||
public interface ExpressionAlgebra<T, E> : Algebra<E> {
|
public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Introduce a variable into expression context
|
* Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
|
||||||
*/
|
*/
|
||||||
public fun variable(name: String, default: T? = null): E
|
public fun bindOrNull(symbol: Symbol): E?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind a string to a context using [StringSymbol]
|
||||||
|
*/
|
||||||
|
override fun symbol(value: String): E = bind(StringSymbol(value))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A constant expression which does not depend on arguments
|
* A constant expression which does not depend on arguments
|
||||||
*/
|
*/
|
||||||
public fun const(value: T): E
|
public fun const(value: T): E
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind a given [Symbol] to this context variable and produce context-specific object.
|
||||||
|
*/
|
||||||
|
public fun <T, E> ExpressionAlgebra<T, E>.bind(symbol: Symbol): E =
|
||||||
|
bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this")
|
||||||
|
|
||||||
|
public val symbol: ReadOnlyProperty<Any?, Symbol> = ReadOnlyProperty { _, property ->
|
||||||
|
StringSymbol(property.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T, E> ExpressionAlgebra<T, E>.binding(): ReadOnlyProperty<Any?, E> =
|
||||||
|
ReadOnlyProperty { _, property ->
|
||||||
|
bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
|
||||||
|
}
|
@ -2,39 +2,6 @@ package kscience.kmath.expressions
|
|||||||
|
|
||||||
import kscience.kmath.operations.*
|
import kscience.kmath.operations.*
|
||||||
|
|
||||||
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
|
||||||
Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
|
||||||
context.unaryOperation(name, expr.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalBinaryOperation<T>(
|
|
||||||
val context: Algebra<T>,
|
|
||||||
val name: String,
|
|
||||||
val first: Expression<T>,
|
|
||||||
val second: Expression<T>
|
|
||||||
) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
|
||||||
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
|
||||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = value
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalConstProductExpression<T>(
|
|
||||||
val context: Space<T>,
|
|
||||||
private val expr: Expression<T>,
|
|
||||||
val const: Number
|
|
||||||
) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context class for [Expression] construction.
|
* A context class for [Expression] construction.
|
||||||
*
|
*
|
||||||
@ -45,24 +12,32 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of constant expression which does not depend on arguments.
|
* Builds an Expression of constant expression which does not depend on arguments.
|
||||||
*/
|
*/
|
||||||
public override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
public override fun const(value: T): Expression<T> = Expression { value }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression to access a variable.
|
* Builds an Expression to access a variable.
|
||||||
*/
|
*/
|
||||||
public override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
public override fun bindOrNull(symbol: Symbol): Expression<T>? = Expression { arguments ->
|
||||||
|
arguments[symbol] ?: error("Argument not found: $symbol")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||||
*/
|
*/
|
||||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(
|
||||||
FunctionalBinaryOperation(algebra, operation, left, right)
|
operation: String,
|
||||||
|
left: Expression<T>,
|
||||||
|
right: Expression<T>,
|
||||||
|
): Expression<T> = Expression { arguments ->
|
||||||
|
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||||
*/
|
*/
|
||||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
|
||||||
FunctionalUnaryOperation(algebra, operation, arg)
|
algebra.unaryOperation(operation, arg.invoke(arguments))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -81,8 +56,9 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of multiplication of expression by number.
|
* Builds an Expression of multiplication of expression by number.
|
||||||
*/
|
*/
|
||||||
public override fun multiply(a: Expression<T>, k: Number): Expression<T> =
|
public override fun multiply(a: Expression<T>, k: Number): Expression<T> = Expression { arguments ->
|
||||||
FunctionalConstProductExpression(algebra, a, k)
|
algebra.multiply(a.invoke(arguments), k)
|
||||||
|
}
|
||||||
|
|
||||||
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||||
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
||||||
@ -118,8 +94,8 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
|
|||||||
}
|
}
|
||||||
|
|
||||||
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||||
FunctionalExpressionRing<T, A>(algebra),
|
FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>
|
||||||
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
|
where A : Field<T>, A : NumericAlgebra<T> {
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of division an expression by another one.
|
* Builds an Expression of division an expression by another one.
|
||||||
*/
|
*/
|
||||||
|
@ -0,0 +1,329 @@
|
|||||||
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.operations.*
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Implementation of backward-mode automatic differentiation.
|
||||||
|
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A [Symbol] with bound value
|
||||||
|
*/
|
||||||
|
public interface BoundSymbol<out T> : Symbol {
|
||||||
|
public val value: T
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind a [Symbol] to a [value] and produce [BoundSymbol]
|
||||||
|
*/
|
||||||
|
public fun <T> Symbol.bind(value: T): BoundSymbol<T> = object : BoundSymbol<T> {
|
||||||
|
override val identity = this@bind.identity
|
||||||
|
override val value: T = value
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents result of [withAutoDiff] call.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @param value the value of result.
|
||||||
|
* @property withAutoDiff The mapping of differentiated variables to their derivatives.
|
||||||
|
* @property context The field over [T].
|
||||||
|
*/
|
||||||
|
public class DerivationResult<T : Any>(
|
||||||
|
override val value: T,
|
||||||
|
private val derivativeValues: Map<Any, T>,
|
||||||
|
public val context: Field<T>,
|
||||||
|
) : BoundSymbol<T> {
|
||||||
|
/**
|
||||||
|
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
||||||
|
*/
|
||||||
|
public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the divergence.
|
||||||
|
*/
|
||||||
|
public fun div(): T = context { sum(derivativeValues.values) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the gradient for variables in given order.
|
||||||
|
*/
|
||||||
|
public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> {
|
||||||
|
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
||||||
|
return variables.map(::derivative).asBuffer()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
||||||
|
*
|
||||||
|
* The partial derivatives are placed in argument `d` variable
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* ```
|
||||||
|
* val x by symbol // define variable(s) and their values
|
||||||
|
* val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||||
|
* assertEquals(17.0, y.x) // the value of result (y)
|
||||||
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
|
* ```
|
||||||
|
*
|
||||||
|
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||||
|
* @return the result of differentiation.
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
||||||
|
bindings: Collection<BoundSymbol<T>>,
|
||||||
|
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
|
||||||
|
): DerivationResult<T> {
|
||||||
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
|
return AutoDiffContext(this, bindings).derivate(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
||||||
|
vararg bindings: Pair<Symbol, T>,
|
||||||
|
body: AutoDiffField<T, F>.() -> BoundSymbol<T>,
|
||||||
|
): DerivationResult<T> = withAutoDiff(bindings.map { it.first.bind(it.second) }, body)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents field in context of which functions can be derived.
|
||||||
|
*/
|
||||||
|
public abstract class AutoDiffField<T : Any, F : Field<T>>
|
||||||
|
: Field<BoundSymbol<T>>, ExpressionAlgebra<T, BoundSymbol<T>> {
|
||||||
|
|
||||||
|
public abstract val context: F
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A variable accessing inner state of derivatives.
|
||||||
|
* Use this value in inner builders to avoid creating additional derivative bindings.
|
||||||
|
*/
|
||||||
|
public abstract var BoundSymbol<T>.d: T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||||
|
*
|
||||||
|
* For example, implementation of `sin` function is:
|
||||||
|
*
|
||||||
|
* ```
|
||||||
|
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
||||||
|
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
||||||
|
|
||||||
|
public inline fun const(block: F.() -> T): BoundSymbol<T> = const(context.block())
|
||||||
|
|
||||||
|
// Overloads for Double constants
|
||||||
|
|
||||||
|
override operator fun Number.plus(b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
||||||
|
b.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun BoundSymbol<T>.plus(b: Number): BoundSymbol<T> = b.plus(this)
|
||||||
|
|
||||||
|
override operator fun Number.minus(b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||||
|
|
||||||
|
override operator fun BoundSymbol<T>.minus(b: Number): BoundSymbol<T> =
|
||||||
|
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatic Differentiation context class.
|
||||||
|
*/
|
||||||
|
private class AutoDiffContext<T : Any, F : Field<T>>(
|
||||||
|
override val context: F,
|
||||||
|
bindings: Collection<BoundSymbol<T>>,
|
||||||
|
) : AutoDiffField<T, F>() {
|
||||||
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
|
private var sp: Int = 0
|
||||||
|
private val derivatives: MutableMap<Any, T> = hashMapOf()
|
||||||
|
override val zero: BoundSymbol<T> get() = const(context.zero)
|
||||||
|
override val one: BoundSymbol<T> get() = const(context.one)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
|
||||||
|
* with respect to this variable.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @property value The value of this variable.
|
||||||
|
*/
|
||||||
|
private class AutoDiffVariableWithDeriv<T : Any>(override val value: T, var d: T) : BoundSymbol<T>
|
||||||
|
|
||||||
|
private val bindings: Map<Any, BoundSymbol<T>> = bindings.associateBy { it.identity }
|
||||||
|
|
||||||
|
override fun bindOrNull(symbol: Symbol): BoundSymbol<T>? = bindings[symbol.identity]
|
||||||
|
|
||||||
|
override fun const(value: T): BoundSymbol<T> = AutoDiffVariableWithDeriv(value, context.zero)
|
||||||
|
|
||||||
|
override var BoundSymbol<T>.d: T
|
||||||
|
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero
|
||||||
|
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||||
|
// save block to stack for backward pass
|
||||||
|
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||||
|
stack[sp++] = block
|
||||||
|
stack[sp++] = value
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
fun runBackwardPass() {
|
||||||
|
while (sp > 0) {
|
||||||
|
val value = stack[--sp]
|
||||||
|
val block = stack[--sp] as F.(Any?) -> Unit
|
||||||
|
context.block(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
|
override fun add(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { a.value + b.value }) { z ->
|
||||||
|
a.d += z.d
|
||||||
|
b.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { a.value * b.value }) { z ->
|
||||||
|
a.d += z.d * b.value
|
||||||
|
b.d += z.d * a.value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun divide(a: BoundSymbol<T>, b: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { a.value / b.value }) { z ->
|
||||||
|
a.d += z.d / b.value
|
||||||
|
b.d -= z.d * a.value / (b.value * b.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: BoundSymbol<T>, k: Number): BoundSymbol<T> =
|
||||||
|
derive(const { k.toDouble() * a.value }) { z ->
|
||||||
|
a.d += z.d * k.toDouble()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun derivate(function: AutoDiffField<T, F>.() -> BoundSymbol<T>): DerivationResult<T> {
|
||||||
|
val result = function()
|
||||||
|
result.d = context.one // computing derivative w.r.t result
|
||||||
|
runBackwardPass()
|
||||||
|
return DerivationResult(result.value, derivatives, context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
|
*/
|
||||||
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
|
public val field: F,
|
||||||
|
public val function: AutoDiffField<T, F>.() -> BoundSymbol<T>,
|
||||||
|
) : DifferentiableExpression<T> {
|
||||||
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
|
val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
|
return AutoDiffContext(field, bindings).function().value
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the derivative expression with given orders
|
||||||
|
*/
|
||||||
|
public override fun derivative(orders: Map<Symbol, Int>): Expression<T> {
|
||||||
|
val dSymbol = orders.entries.singleOrNull { it.value == 1 }
|
||||||
|
?: error("SimpleAutoDiff supports only first order derivatives")
|
||||||
|
return Expression { arguments ->
|
||||||
|
val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
|
val derivationResult = AutoDiffContext(field, bindings).derivate(function)
|
||||||
|
derivationResult.derivative(dSymbol.key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
|
// x ^ 2
|
||||||
|
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||||
|
|
||||||
|
// x ^ 1/2
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||||
|
|
||||||
|
// x ^ y (const)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||||
|
x: BoundSymbol<T>,
|
||||||
|
y: Double,
|
||||||
|
): BoundSymbol<T> =
|
||||||
|
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||||
|
x: BoundSymbol<T>,
|
||||||
|
y: Int,
|
||||||
|
): BoundSymbol<T> =
|
||||||
|
pow(x, y.toDouble())
|
||||||
|
|
||||||
|
// exp(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||||
|
|
||||||
|
// ln(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||||
|
|
||||||
|
// x ^ y (any)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(
|
||||||
|
x: BoundSymbol<T>,
|
||||||
|
y: BoundSymbol<T>,
|
||||||
|
): BoundSymbol<T> =
|
||||||
|
exp(y * ln(x))
|
||||||
|
|
||||||
|
// sin(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||||
|
|
||||||
|
// cos(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { tan(x.value) }) { z ->
|
||||||
|
val c = cos(x.value)
|
||||||
|
x.d += z.d / (c * c)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { tan(x.value) }) { z ->
|
||||||
|
val c = cosh(x.value)
|
||||||
|
x.d += z.d / (c * c)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: BoundSymbol<T>): BoundSymbol<T> =
|
||||||
|
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
||||||
|
|
@ -1,266 +0,0 @@
|
|||||||
package kscience.kmath.misc
|
|
||||||
|
|
||||||
import kscience.kmath.linear.Point
|
|
||||||
import kscience.kmath.operations.*
|
|
||||||
import kscience.kmath.structures.asBuffer
|
|
||||||
import kotlin.contracts.InvocationKind
|
|
||||||
import kotlin.contracts.contract
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Implementation of backward-mode automatic differentiation.
|
|
||||||
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
|
||||||
* with respect to this variable.
|
|
||||||
*
|
|
||||||
* @param T the non-nullable type of value.
|
|
||||||
* @property value The value of this variable.
|
|
||||||
*/
|
|
||||||
public open class Variable<T : Any>(public val value: T)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents result of [deriv] call.
|
|
||||||
*
|
|
||||||
* @param T the non-nullable type of value.
|
|
||||||
* @param value the value of result.
|
|
||||||
* @property deriv The mapping of differentiated variables to their derivatives.
|
|
||||||
* @property context The field over [T].
|
|
||||||
*/
|
|
||||||
public class DerivationResult<T : Any>(
|
|
||||||
value: T,
|
|
||||||
public val deriv: Map<Variable<T>, T>,
|
|
||||||
public val context: Field<T>
|
|
||||||
) : Variable<T>(value) {
|
|
||||||
/**
|
|
||||||
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
|
||||||
*/
|
|
||||||
public fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the divergence.
|
|
||||||
*/
|
|
||||||
public fun div(): T = context { sum(deriv.values) }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the gradient for variables in given order.
|
|
||||||
*/
|
|
||||||
public fun grad(vararg variables: Variable<T>): Point<T> {
|
|
||||||
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
|
||||||
return variables.map(::deriv).asBuffer()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
|
||||||
*
|
|
||||||
* The partial derivatives are placed in argument `d` variable
|
|
||||||
*
|
|
||||||
* Example:
|
|
||||||
* ```
|
|
||||||
* val x = Variable(2) // define variable(s) and their values
|
|
||||||
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
|
||||||
* assertEquals(17.0, y.x) // the value of result (y)
|
|
||||||
* assertEquals(9.0, x.d) // dy/dx
|
|
||||||
* ```
|
|
||||||
*
|
|
||||||
* @param body the action in [AutoDiffField] context returning [Variable] to differentiate with respect to.
|
|
||||||
* @return the result of differentiation.
|
|
||||||
*/
|
|
||||||
public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
|
|
||||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
|
|
||||||
return (AutoDiffContext(this)) {
|
|
||||||
val result = body()
|
|
||||||
result.d = context.one // computing derivative w.r.t result
|
|
||||||
runBackwardPass()
|
|
||||||
DerivationResult(result.value, derivatives, this@deriv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents field in context of which functions can be derived.
|
|
||||||
*/
|
|
||||||
public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
|
||||||
public abstract val context: F
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A variable accessing inner state of derivatives.
|
|
||||||
* Use this value in inner builders to avoid creating additional derivative bindings.
|
|
||||||
*/
|
|
||||||
public abstract var Variable<T>.d: T
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
|
||||||
*
|
|
||||||
* For example, implementation of `sin` function is:
|
|
||||||
*
|
|
||||||
* ```
|
|
||||||
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
|
||||||
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
|
||||||
* }
|
|
||||||
* ```
|
|
||||||
*/
|
|
||||||
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public abstract fun variable(value: T): Variable<T>
|
|
||||||
|
|
||||||
public inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
|
||||||
|
|
||||||
// Overloads for Double constants
|
|
||||||
|
|
||||||
override operator fun Number.plus(b: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
|
|
||||||
b.d += z.d
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
|
||||||
|
|
||||||
override operator fun Number.minus(b: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
|
||||||
|
|
||||||
override operator fun Variable<T>.minus(b: Number): Variable<T> =
|
|
||||||
derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Automatic Differentiation context class.
|
|
||||||
*/
|
|
||||||
@PublishedApi
|
|
||||||
internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
|
||||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
|
||||||
private var sp: Int = 0
|
|
||||||
val derivatives: MutableMap<Variable<T>, T> = hashMapOf()
|
|
||||||
override val zero: Variable<T> get() = Variable(context.zero)
|
|
||||||
override val one: Variable<T> get() = Variable(context.one)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A variable coupled with its derivative. For internal use only
|
|
||||||
*/
|
|
||||||
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
|
|
||||||
|
|
||||||
override fun variable(value: T): Variable<T> =
|
|
||||||
VariableWithDeriv(value, context.zero)
|
|
||||||
|
|
||||||
override var Variable<T>.d: T
|
|
||||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
|
||||||
set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
|
||||||
// save block to stack for backward pass
|
|
||||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
|
||||||
stack[sp++] = block
|
|
||||||
stack[sp++] = value
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
fun runBackwardPass() {
|
|
||||||
while (sp > 0) {
|
|
||||||
val value = stack[--sp]
|
|
||||||
val block = stack[--sp] as F.(Any?) -> Unit
|
|
||||||
context.block(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic math (+, -, *, /)
|
|
||||||
|
|
||||||
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
|
|
||||||
a.d += z.d
|
|
||||||
b.d += z.d
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value * b.value }) { z ->
|
|
||||||
a.d += z.d * b.value
|
|
||||||
b.d += z.d * a.value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value / b.value }) { z ->
|
|
||||||
a.d += z.d / b.value
|
|
||||||
b.d -= z.d * a.value / (b.value * b.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: Variable<T>, k: Number): Variable<T> = derive(variable { k.toDouble() * a.value }) { z ->
|
|
||||||
a.d += z.d * k.toDouble()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
|
||||||
|
|
||||||
// x ^ 2
|
|
||||||
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
|
||||||
|
|
||||||
// x ^ 1/2
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
|
||||||
|
|
||||||
// x ^ y (const)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
|
|
||||||
derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> =
|
|
||||||
pow(x, y.toDouble())
|
|
||||||
|
|
||||||
// exp(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
|
|
||||||
|
|
||||||
// ln(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
|
|
||||||
|
|
||||||
// x ^ y (any)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
|
|
||||||
exp(y * ln(x))
|
|
||||||
|
|
||||||
// sin(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
|
||||||
|
|
||||||
// cos(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { tan(x.value) }) { z ->
|
|
||||||
val c = cos(x.value)
|
|
||||||
x.d += z.d / (c * c)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { tan(x.value) }) { z ->
|
|
||||||
val c = cosh(x.value)
|
|
||||||
x.d += z.d / (c * c)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
|
||||||
|
|
@ -6,19 +6,21 @@ import kscience.kmath.operations.RealField
|
|||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertFails
|
||||||
|
|
||||||
class ExpressionFieldTest {
|
class ExpressionFieldTest {
|
||||||
|
val x by symbol
|
||||||
@Test
|
@Test
|
||||||
fun testExpression() {
|
fun testExpression() {
|
||||||
val context = FunctionalExpressionField(RealField)
|
val context = FunctionalExpressionField(RealField)
|
||||||
|
|
||||||
val expression = context {
|
val expression = context {
|
||||||
val x = variable("x", 2.0)
|
val x by binding()
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression(x to 1.0), 4.0)
|
||||||
assertEquals(expression(), 9.0)
|
assertFails { expression()}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -26,33 +28,33 @@ class ExpressionFieldTest {
|
|||||||
val context = FunctionalExpressionField(ComplexField)
|
val context = FunctionalExpressionField(ComplexField)
|
||||||
|
|
||||||
val expression = context {
|
val expression = context {
|
||||||
val x = variable("x", Complex(2.0, 0.0))
|
val x = bind(x)
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
assertEquals(expression(x to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
||||||
assertEquals(expression(), Complex(9.0, 0.0))
|
//assertEquals(expression(), Complex(9.0, 0.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun separateContext() {
|
fun separateContext() {
|
||||||
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
|
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
|
||||||
val x = variable("x")
|
val x by binding()
|
||||||
return x * x + 2 * x + one
|
return x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
val expression = FunctionalExpressionField(RealField).expression()
|
val expression = FunctionalExpressionField(RealField).expression()
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression(x to 1.0), 4.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun valueExpression() {
|
fun valueExpression() {
|
||||||
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
|
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
|
||||||
val x = variable("x")
|
val x by binding()
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression(x to 1.0), 4.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,277 @@
|
|||||||
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.PI
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
class SimpleAutoDiffTest {
|
||||||
|
fun d(
|
||||||
|
vararg bindings: Pair<Symbol, Double>,
|
||||||
|
body: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>,
|
||||||
|
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
|
||||||
|
|
||||||
|
fun dx(
|
||||||
|
xBinding: Pair<Symbol, Double>,
|
||||||
|
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>) -> BoundSymbol<Double>,
|
||||||
|
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||||
|
|
||||||
|
fun dxy(
|
||||||
|
xBinding: Pair<Symbol, Double>,
|
||||||
|
yBinding: Pair<Symbol, Double>,
|
||||||
|
body: AutoDiffField<Double, RealField>.(x: BoundSymbol<Double>, y: BoundSymbol<Double>) -> BoundSymbol<Double>,
|
||||||
|
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) {
|
||||||
|
body(bind(xBinding.first), bind(yBinding.first))
|
||||||
|
}
|
||||||
|
|
||||||
|
fun diff(block: AutoDiffField<Double, RealField>.() -> BoundSymbol<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
||||||
|
return SimpleAutoDiffExpression(RealField, block)
|
||||||
|
}
|
||||||
|
|
||||||
|
val x by symbol
|
||||||
|
val y by symbol
|
||||||
|
val z by symbol
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlusX2() {
|
||||||
|
val y = d(x to 3.0) {
|
||||||
|
// diff w.r.t this x at 3
|
||||||
|
val x = bind(x)
|
||||||
|
x + x
|
||||||
|
}
|
||||||
|
assertEquals(6.0, y.value) // y = x + x = 6
|
||||||
|
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlus() {
|
||||||
|
// two variables
|
||||||
|
val z = d(x to 2.0, y to 3.0) {
|
||||||
|
val x = bind(x)
|
||||||
|
val y = bind(y)
|
||||||
|
x + y
|
||||||
|
}
|
||||||
|
assertEquals(5.0, z.value) // z = x + y = 5
|
||||||
|
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
|
||||||
|
assertEquals(1.0, z.derivative(y)) // dz/dy = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMinus() {
|
||||||
|
// two variables
|
||||||
|
val z = d(x to 7.0, y to 3.0) {
|
||||||
|
val x = bind(x)
|
||||||
|
val y = bind(y)
|
||||||
|
|
||||||
|
x - y
|
||||||
|
}
|
||||||
|
assertEquals(4.0, z.value) // z = x - y = 4
|
||||||
|
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
|
||||||
|
assertEquals(-1.0, z.derivative(y)) // dz/dy = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMulX2() {
|
||||||
|
val y = dx(x to 3.0) { x ->
|
||||||
|
// diff w.r.t this x at 3
|
||||||
|
x * x
|
||||||
|
}
|
||||||
|
assertEquals(9.0, y.value) // y = x * x = 9
|
||||||
|
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqr() {
|
||||||
|
val y = dx(x to 3.0) { x -> sqr(x) }
|
||||||
|
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
||||||
|
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrSqr() {
|
||||||
|
val y = dx(x to 2.0) { x -> sqr(sqr(x)) }
|
||||||
|
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
||||||
|
assertEquals(32.0, y.derivative(x)) // dy/dx = 4 * x^3 = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testX3() {
|
||||||
|
val y = dx(x to 2.0) { x ->
|
||||||
|
// diff w.r.t this x at 2
|
||||||
|
x * x * x
|
||||||
|
}
|
||||||
|
assertEquals(8.0, y.value) // y = x * x * x = 8
|
||||||
|
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x * x = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDiv() {
|
||||||
|
val z = dxy(x to 5.0, y to 2.0) { x, y ->
|
||||||
|
x / y
|
||||||
|
}
|
||||||
|
assertEquals(2.5, z.value) // z = x / y = 2.5
|
||||||
|
assertEquals(0.5, z.derivative(x)) // dz/dx = 1 / y = 0.5
|
||||||
|
assertEquals(-1.25, z.derivative(y)) // dz/dy = -x / y^2 = -1.25
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPow3() {
|
||||||
|
val y = dx(x to 2.0) { x ->
|
||||||
|
// diff w.r.t this x at 2
|
||||||
|
pow(x, 3)
|
||||||
|
}
|
||||||
|
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
||||||
|
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x ^ 2 = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPowFull() {
|
||||||
|
val z = dxy(x to 2.0, y to 3.0) { x, y ->
|
||||||
|
pow(x, y)
|
||||||
|
}
|
||||||
|
assertApprox(8.0, z.value) // z = x ^ y = 8
|
||||||
|
assertApprox(12.0, z.derivative(x)) // dz/dx = y * x ^ (y - 1) = 12
|
||||||
|
assertApprox(8.0 * kotlin.math.ln(2.0), z.derivative(y)) // dz/dy = x ^ y * ln(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testFromPaper() {
|
||||||
|
val y = dx(x to 3.0) { x -> 2 * x + x * x * x }
|
||||||
|
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
||||||
|
assertEquals(29.0, y.derivative(x)) // dy/dx = 2 + 3 * x * x = 29
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testInnerVariable() {
|
||||||
|
val y = dx(x to 1.0) { x ->
|
||||||
|
const(1.0) * x
|
||||||
|
}
|
||||||
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
|
assertEquals(1.0, y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLongChain() {
|
||||||
|
val n = 10_000
|
||||||
|
val y = dx(x to 1.0) { x ->
|
||||||
|
var res = const(1.0)
|
||||||
|
for (i in 1..n) res *= x
|
||||||
|
res
|
||||||
|
}
|
||||||
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
|
assertEquals(n.toDouble(), y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testExample() {
|
||||||
|
val y = dx(x to 2.0) { x -> sqr(x) + 5 * x + 3 }
|
||||||
|
assertEquals(17.0, y.value) // the value of result (y)
|
||||||
|
assertEquals(9.0, y.derivative(x)) // dy/dx
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrt() {
|
||||||
|
val y = dx(x to 16.0) { x -> sqrt(x) }
|
||||||
|
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
||||||
|
assertEquals(1.0 / 8, y.derivative(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSin() {
|
||||||
|
val y = dx(x to PI / 6.0) { x -> sin(x) }
|
||||||
|
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
||||||
|
assertApprox(sqrt(3.0) / 2, y.derivative(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCos() {
|
||||||
|
val y = dx(x to PI / 6) { x -> cos(x) }
|
||||||
|
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
|
||||||
|
assertApprox(-0.5, y.derivative(x)) // dy/dx = -sin(pi/6) = -0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTan() {
|
||||||
|
val y = dx(x to PI / 6) { x -> tan(x) }
|
||||||
|
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
|
||||||
|
assertApprox(4.0 / 3.0, y.derivative(x)) // dy/dx = sec(pi/6)^2 = 4/3
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAsin() {
|
||||||
|
val y = dx(x to PI / 6) { x -> asin(x) }
|
||||||
|
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
|
||||||
|
assertApprox(6.0 / sqrt(36 - PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(36-pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAcos() {
|
||||||
|
val y = dx(x to PI / 6) { x -> acos(x) }
|
||||||
|
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
|
||||||
|
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAtan() {
|
||||||
|
val y = dx(x to PI / 6) { x -> atan(x) }
|
||||||
|
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
|
||||||
|
assertApprox(36.0 / (36.0 + PI * PI), y.derivative(x)) // dy/dx = 36/(36+pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSinh() {
|
||||||
|
val y = dx(x to 0.0) { x -> sinh(x) }
|
||||||
|
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
|
||||||
|
assertApprox(kotlin.math.cosh(0.0), y.derivative(x)) // dy/dx = cosh(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCosh() {
|
||||||
|
val y = dx(x to 0.0) { x -> cosh(x) }
|
||||||
|
assertApprox(1.0, y.value) //y = cosh(0)
|
||||||
|
assertApprox(0.0, y.derivative(x)) // dy/dx = sinh(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTanh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> tanh(x) }
|
||||||
|
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
||||||
|
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAsinh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> asinh(x) }
|
||||||
|
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
|
||||||
|
assertApprox(6.0 / sqrt(36 + PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(pi^2+36)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAcosh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> acosh(x) }
|
||||||
|
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
|
||||||
|
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAtanh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> atanh(x) }
|
||||||
|
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
|
||||||
|
assertApprox(-36.0 / (PI * PI - 36.0), y.derivative(x)) // dy/dx = -36/(pi^2-36)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDivGrad() {
|
||||||
|
val res = dxy(x to 1.0, y to 2.0) { x, y -> x * x + y * y }
|
||||||
|
assertEquals(6.0, res.div())
|
||||||
|
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun assertApprox(a: Double, b: Double) {
|
||||||
|
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||||
|
}
|
||||||
|
}
|
@ -1,261 +0,0 @@
|
|||||||
package kscience.kmath.misc
|
|
||||||
|
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
import kscience.kmath.structures.asBuffer
|
|
||||||
import kotlin.math.PI
|
|
||||||
import kotlin.math.pow
|
|
||||||
import kotlin.math.sqrt
|
|
||||||
import kotlin.test.Test
|
|
||||||
import kotlin.test.assertEquals
|
|
||||||
import kotlin.test.assertTrue
|
|
||||||
|
|
||||||
class AutoDiffTest {
|
|
||||||
inline fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
|
|
||||||
RealField.deriv(body)
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPlusX2() {
|
|
||||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
|
||||||
val y = deriv { x + x }
|
|
||||||
assertEquals(6.0, y.value) // y = x + x = 6
|
|
||||||
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPlus() {
|
|
||||||
// two variables
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = Variable(3.0)
|
|
||||||
val z = deriv { x + y }
|
|
||||||
assertEquals(5.0, z.value) // z = x + y = 5
|
|
||||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
|
||||||
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testMinus() {
|
|
||||||
// two variables
|
|
||||||
val x = Variable(7.0)
|
|
||||||
val y = Variable(3.0)
|
|
||||||
val z = deriv { x - y }
|
|
||||||
assertEquals(4.0, z.value) // z = x - y = 4
|
|
||||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
|
||||||
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testMulX2() {
|
|
||||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
|
||||||
val y = deriv { x * x }
|
|
||||||
assertEquals(9.0, y.value) // y = x * x = 9
|
|
||||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSqr() {
|
|
||||||
val x = Variable(3.0)
|
|
||||||
val y = deriv { sqr(x) }
|
|
||||||
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
|
||||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSqrSqr() {
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = deriv { sqr(sqr(x)) }
|
|
||||||
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
|
||||||
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testX3() {
|
|
||||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
|
||||||
val y = deriv { x * x * x }
|
|
||||||
assertEquals(8.0, y.value) // y = x * x * x = 8
|
|
||||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testDiv() {
|
|
||||||
val x = Variable(5.0)
|
|
||||||
val y = Variable(2.0)
|
|
||||||
val z = deriv { x / y }
|
|
||||||
assertEquals(2.5, z.value) // z = x / y = 2.5
|
|
||||||
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
|
|
||||||
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPow3() {
|
|
||||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
|
||||||
val y = deriv { pow(x, 3) }
|
|
||||||
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
|
||||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPowFull() {
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = Variable(3.0)
|
|
||||||
val z = deriv { pow(x, y) }
|
|
||||||
assertApprox(8.0, z.value) // z = x ^ y = 8
|
|
||||||
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
|
|
||||||
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testFromPaper() {
|
|
||||||
val x = Variable(3.0)
|
|
||||||
val y = deriv { 2 * x + x * x * x }
|
|
||||||
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
|
||||||
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testInnerVariable() {
|
|
||||||
val x = Variable(1.0)
|
|
||||||
val y = deriv {
|
|
||||||
Variable(1.0) * x
|
|
||||||
}
|
|
||||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
|
||||||
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testLongChain() {
|
|
||||||
val n = 10_000
|
|
||||||
val x = Variable(1.0)
|
|
||||||
val y = deriv {
|
|
||||||
var res = Variable(1.0)
|
|
||||||
for (i in 1..n) res *= x
|
|
||||||
res
|
|
||||||
}
|
|
||||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
|
||||||
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testExample() {
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = deriv { sqr(x) + 5 * x + 3 }
|
|
||||||
assertEquals(17.0, y.value) // the value of result (y)
|
|
||||||
assertEquals(9.0, y.deriv(x)) // dy/dx
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSqrt() {
|
|
||||||
val x = Variable(16.0)
|
|
||||||
val y = deriv { sqrt(x) }
|
|
||||||
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
|
||||||
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSin() {
|
|
||||||
val x = Variable(PI / 6.0)
|
|
||||||
val y = deriv { sin(x) }
|
|
||||||
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
|
||||||
assertApprox(sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testCos() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { cos(x) }
|
|
||||||
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
|
|
||||||
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(pi/6) = -0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testTan() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { tan(x) }
|
|
||||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
|
|
||||||
assertApprox(4.0 / 3.0, y.deriv(x)) // dy/dx = sec(pi/6)^2 = 4/3
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAsin() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { asin(x) }
|
|
||||||
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
|
|
||||||
assertApprox(6.0 / sqrt(36 - PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(36-pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAcos() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { acos(x) }
|
|
||||||
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
|
|
||||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAtan() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { atan(x) }
|
|
||||||
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
|
|
||||||
assertApprox(36.0 / (36.0 + PI * PI), y.deriv(x)) // dy/dx = 36/(36+pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSinh() {
|
|
||||||
val x = Variable(0.0)
|
|
||||||
val y = deriv { sinh(x) }
|
|
||||||
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
|
|
||||||
assertApprox(kotlin.math.cosh(0.0), y.deriv(x)) // dy/dx = cosh(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testCosh() {
|
|
||||||
val x = Variable(0.0)
|
|
||||||
val y = deriv { cosh(x) }
|
|
||||||
assertApprox(1.0, y.value) //y = cosh(0)
|
|
||||||
assertApprox(0.0, y.deriv(x)) // dy/dx = sinh(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testTanh() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { tanh(x) }
|
|
||||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
|
||||||
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.deriv(x)) // dy/dx = sech(pi/6)^2
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAsinh() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { asinh(x) }
|
|
||||||
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
|
|
||||||
assertApprox(6.0 / sqrt(36 + PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(pi^2+36)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAcosh() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { acosh(x) }
|
|
||||||
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
|
|
||||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAtanh() {
|
|
||||||
val x = Variable(PI / 6.0)
|
|
||||||
val y = deriv { atanh(x) }
|
|
||||||
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
|
|
||||||
assertApprox(-36.0 / (PI * PI - 36.0), y.deriv(x)) // dy/dx = -36/(pi^2-36)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testDivGrad() {
|
|
||||||
val x = Variable(1.0)
|
|
||||||
val y = Variable(2.0)
|
|
||||||
val res = deriv { x * x + y * y }
|
|
||||||
assertEquals(6.0, res.div())
|
|
||||||
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun assertApprox(a: Double, b: Double) {
|
|
||||||
if ((a - b) > 1e-10) assertEquals(a, b)
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user