forked from kscience/kmath
Merge branch 'dev' into feature/quaternion
# Conflicts: # CHANGELOG.md
This commit is contained in:
commit
6d016c87f2
@ -1 +1,3 @@
|
|||||||
job("Build") { gradlew("openjdk:11", "build") }
|
job("Build") {
|
||||||
|
gradlew("openjdk:11", "build")
|
||||||
|
}
|
||||||
|
@ -8,6 +8,11 @@
|
|||||||
- Automatic README generation for features (#139)
|
- Automatic README generation for features (#139)
|
||||||
- Native support for `memory`, `core` and `dimensions`
|
- Native support for `memory`, `core` and `dimensions`
|
||||||
- `kmath-ejml` to supply EJML SimpleMatrix wrapper.
|
- `kmath-ejml` to supply EJML SimpleMatrix wrapper.
|
||||||
|
- A separate `Symbol` entity, which is used for global unbound symbol.
|
||||||
|
- A `Symbol` indexing scope.
|
||||||
|
- Basic optimization API for Commons-math.
|
||||||
|
- Chi squared optimization for array-like data in CM
|
||||||
|
- `Fitting` utility object in prob/stat
|
||||||
- Basic Quaternion vector support.
|
- Basic Quaternion vector support.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
@ -17,6 +22,8 @@
|
|||||||
- `Polynomial` secondary constructor made function.
|
- `Polynomial` secondary constructor made function.
|
||||||
- Kotlin version: 1.3.72 -> 1.4.20-M1
|
- Kotlin version: 1.3.72 -> 1.4.20-M1
|
||||||
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
|
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
|
||||||
|
- Full autodiff refactoring based on `Symbol`
|
||||||
|
- `kmath-prob` renamed to `kmath-stat`
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
|
18
README.md
18
README.md
@ -54,8 +54,6 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||||
to submit a feature request if you want something to be done first.
|
to submit a feature request if you want something to be done first.
|
||||||
|
|
||||||
* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
|
|
||||||
|
|
||||||
## Planned features
|
## Planned features
|
||||||
|
|
||||||
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
||||||
@ -101,7 +99,7 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
> - [buffers](kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure
|
> - [buffers](kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure
|
||||||
> - [expressions](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions
|
> - [expressions](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions
|
||||||
> - [domains](kmath-core/src/commonMain/kotlin/kscience/kmath/domains) : Domains
|
> - [domains](kmath-core/src/commonMain/kotlin/kscience/kmath/domains) : Domains
|
||||||
> - [autodif](kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt) : Automatic differentiation
|
> - [autodif](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
|
||||||
|
|
||||||
<hr/>
|
<hr/>
|
||||||
|
|
||||||
@ -117,6 +115,12 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
<hr/>
|
||||||
|
|
||||||
|
* ### [kmath-ejml](kmath-ejml)
|
||||||
|
>
|
||||||
|
>
|
||||||
|
> **Maturity**: EXPERIMENTAL
|
||||||
|
<hr/>
|
||||||
|
|
||||||
* ### [kmath-for-real](kmath-for-real)
|
* ### [kmath-for-real](kmath-for-real)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
@ -147,7 +151,7 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
<hr/>
|
||||||
|
|
||||||
* ### [kmath-prob](kmath-prob)
|
* ### [kmath-stat](kmath-stat)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
@ -178,8 +182,8 @@ repositories{
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies{
|
dependencies{
|
||||||
api("kscience.kmath:kmath-core:${kmathVersion}")
|
api("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||||
//api("scientifik:kmath-core:${kmathVersion}") for 0.1.3 and earlier
|
//api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -197,4 +201,4 @@ with the same artifact names.
|
|||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
|
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
|
@ -2,7 +2,7 @@ plugins {
|
|||||||
id("ru.mipt.npm.project")
|
id("ru.mipt.npm.project")
|
||||||
}
|
}
|
||||||
|
|
||||||
val kmathVersion: String by extra("0.2.0-dev-2")
|
val kmathVersion: String by extra("0.2.0-dev-3")
|
||||||
val bintrayRepo: String by extra("kscience")
|
val bintrayRepo: String by extra("kscience")
|
||||||
val githubProject: String by extra("kmath")
|
val githubProject: String by extra("kmath")
|
||||||
|
|
||||||
@ -25,3 +25,7 @@ subprojects {
|
|||||||
readme {
|
readme {
|
||||||
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apiValidation{
|
||||||
|
validationDisabled = true
|
||||||
|
}
|
2
docs/templates/README-TEMPLATE.md
vendored
2
docs/templates/README-TEMPLATE.md
vendored
@ -107,4 +107,4 @@ with the same artifact names.
|
|||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
|
The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
|
@ -23,7 +23,7 @@ dependencies {
|
|||||||
implementation(project(":kmath-core"))
|
implementation(project(":kmath-core"))
|
||||||
implementation(project(":kmath-coroutines"))
|
implementation(project(":kmath-coroutines"))
|
||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
implementation(project(":kmath-prob"))
|
implementation(project(":kmath-stat"))
|
||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
implementation(project(":kmath-ejml"))
|
implementation(project(":kmath-ejml"))
|
||||||
|
@ -4,7 +4,7 @@ import kotlinx.coroutines.Dispatchers
|
|||||||
import kotlinx.coroutines.async
|
import kotlinx.coroutines.async
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import kscience.kmath.chains.BlockingRealChain
|
import kscience.kmath.chains.BlockingRealChain
|
||||||
import kscience.kmath.prob.*
|
import kscience.kmath.stat.*
|
||||||
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
|
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
|
||||||
import org.apache.commons.rng.simple.RandomSource
|
import org.apache.commons.rng.simple.RandomSource
|
||||||
import java.time.Duration
|
import java.time.Duration
|
||||||
|
@ -3,9 +3,9 @@ package kscience.kmath.commons.prob
|
|||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import kscience.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import kscience.kmath.chains.collectWithState
|
import kscience.kmath.chains.collectWithState
|
||||||
import kscience.kmath.prob.Distribution
|
import kscience.kmath.stat.Distribution
|
||||||
import kscience.kmath.prob.RandomGenerator
|
import kscience.kmath.stat.RandomGenerator
|
||||||
import kscience.kmath.prob.normal
|
import kscience.kmath.stat.normal
|
||||||
|
|
||||||
private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
|
||||||
|
|
||||||
|
@ -6,8 +6,8 @@ import kscience.kmath.structures.complex
|
|||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
// 2d element
|
// 2d element
|
||||||
val element = NDElement.complex(2, 2) { index: IntArray ->
|
val element = NDElement.complex(2, 2) { (i,j) ->
|
||||||
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
|
Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble())
|
||||||
}
|
}
|
||||||
println(element)
|
println(element)
|
||||||
|
|
||||||
|
@ -14,8 +14,8 @@ import kotlin.contracts.contract
|
|||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
||||||
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> {
|
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||||
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
|
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
@ -27,7 +27,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
error("Numeric nodes are not supported by $this")
|
error("Numeric nodes are not supported by $this")
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -37,7 +37,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||||
mstAlgebra: E,
|
mstAlgebra: E,
|
||||||
block: E.() -> MST
|
block: E.() -> MST,
|
||||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -116,7 +116,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||||
block: MstExtendedField.() -> MST
|
block: MstExtendedField.() -> MST,
|
||||||
): MstExpression<T> {
|
): MstExpression<T> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInExtendedField(block)
|
return algebra.mstInExtendedField(block)
|
||||||
|
@ -25,7 +25,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
private val classOfT: Class<*>,
|
private val classOfT: Class<*>,
|
||||||
private val algebra: Algebra<T>,
|
private val algebra: Algebra<T>,
|
||||||
private val className: String,
|
private val className: String,
|
||||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit,
|
||||||
) {
|
) {
|
||||||
/**
|
/**
|
||||||
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
||||||
@ -379,22 +379,14 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
|
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
|
||||||
* provided.
|
* provided.
|
||||||
*/
|
*/
|
||||||
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
|
||||||
load(invokeArgumentsVar, MAP_TYPE)
|
load(invokeArgumentsVar, MAP_TYPE)
|
||||||
aconst(name)
|
aconst(name)
|
||||||
|
|
||||||
if (defaultValue != null)
|
|
||||||
loadTConstant(defaultValue)
|
|
||||||
|
|
||||||
invokestatic(
|
invokestatic(
|
||||||
MAP_INTRINSICS_TYPE.internalName,
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
"getOrFail",
|
"getOrFail",
|
||||||
|
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
|
||||||
Type.getMethodDescriptor(
|
|
||||||
OBJECT_TYPE,
|
|
||||||
MAP_TYPE,
|
|
||||||
OBJECT_TYPE,
|
|
||||||
*OBJECT_TYPE.wrapToArrayIf { defaultValue != null }),
|
|
||||||
false
|
false
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -429,7 +421,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
method: String,
|
method: String,
|
||||||
descriptor: String,
|
descriptor: String,
|
||||||
expectedArity: Int,
|
expectedArity: Int,
|
||||||
opcode: Int = INVOKEINTERFACE
|
opcode: Int = INVOKEINTERFACE,
|
||||||
) {
|
) {
|
||||||
run loop@{
|
run loop@{
|
||||||
repeat(expectedArity) {
|
repeat(expectedArity) {
|
||||||
|
@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
package kscience.kmath.asm.internal
|
package kscience.kmath.asm.internal
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.StringSymbol
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
|
* Gets value with given [key] or throws [NoSuchElementException] whenever it is not present.
|
||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@JvmOverloads
|
internal fun <V> Map<Symbol, V>.getOrFail(key: String): V = getValue(StringSymbol(key))
|
||||||
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
|
|
||||||
this[key] ?: default ?: error("Parameter not found: $key")
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
package kscience.kmath.asm
|
package kscience.kmath.asm
|
||||||
|
|
||||||
import kscience.kmath.asm.compile
|
|
||||||
import kscience.kmath.ast.mstInField
|
import kscience.kmath.ast.mstInField
|
||||||
import kscience.kmath.ast.mstInRing
|
import kscience.kmath.ast.mstInRing
|
||||||
import kscience.kmath.ast.mstInSpace
|
import kscience.kmath.ast.mstInSpace
|
||||||
@ -11,6 +10,7 @@ import kotlin.test.Test
|
|||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class TestAsmAlgebras {
|
internal class TestAsmAlgebras {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun space() {
|
fun space() {
|
||||||
val res1 = ByteRing.mstInSpace {
|
val res1 = ByteRing.mstInSpace {
|
||||||
|
@ -6,7 +6,7 @@ description = "Commons math binding for kmath"
|
|||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
api(project(":kmath-prob"))
|
api(project(":kmath-stat"))
|
||||||
// api(project(":kmath-functions"))
|
api(project(":kmath-functions"))
|
||||||
api("org.apache.commons:commons-math3:3.6.1")
|
api("org.apache.commons:commons-math3:3.6.1")
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,115 @@
|
|||||||
|
package kscience.kmath.commons.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.*
|
||||||
|
import kscience.kmath.operations.ExtendedField
|
||||||
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field over commons-math [DerivativeStructure].
|
||||||
|
*
|
||||||
|
* @property order The derivation order.
|
||||||
|
* @property bindings The map of bindings values. All bindings are considered free parameters
|
||||||
|
*/
|
||||||
|
public class DerivativeStructureField(
|
||||||
|
public val order: Int,
|
||||||
|
private val bindings: Map<Symbol, Double>
|
||||||
|
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||||
|
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
|
||||||
|
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||||
|
*/
|
||||||
|
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
||||||
|
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
||||||
|
override val identity: String = symbol.identity
|
||||||
|
override fun toString(): String = identity
|
||||||
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Identity-based symbol bindings map
|
||||||
|
*/
|
||||||
|
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
|
||||||
|
key.identity to DerivativeStructureSymbol(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value)
|
||||||
|
|
||||||
|
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||||
|
|
||||||
|
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
|
||||||
|
|
||||||
|
//public fun Number.const(): DerivativeStructure = const(toDouble())
|
||||||
|
|
||||||
|
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
|
||||||
|
return derivative(mapOf(parameter to order))
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double {
|
||||||
|
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
|
||||||
|
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
||||||
|
|
||||||
|
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
||||||
|
is Double -> a.multiply(k)
|
||||||
|
is Int -> a.multiply(k)
|
||||||
|
else -> a.multiply(k.toDouble())
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
|
||||||
|
public override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
|
||||||
|
public override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
||||||
|
public override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
||||||
|
public override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
|
||||||
|
public override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
|
||||||
|
public override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
|
||||||
|
public override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
|
||||||
|
public override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
|
||||||
|
public override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
|
||||||
|
public override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
|
||||||
|
public override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
|
||||||
|
public override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
|
||||||
|
public override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
|
||||||
|
|
||||||
|
public override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
|
||||||
|
is Double -> arg.pow(pow)
|
||||||
|
is Int -> arg.pow(pow)
|
||||||
|
else -> arg.pow(pow.toDouble())
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
|
||||||
|
public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
|
||||||
|
public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
||||||
|
|
||||||
|
public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
||||||
|
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||||
|
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||||
|
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||||
|
|
||||||
|
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
||||||
|
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
|
||||||
|
return DerivativeStructureExpression(function)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
|
*/
|
||||||
|
public class DerivativeStructureExpression(
|
||||||
|
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||||
|
) : DifferentiableExpression<Double> {
|
||||||
|
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
|
DerivativeStructureField(0, arguments).function().value
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the derivative expression with given orders
|
||||||
|
*/
|
||||||
|
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
|
||||||
|
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
|
||||||
|
}
|
||||||
|
}
|
@ -1,131 +0,0 @@
|
|||||||
package kscience.kmath.commons.expressions
|
|
||||||
|
|
||||||
import kscience.kmath.expressions.Expression
|
|
||||||
import kscience.kmath.expressions.ExpressionAlgebra
|
|
||||||
import kscience.kmath.operations.ExtendedField
|
|
||||||
import kscience.kmath.operations.Field
|
|
||||||
import kscience.kmath.operations.invoke
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
|
||||||
import kotlin.properties.ReadOnlyProperty
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A field over commons-math [DerivativeStructure].
|
|
||||||
*
|
|
||||||
* @property order The derivation order.
|
|
||||||
* @property parameters The map of free parameters.
|
|
||||||
*/
|
|
||||||
public class DerivativeStructureField(
|
|
||||||
public val order: Int,
|
|
||||||
public val parameters: Map<String, Double>
|
|
||||||
) : ExtendedField<DerivativeStructure> {
|
|
||||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) }
|
|
||||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) }
|
|
||||||
|
|
||||||
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
|
|
||||||
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
|
|
||||||
}
|
|
||||||
|
|
||||||
public val variable: ReadOnlyProperty<Any?, DerivativeStructure> = ReadOnlyProperty { _, property ->
|
|
||||||
variables[property.name] ?: error("A variable with name ${property.name} does not exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
|
|
||||||
variables[name] ?: default ?: error("A variable with name $name does not exist")
|
|
||||||
|
|
||||||
public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
|
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
|
|
||||||
return deriv(mapOf(parName to order))
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
|
|
||||||
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
|
|
||||||
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
|
||||||
|
|
||||||
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
|
|
||||||
is Double -> a.multiply(k)
|
|
||||||
is Int -> a.multiply(k)
|
|
||||||
else -> a.multiply(k.toDouble())
|
|
||||||
}
|
|
||||||
|
|
||||||
public override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
|
|
||||||
public override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
|
|
||||||
public override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
|
||||||
public override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
|
||||||
public override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
|
|
||||||
public override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
|
|
||||||
public override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
|
|
||||||
public override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
|
|
||||||
public override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
|
|
||||||
public override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
|
|
||||||
public override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
|
|
||||||
public override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
|
|
||||||
public override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
|
|
||||||
public override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
|
|
||||||
|
|
||||||
public override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
|
|
||||||
is Double -> arg.pow(pow)
|
|
||||||
is Int -> arg.pow(pow)
|
|
||||||
else -> arg.pow(pow.toDouble())
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
|
|
||||||
public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
|
|
||||||
public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
|
||||||
|
|
||||||
public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
|
||||||
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
|
||||||
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
|
||||||
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A constructs that creates a derivative structure with required order on-demand
|
|
||||||
*/
|
|
||||||
public class DiffExpression(public val function: DerivativeStructureField.() -> DerivativeStructure) :
|
|
||||||
Expression<Double> {
|
|
||||||
public override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
|
|
||||||
0,
|
|
||||||
arguments
|
|
||||||
).function().value
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the derivative expression with given orders
|
|
||||||
* TODO make result [DiffExpression]
|
|
||||||
*/
|
|
||||||
public fun derivative(orders: Map<String, Int>): Expression<Double> = Expression { arguments ->
|
|
||||||
(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().deriv(orders) }
|
|
||||||
}
|
|
||||||
|
|
||||||
//TODO add gradient and maybe other vector operators
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
|
|
||||||
public fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
|
|
||||||
*/
|
|
||||||
public object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
|
|
||||||
public override val zero: DiffExpression = DiffExpression { 0.0.const() }
|
|
||||||
public override val one: DiffExpression = DiffExpression { 1.0.const() }
|
|
||||||
|
|
||||||
public override fun variable(name: String, default: Double?): DiffExpression =
|
|
||||||
DiffExpression { variable(name, default?.const()) }
|
|
||||||
|
|
||||||
public override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
|
|
||||||
|
|
||||||
public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
|
|
||||||
DiffExpression { a.function(this) + b.function(this) }
|
|
||||||
|
|
||||||
public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k }
|
|
||||||
|
|
||||||
public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
|
|
||||||
DiffExpression { a.function(this) * b.function(this) }
|
|
||||||
|
|
||||||
public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
|
|
||||||
DiffExpression { a.function(this) / b.function(this) }
|
|
||||||
}
|
|
@ -0,0 +1,111 @@
|
|||||||
|
package kscience.kmath.commons.optimization
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.*
|
||||||
|
import kscience.kmath.stat.OptimizationFeature
|
||||||
|
import kscience.kmath.stat.OptimizationProblem
|
||||||
|
import kscience.kmath.stat.OptimizationProblemFactory
|
||||||
|
import kscience.kmath.stat.OptimizationResult
|
||||||
|
import org.apache.commons.math3.optim.*
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||||
|
public operator fun PointValuePair.component2(): Double = value
|
||||||
|
|
||||||
|
public class CMOptimizationProblem(
|
||||||
|
override val symbols: List<Symbol>,
|
||||||
|
) : OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
|
||||||
|
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||||
|
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
||||||
|
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
||||||
|
DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER)
|
||||||
|
|
||||||
|
public fun addOptimizationData(data: OptimizationData) {
|
||||||
|
optimizationData[data::class] = data
|
||||||
|
}
|
||||||
|
|
||||||
|
init {
|
||||||
|
addOptimizationData(MaxEval.unlimited())
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
|
||||||
|
|
||||||
|
public override fun initialGuess(map: Map<Symbol, Double>): Unit {
|
||||||
|
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun expression(expression: Expression<Double>): Unit {
|
||||||
|
val objectiveFunction = ObjectiveFunction {
|
||||||
|
val args = it.toMap()
|
||||||
|
expression(args)
|
||||||
|
}
|
||||||
|
addOptimizationData(objectiveFunction)
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit {
|
||||||
|
expression(expression)
|
||||||
|
val gradientFunction = ObjectiveFunctionGradient {
|
||||||
|
val args = it.toMap()
|
||||||
|
DoubleArray(symbols.size) { index ->
|
||||||
|
expression.derivative(symbols[index])(args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
addOptimizationData(gradientFunction)
|
||||||
|
if (optimizatorBuilder == null) {
|
||||||
|
optimizatorBuilder = {
|
||||||
|
NonLinearConjugateGradientOptimizer(
|
||||||
|
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
||||||
|
convergenceChecker
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun simplex(simplex: AbstractSimplex) {
|
||||||
|
addOptimizationData(simplex)
|
||||||
|
//Set optimization builder to simplex if it is not present
|
||||||
|
if (optimizatorBuilder == null) {
|
||||||
|
optimizatorBuilder = { SimplexOptimizer(convergenceChecker) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun simplexSteps(steps: Map<Symbol, Double>) {
|
||||||
|
simplex(NelderMeadSimplex(steps.toDoubleArray()))
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun goal(goalType: GoalType) {
|
||||||
|
addOptimizationData(goalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun optimizer(block: () -> MultivariateOptimizer) {
|
||||||
|
optimizatorBuilder = block
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun update(result: OptimizationResult<Double>) {
|
||||||
|
initialGuess(result.point)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun optimize(): OptimizationResult<Double> {
|
||||||
|
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
|
||||||
|
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
|
||||||
|
return OptimizationResult(point.toMap(), value, setOf(this))
|
||||||
|
}
|
||||||
|
|
||||||
|
public companion object : OptimizationProblemFactory<Double, CMOptimizationProblem> {
|
||||||
|
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
|
||||||
|
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
|
||||||
|
public const val DEFAULT_MAX_ITER: Int = 1000
|
||||||
|
|
||||||
|
override fun build(symbols: List<Symbol>): CMOptimizationProblem = CMOptimizationProblem(symbols)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun CMOptimizationProblem.initialGuess(vararg pairs: Pair<Symbol, Double>): Unit = initialGuess(pairs.toMap())
|
||||||
|
public fun CMOptimizationProblem.simplexSteps(vararg pairs: Pair<Symbol, Double>): Unit = simplexSteps(pairs.toMap())
|
@ -0,0 +1,70 @@
|
|||||||
|
package kscience.kmath.commons.optimization
|
||||||
|
|
||||||
|
import kscience.kmath.commons.expressions.DerivativeStructureField
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import kscience.kmath.expressions.Expression
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
|
import kscience.kmath.stat.Fitting
|
||||||
|
import kscience.kmath.stat.OptimizationResult
|
||||||
|
import kscience.kmath.stat.optimizeWith
|
||||||
|
import kscience.kmath.structures.Buffer
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
|
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
|
*/
|
||||||
|
public fun Fitting.chiSquared(
|
||||||
|
x: Buffer<Double>,
|
||||||
|
y: Buffer<Double>,
|
||||||
|
yErr: Buffer<Double>,
|
||||||
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
|
): DifferentiableExpression<Double> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
|
*/
|
||||||
|
public fun Fitting.chiSquared(
|
||||||
|
x: Iterable<Double>,
|
||||||
|
y: Iterable<Double>,
|
||||||
|
yErr: Iterable<Double>,
|
||||||
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
|
): DifferentiableExpression<Double> = chiSquared(
|
||||||
|
DerivativeStructureField,
|
||||||
|
x.toList().asBuffer(),
|
||||||
|
y.toList().asBuffer(),
|
||||||
|
yErr.toList().asBuffer(),
|
||||||
|
model
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize expression without derivatives
|
||||||
|
*/
|
||||||
|
public fun Expression<Double>.optimize(
|
||||||
|
vararg symbols: Symbol,
|
||||||
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize differentiable expression
|
||||||
|
*/
|
||||||
|
public fun DifferentiableExpression<Double>.optimize(
|
||||||
|
vararg symbols: Symbol,
|
||||||
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
|
|
||||||
|
public fun DifferentiableExpression<Double>.minimize(
|
||||||
|
vararg startPoint: Pair<Symbol, Double>,
|
||||||
|
configuration: CMOptimizationProblem.() -> Unit = {},
|
||||||
|
): OptimizationResult<Double> {
|
||||||
|
require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
|
val problem = CMOptimizationProblem(startPoint.map { it.first }).apply(configuration)
|
||||||
|
problem.diffExpression(this)
|
||||||
|
problem.initialGuess(startPoint.toMap())
|
||||||
|
problem.goal(GoalType.MINIMIZE)
|
||||||
|
return problem.optimize()
|
||||||
|
}
|
@ -1,9 +1,10 @@
|
|||||||
package kscience.kmath.commons.random
|
package kscience.kmath.commons.random
|
||||||
|
|
||||||
import kscience.kmath.prob.RandomGenerator
|
import kscience.kmath.stat.RandomGenerator
|
||||||
|
|
||||||
public class CMRandomGeneratorWrapper(public val factory: (IntArray) -> RandomGenerator) :
|
public class CMRandomGeneratorWrapper(
|
||||||
org.apache.commons.math3.random.RandomGenerator {
|
public val factory: (IntArray) -> RandomGenerator,
|
||||||
|
) : org.apache.commons.math3.random.RandomGenerator {
|
||||||
private var generator: RandomGenerator = factory(intArrayOf())
|
private var generator: RandomGenerator = factory(intArrayOf())
|
||||||
|
|
||||||
public override fun nextBoolean(): Boolean = generator.nextBoolean()
|
public override fun nextBoolean(): Boolean = generator.nextBoolean()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package kscience.kmath.commons.expressions
|
package kscience.kmath.commons.expressions
|
||||||
|
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.*
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -8,33 +8,37 @@ import kotlin.test.assertEquals
|
|||||||
|
|
||||||
internal inline fun <R> diff(
|
internal inline fun <R> diff(
|
||||||
order: Int,
|
order: Int,
|
||||||
vararg parameters: Pair<String, Double>,
|
vararg parameters: Pair<Symbol, Double>,
|
||||||
block: DerivativeStructureField.() -> R
|
block: DerivativeStructureField.() -> R,
|
||||||
): R {
|
): R {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AutoDiffTest {
|
internal class AutoDiffTest {
|
||||||
|
private val x by symbol
|
||||||
|
private val y by symbol
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun derivativeStructureFieldTest() {
|
fun derivativeStructureFieldTest() {
|
||||||
val res = diff(3, "x" to 1.0, "y" to 1.0) {
|
val res: Double = diff(3, x to 1.0, y to 1.0) {
|
||||||
val x by variable
|
val x = bind(x)//by binding()
|
||||||
val y = variable("y")
|
val y = symbol("y")
|
||||||
val z = x * (-sin(x * y) + y)
|
val z = x * (-sin(x * y) + y)
|
||||||
z.deriv("x")
|
z.derivative(x)
|
||||||
}
|
}
|
||||||
|
println(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun autoDifTest() {
|
fun autoDifTest() {
|
||||||
val f = DiffExpression {
|
val f = DerivativeStructureExpression {
|
||||||
val x by variable
|
val x by binding()
|
||||||
val y by variable
|
val y by binding()
|
||||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(10.0, f("x" to 1.0, "y" to 2.0))
|
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||||
assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0))
|
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -0,0 +1,66 @@
|
|||||||
|
package kscience.kmath.commons.optimization
|
||||||
|
|
||||||
|
import kscience.kmath.commons.expressions.DerivativeStructureExpression
|
||||||
|
import kscience.kmath.expressions.symbol
|
||||||
|
import kscience.kmath.stat.Distribution
|
||||||
|
import kscience.kmath.stat.Fitting
|
||||||
|
import kscience.kmath.stat.RandomGenerator
|
||||||
|
import kscience.kmath.stat.normal
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import org.junit.jupiter.api.Test
|
||||||
|
import kotlin.math.pow
|
||||||
|
|
||||||
|
internal class OptimizeTest {
|
||||||
|
val x by symbol
|
||||||
|
val y by symbol
|
||||||
|
|
||||||
|
val normal = DerivativeStructureExpression {
|
||||||
|
exp(-bind(x).pow(2) / 2) + exp(-bind(y).pow(2) / 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testGradientOptimization() {
|
||||||
|
val result = normal.optimize(x, y) {
|
||||||
|
initialGuess(x to 1.0, y to 1.0)
|
||||||
|
//no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function
|
||||||
|
}
|
||||||
|
println(result.point)
|
||||||
|
println(result.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSimplexOptimization() {
|
||||||
|
val result = normal.optimize(x, y) {
|
||||||
|
initialGuess(x to 1.0, y to 1.0)
|
||||||
|
simplexSteps(x to 2.0, y to 0.5)
|
||||||
|
//this sets simplex optimizer
|
||||||
|
}
|
||||||
|
println(result.point)
|
||||||
|
println(result.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCmFit() {
|
||||||
|
val a by symbol
|
||||||
|
val b by symbol
|
||||||
|
val c by symbol
|
||||||
|
|
||||||
|
val sigma = 1.0
|
||||||
|
val generator = Distribution.normal(0.0, sigma)
|
||||||
|
val chain = generator.sample(RandomGenerator.default(112667))
|
||||||
|
val x = (1..100).map { it.toDouble() }
|
||||||
|
val y = x.map { it ->
|
||||||
|
it.pow(2) + it + 1 + chain.nextDouble()
|
||||||
|
}
|
||||||
|
val yErr = x.map { sigma }
|
||||||
|
val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
|
||||||
|
val cWithDefault = bindOrNull(c) ?: one
|
||||||
|
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
||||||
|
println(result)
|
||||||
|
println("Chi2/dof = ${result.value / (x.size - 3)}")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -7,12 +7,12 @@ The core features of KMath:
|
|||||||
- [buffers](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure
|
- [buffers](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure
|
||||||
- [expressions](src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions
|
- [expressions](src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions
|
||||||
- [domains](src/commonMain/kotlin/kscience/kmath/domains) : Domains
|
- [domains](src/commonMain/kotlin/kscience/kmath/domains) : Domains
|
||||||
- [autodif](src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt) : Automatic differentiation
|
- [autodif](src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
|
||||||
|
|
||||||
|
|
||||||
> #### Artifact:
|
> #### Artifact:
|
||||||
>
|
>
|
||||||
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-1`.
|
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
|
||||||
>
|
>
|
||||||
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
|
||||||
>
|
>
|
||||||
@ -22,25 +22,28 @@ The core features of KMath:
|
|||||||
>
|
>
|
||||||
> ```gradle
|
> ```gradle
|
||||||
> repositories {
|
> repositories {
|
||||||
|
> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
|
||||||
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
|
||||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||||
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||||
|
|
||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-1'
|
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
> **Gradle Kotlin DSL:**
|
> **Gradle Kotlin DSL:**
|
||||||
>
|
>
|
||||||
> ```kotlin
|
> ```kotlin
|
||||||
> repositories {
|
> repositories {
|
||||||
|
> maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
> maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
> maven("https://dl.bintray.com/mipt-npm/dev")
|
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
> maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation("kscience.kmath:kmath-core:0.2.0-dev-1")
|
> implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
|
@ -41,6 +41,6 @@ readme {
|
|||||||
feature(
|
feature(
|
||||||
id = "autodif",
|
id = "autodif",
|
||||||
description = "Automatic differentiation",
|
description = "Automatic differentiation",
|
||||||
ref = "src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt"
|
ref = "src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt"
|
||||||
)
|
)
|
||||||
}
|
}
|
@ -0,0 +1,39 @@
|
|||||||
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An expression that provides derivatives
|
||||||
|
*/
|
||||||
|
public interface DifferentiableExpression<T> : Expression<T>{
|
||||||
|
public fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>?
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(orders: Map<Symbol, Int>): Expression<T> =
|
||||||
|
derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided")
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||||
|
derivative(mapOf(*orders))
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||||
|
derivative(StringSymbol(name) to 1)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A [DifferentiableExpression] that defines only first derivatives
|
||||||
|
*/
|
||||||
|
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
||||||
|
|
||||||
|
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||||
|
|
||||||
|
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>? {
|
||||||
|
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
|
||||||
|
return derivativeOrNull(dSymbol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
||||||
|
*/
|
||||||
|
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
|
||||||
|
public fun process(function: A.() -> I): DifferentiableExpression<T>
|
||||||
|
}
|
@ -1,6 +1,25 @@
|
|||||||
package kscience.kmath.expressions
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
import kscience.kmath.operations.Algebra
|
import kscience.kmath.operations.Algebra
|
||||||
|
import kotlin.jvm.JvmName
|
||||||
|
import kotlin.properties.ReadOnlyProperty
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A marker interface for a symbol. A symbol mus have an identity
|
||||||
|
*/
|
||||||
|
public interface Symbol {
|
||||||
|
/**
|
||||||
|
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
|
||||||
|
*/
|
||||||
|
public val identity: String
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A [Symbol] with a [String] identity
|
||||||
|
*/
|
||||||
|
public inline class StringSymbol(override val identity: String) : Symbol {
|
||||||
|
override fun toString(): String = identity
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An elementary function that could be invoked on a map of arguments
|
* An elementary function that could be invoked on a map of arguments
|
||||||
@ -12,30 +31,69 @@ public fun interface Expression<T> {
|
|||||||
* @param arguments the map of arguments.
|
* @param arguments the map of arguments.
|
||||||
* @return the value.
|
* @return the value.
|
||||||
*/
|
*/
|
||||||
public operator fun invoke(arguments: Map<String, T>): T
|
public operator fun invoke(arguments: Map<Symbol, T>): T
|
||||||
|
|
||||||
public companion object
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invoke an expression without parameters
|
||||||
|
*/
|
||||||
|
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
||||||
|
//This method exists to avoid resolution ambiguity of vararg methods
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls this expression from arguments.
|
* Calls this expression from arguments.
|
||||||
*
|
*
|
||||||
* @param pairs the pair of arguments' names to values.
|
* @param pairs the pair of arguments' names to values.
|
||||||
* @return the value.
|
* @return the value.
|
||||||
*/
|
*/
|
||||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs))
|
@JvmName("callBySymbol")
|
||||||
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
|
||||||
|
|
||||||
|
@JvmName("callByString")
|
||||||
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||||
|
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for expression construction
|
* A context for expression construction
|
||||||
|
*
|
||||||
|
* @param T type of the constants for the expression
|
||||||
|
* @param E type of the actual expression state
|
||||||
*/
|
*/
|
||||||
public interface ExpressionAlgebra<T, E> : Algebra<E> {
|
public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Introduce a variable into expression context
|
* Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
|
||||||
*/
|
*/
|
||||||
public fun variable(name: String, default: T? = null): E
|
public fun bindOrNull(symbol: Symbol): E?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind a string to a context using [StringSymbol]
|
||||||
|
*/
|
||||||
|
override fun symbol(value: String): E = bind(StringSymbol(value))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A constant expression which does not depend on arguments
|
* A constant expression which does not depend on arguments
|
||||||
*/
|
*/
|
||||||
public fun const(value: T): E
|
public fun const(value: T): E
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind a given [Symbol] to this context variable and produce context-specific object.
|
||||||
|
*/
|
||||||
|
public fun <T, E> ExpressionAlgebra<T, E>.bind(symbol: Symbol): E =
|
||||||
|
bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A delegate to create a symbol with a string identity in this scope
|
||||||
|
*/
|
||||||
|
public val symbol: ReadOnlyProperty<Any?, StringSymbol> = ReadOnlyProperty { thisRef, property ->
|
||||||
|
StringSymbol(property.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Bind a symbol by name inside the [ExpressionAlgebra]
|
||||||
|
*/
|
||||||
|
public fun <T, E> ExpressionAlgebra<T, E>.binding(): ReadOnlyProperty<Any?, E> = ReadOnlyProperty { _, property ->
|
||||||
|
bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
|
||||||
|
}
|
@ -2,67 +2,43 @@ package kscience.kmath.expressions
|
|||||||
|
|
||||||
import kscience.kmath.operations.*
|
import kscience.kmath.operations.*
|
||||||
|
|
||||||
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
|
||||||
Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
|
||||||
context.unaryOperation(name, expr.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalBinaryOperation<T>(
|
|
||||||
val context: Algebra<T>,
|
|
||||||
val name: String,
|
|
||||||
val first: Expression<T>,
|
|
||||||
val second: Expression<T>
|
|
||||||
) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
|
||||||
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T =
|
|
||||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = value
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class FunctionalConstProductExpression<T>(
|
|
||||||
val context: Space<T>,
|
|
||||||
private val expr: Expression<T>,
|
|
||||||
val const: Number
|
|
||||||
) : Expression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context class for [Expression] construction.
|
* A context class for [Expression] construction.
|
||||||
*
|
*
|
||||||
* @param algebra The algebra to provide for Expressions built.
|
* @param algebra The algebra to provide for Expressions built.
|
||||||
*/
|
*/
|
||||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
|
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
||||||
ExpressionAlgebra<T, Expression<T>> {
|
public val algebra: A,
|
||||||
|
) : ExpressionAlgebra<T, Expression<T>> {
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of constant expression which does not depend on arguments.
|
* Builds an Expression of constant expression which does not depend on arguments.
|
||||||
*/
|
*/
|
||||||
public override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
public override fun const(value: T): Expression<T> = Expression { value }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression to access a variable.
|
* Builds an Expression to access a variable.
|
||||||
*/
|
*/
|
||||||
public override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
public override fun bindOrNull(symbol: Symbol): Expression<T>? = Expression { arguments ->
|
||||||
|
arguments[symbol] ?: error("Argument not found: $symbol")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||||
*/
|
*/
|
||||||
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
public override fun binaryOperation(
|
||||||
FunctionalBinaryOperation(algebra, operation, left, right)
|
operation: String,
|
||||||
|
left: Expression<T>,
|
||||||
|
right: Expression<T>,
|
||||||
|
): Expression<T> = Expression { arguments ->
|
||||||
|
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||||
*/
|
*/
|
||||||
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
|
||||||
FunctionalUnaryOperation(algebra, operation, arg)
|
algebra.unaryOperation(operation, arg.invoke(arguments))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -81,8 +57,9 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of multiplication of expression by number.
|
* Builds an Expression of multiplication of expression by number.
|
||||||
*/
|
*/
|
||||||
public override fun multiply(a: Expression<T>, k: Number): Expression<T> =
|
public override fun multiply(a: Expression<T>, k: Number): Expression<T> = Expression { arguments ->
|
||||||
FunctionalConstProductExpression(algebra, a, k)
|
algebra.multiply(a.invoke(arguments), k)
|
||||||
|
}
|
||||||
|
|
||||||
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||||
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
||||||
@ -118,8 +95,8 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
|
|||||||
}
|
}
|
||||||
|
|
||||||
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||||
FunctionalExpressionRing<T, A>(algebra),
|
FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>
|
||||||
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
|
where A : Field<T>, A : NumericAlgebra<T> {
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of division an expression by another one.
|
* Builds an Expression of division an expression by another one.
|
||||||
*/
|
*/
|
||||||
|
@ -0,0 +1,395 @@
|
|||||||
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.operations.*
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Implementation of backward-mode automatic differentiation.
|
||||||
|
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
public open class AutoDiffValue<out T>(public val value: T)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents result of [simpleAutoDiff] call.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @param value the value of result.
|
||||||
|
* @property simpleAutoDiff The mapping of differentiated variables to their derivatives.
|
||||||
|
* @property context The field over [T].
|
||||||
|
*/
|
||||||
|
public class DerivationResult<T : Any>(
|
||||||
|
public val value: T,
|
||||||
|
private val derivativeValues: Map<String, T>,
|
||||||
|
public val context: Field<T>,
|
||||||
|
) {
|
||||||
|
/**
|
||||||
|
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
||||||
|
*/
|
||||||
|
public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the divergence.
|
||||||
|
*/
|
||||||
|
public fun div(): T = context { sum(derivativeValues.values) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the gradient for variables in given order.
|
||||||
|
*/
|
||||||
|
public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> {
|
||||||
|
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
||||||
|
return variables.map(::derivative).asBuffer()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
|
||||||
|
*
|
||||||
|
* The partial derivatives are placed in argument `d` variable
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* ```
|
||||||
|
* val x by symbol // define variable(s) and their values
|
||||||
|
* val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
||||||
|
* assertEquals(17.0, y.x) // the value of result (y)
|
||||||
|
* assertEquals(9.0, x.d) // dy/dx
|
||||||
|
* ```
|
||||||
|
*
|
||||||
|
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||||
|
* @return the result of differentiation.
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
|
): DerivationResult<T> {
|
||||||
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
|
return SimpleAutoDiffField(this, bindings).derivate(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
|
vararg bindings: Pair<Symbol, T>,
|
||||||
|
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
|
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents field in context of which functions can be derived.
|
||||||
|
*/
|
||||||
|
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||||
|
public val context: F,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||||
|
|
||||||
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
|
private var sp: Int = 0
|
||||||
|
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||||
|
* with respect to this variable.
|
||||||
|
*
|
||||||
|
* @param T the non-nullable type of value.
|
||||||
|
* @property value The value of this variable.
|
||||||
|
*/
|
||||||
|
private class AutoDiffVariableWithDerivative<T : Any>(
|
||||||
|
override val identity: String,
|
||||||
|
value: T,
|
||||||
|
var d: T,
|
||||||
|
) : AutoDiffValue<T>(value), Symbol {
|
||||||
|
override fun toString(): String = identity
|
||||||
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||||
|
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
||||||
|
|
||||||
|
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
||||||
|
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
||||||
|
|
||||||
|
private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
|
||||||
|
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
private fun runBackwardPass() {
|
||||||
|
while (sp > 0) {
|
||||||
|
val value = stack[--sp]
|
||||||
|
val block = stack[--sp] as F.(Any?) -> Unit
|
||||||
|
context.block(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
||||||
|
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||||
|
|
||||||
|
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A variable accessing inner state of derivatives.
|
||||||
|
* Use this value in inner builders to avoid creating additional derivative bindings.
|
||||||
|
*/
|
||||||
|
public var AutoDiffValue<T>.d: T
|
||||||
|
get() = getDerivative(this)
|
||||||
|
set(value) = setDerivative(this, value)
|
||||||
|
|
||||||
|
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs update of derivative after the rest of the formula in the back-pass.
|
||||||
|
*
|
||||||
|
* For example, implementation of `sin` function is:
|
||||||
|
*
|
||||||
|
* ```
|
||||||
|
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
||||||
|
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
||||||
|
// save block to stack for backward pass
|
||||||
|
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
||||||
|
stack[sp++] = block
|
||||||
|
stack[sp++] = value
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||||
|
val result = function()
|
||||||
|
result.d = context.one // computing derivative w.r.t result
|
||||||
|
runBackwardPass()
|
||||||
|
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overloads for Double constants
|
||||||
|
|
||||||
|
override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
||||||
|
b.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
|
||||||
|
|
||||||
|
override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||||
|
|
||||||
|
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
||||||
|
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||||
|
|
||||||
|
|
||||||
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
|
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { a.value + b.value }) { z ->
|
||||||
|
a.d += z.d
|
||||||
|
b.d += z.d
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { a.value * b.value }) { z ->
|
||||||
|
a.d += z.d * b.value
|
||||||
|
b.d += z.d * a.value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { a.value / b.value }) { z ->
|
||||||
|
a.d += z.d / b.value
|
||||||
|
b.d -= z.d * a.value / (b.value * b.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
|
||||||
|
derive(const { k.toDouble() * a.value }) { z ->
|
||||||
|
a.d += z.d * k.toDouble()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A constructs that creates a derivative structure with required order on-demand
|
||||||
|
*/
|
||||||
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
|
public val field: F,
|
||||||
|
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
|
) : FirstDerivativeExpression<T>() {
|
||||||
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
|
return SimpleAutoDiffField(field, arguments).function().value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||||
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
|
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
|
||||||
|
derivationResult.derivative(symbol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||||
|
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||||
|
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
|
||||||
|
return SimpleAutoDiffExpression(field, function)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
|
// x ^ 2
|
||||||
|
public fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
||||||
|
|
||||||
|
// x ^ 1/2
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
||||||
|
|
||||||
|
// x ^ y (const)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
|
x: AutoDiffValue<T>,
|
||||||
|
y: Double,
|
||||||
|
): AutoDiffValue<T> =
|
||||||
|
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
|
x: AutoDiffValue<T>,
|
||||||
|
y: Int,
|
||||||
|
): AutoDiffValue<T> = pow(x, y.toDouble())
|
||||||
|
|
||||||
|
// exp(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
|
||||||
|
|
||||||
|
// ln(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
|
||||||
|
|
||||||
|
// x ^ y (any)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
|
||||||
|
x: AutoDiffValue<T>,
|
||||||
|
y: AutoDiffValue<T>,
|
||||||
|
): AutoDiffValue<T> =
|
||||||
|
exp(y * ln(x))
|
||||||
|
|
||||||
|
// sin(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
||||||
|
|
||||||
|
// cos(x)
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { tan(x.value) }) { z ->
|
||||||
|
val c = cos(x.value)
|
||||||
|
x.d += z.d / (c * c)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { tanh(x.value) }) { z ->
|
||||||
|
val c = cosh(x.value)
|
||||||
|
x.d += z.d / (c * c)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
||||||
|
|
||||||
|
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
||||||
|
|
||||||
|
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
||||||
|
context: F,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : ExtendedField<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
|
||||||
|
// x ^ 2
|
||||||
|
public fun sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sqr(x)
|
||||||
|
|
||||||
|
// x ^ 1/2
|
||||||
|
public override fun sqrt(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sqrt(arg)
|
||||||
|
|
||||||
|
// x ^ y (const)
|
||||||
|
public override fun power(arg: AutoDiffValue<T>, pow: Number): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).pow(arg, pow.toDouble())
|
||||||
|
|
||||||
|
// exp(x)
|
||||||
|
public override fun exp(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).exp(arg)
|
||||||
|
|
||||||
|
// ln(x)
|
||||||
|
public override fun ln(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).ln(arg)
|
||||||
|
|
||||||
|
// x ^ y (any)
|
||||||
|
public fun pow(
|
||||||
|
x: AutoDiffValue<T>,
|
||||||
|
y: AutoDiffValue<T>,
|
||||||
|
): AutoDiffValue<T> = exp(y * ln(x))
|
||||||
|
|
||||||
|
// sin(x)
|
||||||
|
public override fun sin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sin(arg)
|
||||||
|
|
||||||
|
// cos(x)
|
||||||
|
public override fun cos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).cos(arg)
|
||||||
|
|
||||||
|
public override fun tan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).tan(arg)
|
||||||
|
|
||||||
|
public override fun asin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).asin(arg)
|
||||||
|
|
||||||
|
public override fun acos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).acos(arg)
|
||||||
|
|
||||||
|
public override fun atan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).atan(arg)
|
||||||
|
|
||||||
|
public override fun sinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).sinh(arg)
|
||||||
|
|
||||||
|
public override fun cosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).cosh(arg)
|
||||||
|
|
||||||
|
public override fun tanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).tanh(arg)
|
||||||
|
|
||||||
|
public override fun asinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).asinh(arg)
|
||||||
|
|
||||||
|
public override fun acosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).acosh(arg)
|
||||||
|
|
||||||
|
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
|
(this as SimpleAutoDiffField<T, F>).atanh(arg)
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.structures.BufferFactory
|
||||||
|
import kscience.kmath.structures.Structure2D
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An environment to easy transform indexed variables to symbols and back.
|
||||||
|
* TODO requires multi-receivers to be beutiful
|
||||||
|
*/
|
||||||
|
public interface SymbolIndexer {
|
||||||
|
public val symbols: List<Symbol>
|
||||||
|
public fun indexOf(symbol: Symbol): Int = symbols.indexOf(symbol)
|
||||||
|
|
||||||
|
public operator fun <T> List<T>.get(symbol: Symbol): T {
|
||||||
|
require(size == symbols.size) { "The input list size for indexer should be ${symbols.size} but $size found" }
|
||||||
|
return get(this@SymbolIndexer.indexOf(symbol))
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun <T> Array<T>.get(symbol: Symbol): T {
|
||||||
|
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
|
||||||
|
return get(this@SymbolIndexer.indexOf(symbol))
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun DoubleArray.get(symbol: Symbol): Double {
|
||||||
|
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
|
||||||
|
return get(this@SymbolIndexer.indexOf(symbol))
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun <T> Point<T>.get(symbol: Symbol): T {
|
||||||
|
require(size == symbols.size) { "The input buffer size for indexer should be ${symbols.size} but $size found" }
|
||||||
|
return get(this@SymbolIndexer.indexOf(symbol))
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun DoubleArray.toMap(): Map<Symbol, Double> {
|
||||||
|
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
|
||||||
|
return symbols.indices.associate { symbols[it] to get(it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun <T> Structure2D<T>.get(rowSymbol: Symbol, columnSymbol: Symbol): T =
|
||||||
|
get(indexOf(rowSymbol), indexOf(columnSymbol))
|
||||||
|
|
||||||
|
|
||||||
|
public fun <T> Map<Symbol, T>.toList(): List<T> = symbols.map { getValue(it) }
|
||||||
|
|
||||||
|
public fun <T> Map<Symbol, T>.toPoint(bufferFactory: BufferFactory<T>): Point<T> =
|
||||||
|
bufferFactory(symbols.size) { getValue(symbols[it]) }
|
||||||
|
|
||||||
|
public fun Map<Symbol, Double>.toDoubleArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) }
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline class SimpleSymbolIndexer(override val symbols: List<Symbol>) : SymbolIndexer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute the block with symbol indexer based on given symbol order
|
||||||
|
*/
|
||||||
|
public inline fun <R> withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R =
|
||||||
|
with(SimpleSymbolIndexer(symbols.toList()), block)
|
||||||
|
|
||||||
|
public inline fun <R> withSymbols(symbols: Collection<Symbol>, block: SymbolIndexer.() -> R): R =
|
||||||
|
with(SimpleSymbolIndexer(symbols.toList()), block)
|
@ -7,6 +7,7 @@ import kscience.kmath.operations.Space
|
|||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a functional expression with this [Space].
|
* Creates a functional expression with this [Space].
|
||||||
*/
|
*/
|
||||||
|
@ -1,266 +0,0 @@
|
|||||||
package kscience.kmath.misc
|
|
||||||
|
|
||||||
import kscience.kmath.linear.Point
|
|
||||||
import kscience.kmath.operations.*
|
|
||||||
import kscience.kmath.structures.asBuffer
|
|
||||||
import kotlin.contracts.InvocationKind
|
|
||||||
import kotlin.contracts.contract
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Implementation of backward-mode automatic differentiation.
|
|
||||||
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Differentiable variable with value and derivative of differentiation ([deriv]) result
|
|
||||||
* with respect to this variable.
|
|
||||||
*
|
|
||||||
* @param T the non-nullable type of value.
|
|
||||||
* @property value The value of this variable.
|
|
||||||
*/
|
|
||||||
public open class Variable<T : Any>(public val value: T)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents result of [deriv] call.
|
|
||||||
*
|
|
||||||
* @param T the non-nullable type of value.
|
|
||||||
* @param value the value of result.
|
|
||||||
* @property deriv The mapping of differentiated variables to their derivatives.
|
|
||||||
* @property context The field over [T].
|
|
||||||
*/
|
|
||||||
public class DerivationResult<T : Any>(
|
|
||||||
value: T,
|
|
||||||
public val deriv: Map<Variable<T>, T>,
|
|
||||||
public val context: Field<T>
|
|
||||||
) : Variable<T>(value) {
|
|
||||||
/**
|
|
||||||
* Returns derivative of [variable] or returns [Ring.zero] in [context].
|
|
||||||
*/
|
|
||||||
public fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the divergence.
|
|
||||||
*/
|
|
||||||
public fun div(): T = context { sum(deriv.values) }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the gradient for variables in given order.
|
|
||||||
*/
|
|
||||||
public fun grad(vararg variables: Variable<T>): Point<T> {
|
|
||||||
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
|
|
||||||
return variables.map(::deriv).asBuffer()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
|
|
||||||
*
|
|
||||||
* The partial derivatives are placed in argument `d` variable
|
|
||||||
*
|
|
||||||
* Example:
|
|
||||||
* ```
|
|
||||||
* val x = Variable(2) // define variable(s) and their values
|
|
||||||
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
|
|
||||||
* assertEquals(17.0, y.x) // the value of result (y)
|
|
||||||
* assertEquals(9.0, x.d) // dy/dx
|
|
||||||
* ```
|
|
||||||
*
|
|
||||||
* @param body the action in [AutoDiffField] context returning [Variable] to differentiate with respect to.
|
|
||||||
* @return the result of differentiation.
|
|
||||||
*/
|
|
||||||
public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
|
|
||||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
|
|
||||||
return (AutoDiffContext(this)) {
|
|
||||||
val result = body()
|
|
||||||
result.d = context.one // computing derivative w.r.t result
|
|
||||||
runBackwardPass()
|
|
||||||
DerivationResult(result.value, derivatives, this@deriv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents field in context of which functions can be derived.
|
|
||||||
*/
|
|
||||||
public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
|
|
||||||
public abstract val context: F
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A variable accessing inner state of derivatives.
|
|
||||||
* Use this value in inner builders to avoid creating additional derivative bindings.
|
|
||||||
*/
|
|
||||||
public abstract var Variable<T>.d: T
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performs update of derivative after the rest of the formula in the back-pass.
|
|
||||||
*
|
|
||||||
* For example, implementation of `sin` function is:
|
|
||||||
*
|
|
||||||
* ```
|
|
||||||
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
|
|
||||||
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
|
|
||||||
* }
|
|
||||||
* ```
|
|
||||||
*/
|
|
||||||
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public abstract fun variable(value: T): Variable<T>
|
|
||||||
|
|
||||||
public inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
|
|
||||||
|
|
||||||
// Overloads for Double constants
|
|
||||||
|
|
||||||
override operator fun Number.plus(b: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
|
|
||||||
b.d += z.d
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
|
|
||||||
|
|
||||||
override operator fun Number.minus(b: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
|
||||||
|
|
||||||
override operator fun Variable<T>.minus(b: Number): Variable<T> =
|
|
||||||
derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Automatic Differentiation context class.
|
|
||||||
*/
|
|
||||||
@PublishedApi
|
|
||||||
internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
|
||||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
|
||||||
private var sp: Int = 0
|
|
||||||
val derivatives: MutableMap<Variable<T>, T> = hashMapOf()
|
|
||||||
override val zero: Variable<T> get() = Variable(context.zero)
|
|
||||||
override val one: Variable<T> get() = Variable(context.one)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A variable coupled with its derivative. For internal use only
|
|
||||||
*/
|
|
||||||
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
|
|
||||||
|
|
||||||
override fun variable(value: T): Variable<T> =
|
|
||||||
VariableWithDeriv(value, context.zero)
|
|
||||||
|
|
||||||
override var Variable<T>.d: T
|
|
||||||
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
|
|
||||||
set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
|
|
||||||
// save block to stack for backward pass
|
|
||||||
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
|
|
||||||
stack[sp++] = block
|
|
||||||
stack[sp++] = value
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
fun runBackwardPass() {
|
|
||||||
while (sp > 0) {
|
|
||||||
val value = stack[--sp]
|
|
||||||
val block = stack[--sp] as F.(Any?) -> Unit
|
|
||||||
context.block(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic math (+, -, *, /)
|
|
||||||
|
|
||||||
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
|
|
||||||
a.d += z.d
|
|
||||||
b.d += z.d
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value * b.value }) { z ->
|
|
||||||
a.d += z.d * b.value
|
|
||||||
b.d += z.d * a.value
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value / b.value }) { z ->
|
|
||||||
a.d += z.d / b.value
|
|
||||||
b.d -= z.d * a.value / (b.value * b.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: Variable<T>, k: Number): Variable<T> = derive(variable { k.toDouble() * a.value }) { z ->
|
|
||||||
a.d += z.d * k.toDouble()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
|
||||||
|
|
||||||
// x ^ 2
|
|
||||||
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
|
|
||||||
|
|
||||||
// x ^ 1/2
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
|
|
||||||
|
|
||||||
// x ^ y (const)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
|
|
||||||
derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> =
|
|
||||||
pow(x, y.toDouble())
|
|
||||||
|
|
||||||
// exp(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
|
|
||||||
|
|
||||||
// ln(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
|
|
||||||
|
|
||||||
// x ^ y (any)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
|
|
||||||
exp(y * ln(x))
|
|
||||||
|
|
||||||
// sin(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
|
|
||||||
|
|
||||||
// cos(x)
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { tan(x.value) }) { z ->
|
|
||||||
val c = cos(x.value)
|
|
||||||
x.d += z.d / (c * c)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { tan(x.value) }) { z ->
|
|
||||||
val c = cosh(x.value)
|
|
||||||
x.d += z.d / (c * c)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
|
|
||||||
|
|
||||||
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: Variable<T>): Variable<T> =
|
|
||||||
derive(variable { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
|
|
||||||
|
|
@ -150,6 +150,7 @@ public data class Complex(val re: Double, val im: Double) : FieldElement<Complex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a complex number with real part equal to this real.
|
* Creates a complex number with real part equal to this real.
|
||||||
*
|
*
|
||||||
|
@ -6,19 +6,21 @@ import kscience.kmath.operations.RealField
|
|||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertFails
|
||||||
|
|
||||||
class ExpressionFieldTest {
|
class ExpressionFieldTest {
|
||||||
|
val x by symbol
|
||||||
@Test
|
@Test
|
||||||
fun testExpression() {
|
fun testExpression() {
|
||||||
val context = FunctionalExpressionField(RealField)
|
val context = FunctionalExpressionField(RealField)
|
||||||
|
|
||||||
val expression = context {
|
val expression = context {
|
||||||
val x = variable("x", 2.0)
|
val x by binding()
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression(x to 1.0), 4.0)
|
||||||
assertEquals(expression(), 9.0)
|
assertFails { expression()}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -26,33 +28,33 @@ class ExpressionFieldTest {
|
|||||||
val context = FunctionalExpressionField(ComplexField)
|
val context = FunctionalExpressionField(ComplexField)
|
||||||
|
|
||||||
val expression = context {
|
val expression = context {
|
||||||
val x = variable("x", Complex(2.0, 0.0))
|
val x = bind(x)
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
assertEquals(expression(x to Complex(1.0, 0.0)), Complex(4.0, 0.0))
|
||||||
assertEquals(expression(), Complex(9.0, 0.0))
|
//assertEquals(expression(), Complex(9.0, 0.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun separateContext() {
|
fun separateContext() {
|
||||||
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
|
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
|
||||||
val x = variable("x")
|
val x by binding()
|
||||||
return x * x + 2 * x + one
|
return x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
val expression = FunctionalExpressionField(RealField).expression()
|
val expression = FunctionalExpressionField(RealField).expression()
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression(x to 1.0), 4.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun valueExpression() {
|
fun valueExpression() {
|
||||||
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
|
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
|
||||||
val x = variable("x")
|
val x by binding()
|
||||||
x * x + 2 * x + one
|
x * x + 2 * x + one
|
||||||
}
|
}
|
||||||
|
|
||||||
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
val expression = FunctionalExpressionField(RealField).expressionBuilder()
|
||||||
assertEquals(expression("x" to 1.0), 4.0)
|
assertEquals(expression(x to 1.0), 4.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,285 @@
|
|||||||
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.E
|
||||||
|
import kotlin.math.PI
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
class SimpleAutoDiffTest {
|
||||||
|
|
||||||
|
fun dx(
|
||||||
|
xBinding: Pair<Symbol, Double>,
|
||||||
|
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||||
|
|
||||||
|
fun dxy(
|
||||||
|
xBinding: Pair<Symbol, Double>,
|
||||||
|
yBinding: Pair<Symbol, Double>,
|
||||||
|
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
|
||||||
|
body(bind(xBinding.first), bind(yBinding.first))
|
||||||
|
}
|
||||||
|
|
||||||
|
fun diff(block: SimpleAutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
|
||||||
|
return SimpleAutoDiffExpression(RealField, block)
|
||||||
|
}
|
||||||
|
|
||||||
|
val x by symbol
|
||||||
|
val y by symbol
|
||||||
|
val z by symbol
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlusX2() {
|
||||||
|
val y = RealField.simpleAutoDiff(x to 3.0) {
|
||||||
|
// diff w.r.t this x at 3
|
||||||
|
val x = bind(x)
|
||||||
|
x + x
|
||||||
|
}
|
||||||
|
assertEquals(6.0, y.value) // y = x + x = 6
|
||||||
|
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlusX2Expr() {
|
||||||
|
val expr = diff {
|
||||||
|
val x = bind(x)
|
||||||
|
x + x
|
||||||
|
}
|
||||||
|
assertEquals(6.0, expr(x to 3.0)) // y = x + x = 6
|
||||||
|
assertEquals(2.0, expr.derivative(x)(x to 3.0)) // dy/dx = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlus() {
|
||||||
|
// two variables
|
||||||
|
val z = RealField.simpleAutoDiff(x to 2.0, y to 3.0) {
|
||||||
|
val x = bind(x)
|
||||||
|
val y = bind(y)
|
||||||
|
x + y
|
||||||
|
}
|
||||||
|
assertEquals(5.0, z.value) // z = x + y = 5
|
||||||
|
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
|
||||||
|
assertEquals(1.0, z.derivative(y)) // dz/dy = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMinus() {
|
||||||
|
// two variables
|
||||||
|
val z = RealField.simpleAutoDiff(x to 7.0, y to 3.0) {
|
||||||
|
val x = bind(x)
|
||||||
|
val y = bind(y)
|
||||||
|
|
||||||
|
x - y
|
||||||
|
}
|
||||||
|
assertEquals(4.0, z.value) // z = x - y = 4
|
||||||
|
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
|
||||||
|
assertEquals(-1.0, z.derivative(y)) // dz/dy = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testMulX2() {
|
||||||
|
val y = dx(x to 3.0) { x ->
|
||||||
|
// diff w.r.t this x at 3
|
||||||
|
x * x
|
||||||
|
}
|
||||||
|
assertEquals(9.0, y.value) // y = x * x = 9
|
||||||
|
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqr() {
|
||||||
|
val y = dx(x to 3.0) { x -> sqr(x) }
|
||||||
|
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
||||||
|
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrSqr() {
|
||||||
|
val y = dx(x to 2.0) { x -> sqr(sqr(x)) }
|
||||||
|
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
||||||
|
assertEquals(32.0, y.derivative(x)) // dy/dx = 4 * x^3 = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testX3() {
|
||||||
|
val y = dx(x to 2.0) { x ->
|
||||||
|
// diff w.r.t this x at 2
|
||||||
|
x * x * x
|
||||||
|
}
|
||||||
|
assertEquals(8.0, y.value) // y = x * x * x = 8
|
||||||
|
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x * x = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDiv() {
|
||||||
|
val z = dxy(x to 5.0, y to 2.0) { x, y ->
|
||||||
|
x / y
|
||||||
|
}
|
||||||
|
assertEquals(2.5, z.value) // z = x / y = 2.5
|
||||||
|
assertEquals(0.5, z.derivative(x)) // dz/dx = 1 / y = 0.5
|
||||||
|
assertEquals(-1.25, z.derivative(y)) // dz/dy = -x / y^2 = -1.25
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPow3() {
|
||||||
|
val y = dx(x to 2.0) { x ->
|
||||||
|
// diff w.r.t this x at 2
|
||||||
|
pow(x, 3)
|
||||||
|
}
|
||||||
|
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
||||||
|
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x ^ 2 = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPowFull() {
|
||||||
|
val z = dxy(x to 2.0, y to 3.0) { x, y ->
|
||||||
|
pow(x, y)
|
||||||
|
}
|
||||||
|
assertApprox(8.0, z.value) // z = x ^ y = 8
|
||||||
|
assertApprox(12.0, z.derivative(x)) // dz/dx = y * x ^ (y - 1) = 12
|
||||||
|
assertApprox(8.0 * kotlin.math.ln(2.0), z.derivative(y)) // dz/dy = x ^ y * ln(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testFromPaper() {
|
||||||
|
val y = dx(x to 3.0) { x -> 2 * x + x * x * x }
|
||||||
|
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
||||||
|
assertEquals(29.0, y.derivative(x)) // dy/dx = 2 + 3 * x * x = 29
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testInnerVariable() {
|
||||||
|
val y = dx(x to 1.0) { x ->
|
||||||
|
const(1.0) * x
|
||||||
|
}
|
||||||
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
|
assertEquals(1.0, y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLongChain() {
|
||||||
|
val n = 10_000
|
||||||
|
val y = dx(x to 1.0) { x ->
|
||||||
|
var res = const(1.0)
|
||||||
|
for (i in 1..n) res *= x
|
||||||
|
res
|
||||||
|
}
|
||||||
|
assertEquals(1.0, y.value) // y = x ^ n = 1
|
||||||
|
assertEquals(n.toDouble(), y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testExample() {
|
||||||
|
val y = dx(x to 2.0) { x -> sqr(x) + 5 * x + 3 }
|
||||||
|
assertEquals(17.0, y.value) // the value of result (y)
|
||||||
|
assertEquals(9.0, y.derivative(x)) // dy/dx
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSqrt() {
|
||||||
|
val y = dx(x to 16.0) { x -> sqrt(x) }
|
||||||
|
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
||||||
|
assertEquals(1.0 / 8, y.derivative(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSin() {
|
||||||
|
val y = dx(x to PI / 6.0) { x -> sin(x) }
|
||||||
|
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
||||||
|
assertApprox(sqrt(3.0) / 2, y.derivative(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCos() {
|
||||||
|
val y = dx(x to PI / 6) { x -> cos(x) }
|
||||||
|
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
|
||||||
|
assertApprox(-0.5, y.derivative(x)) // dy/dx = -sin(pi/6) = -0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTan() {
|
||||||
|
val y = dx(x to PI / 6) { x -> tan(x) }
|
||||||
|
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
|
||||||
|
assertApprox(4.0 / 3.0, y.derivative(x)) // dy/dx = sec(pi/6)^2 = 4/3
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAsin() {
|
||||||
|
val y = dx(x to PI / 6) { x -> asin(x) }
|
||||||
|
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
|
||||||
|
assertApprox(6.0 / sqrt(36 - PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(36-pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAcos() {
|
||||||
|
val y = dx(x to PI / 6) { x -> acos(x) }
|
||||||
|
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
|
||||||
|
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAtan() {
|
||||||
|
val y = dx(x to PI / 6) { x -> atan(x) }
|
||||||
|
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
|
||||||
|
assertApprox(36.0 / (36.0 + PI * PI), y.derivative(x)) // dy/dx = 36/(36+pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSinh() {
|
||||||
|
val y = dx(x to 0.0) { x -> sinh(x) }
|
||||||
|
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
|
||||||
|
assertApprox(kotlin.math.cosh(0.0), y.derivative(x)) // dy/dx = cosh(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCosh() {
|
||||||
|
val y = dx(x to 0.0) { x -> cosh(x) }
|
||||||
|
assertApprox(1.0, y.value) //y = cosh(0)
|
||||||
|
assertApprox(0.0, y.derivative(x)) // dy/dx = sinh(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTanh() {
|
||||||
|
val y = dx(x to 1.0) { x -> tanh(x) }
|
||||||
|
assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6)
|
||||||
|
assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAsinh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> asinh(x) }
|
||||||
|
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
|
||||||
|
assertApprox(6.0 / sqrt(36 + PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(pi^2+36)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAcosh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> acosh(x) }
|
||||||
|
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
|
||||||
|
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAtanh() {
|
||||||
|
val y = dx(x to PI / 6) { x -> atanh(x) }
|
||||||
|
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
|
||||||
|
assertApprox(-36.0 / (PI * PI - 36.0), y.derivative(x)) // dy/dx = -36/(pi^2-36)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDivGrad() {
|
||||||
|
val res = dxy(x to 1.0, y to 2.0) { x, y -> x * x + y * y }
|
||||||
|
assertEquals(6.0, res.div())
|
||||||
|
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun assertApprox(a: Double, b: Double) {
|
||||||
|
if ((a - b) > 1e-10) assertEquals(a, b)
|
||||||
|
}
|
||||||
|
}
|
@ -1,261 +0,0 @@
|
|||||||
package kscience.kmath.misc
|
|
||||||
|
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
import kscience.kmath.structures.asBuffer
|
|
||||||
import kotlin.math.PI
|
|
||||||
import kotlin.math.pow
|
|
||||||
import kotlin.math.sqrt
|
|
||||||
import kotlin.test.Test
|
|
||||||
import kotlin.test.assertEquals
|
|
||||||
import kotlin.test.assertTrue
|
|
||||||
|
|
||||||
class AutoDiffTest {
|
|
||||||
inline fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
|
|
||||||
RealField.deriv(body)
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPlusX2() {
|
|
||||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
|
||||||
val y = deriv { x + x }
|
|
||||||
assertEquals(6.0, y.value) // y = x + x = 6
|
|
||||||
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPlus() {
|
|
||||||
// two variables
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = Variable(3.0)
|
|
||||||
val z = deriv { x + y }
|
|
||||||
assertEquals(5.0, z.value) // z = x + y = 5
|
|
||||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
|
||||||
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testMinus() {
|
|
||||||
// two variables
|
|
||||||
val x = Variable(7.0)
|
|
||||||
val y = Variable(3.0)
|
|
||||||
val z = deriv { x - y }
|
|
||||||
assertEquals(4.0, z.value) // z = x - y = 4
|
|
||||||
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
|
|
||||||
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testMulX2() {
|
|
||||||
val x = Variable(3.0) // diff w.r.t this x at 3
|
|
||||||
val y = deriv { x * x }
|
|
||||||
assertEquals(9.0, y.value) // y = x * x = 9
|
|
||||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSqr() {
|
|
||||||
val x = Variable(3.0)
|
|
||||||
val y = deriv { sqr(x) }
|
|
||||||
assertEquals(9.0, y.value) // y = x ^ 2 = 9
|
|
||||||
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSqrSqr() {
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = deriv { sqr(sqr(x)) }
|
|
||||||
assertEquals(16.0, y.value) // y = x ^ 4 = 16
|
|
||||||
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testX3() {
|
|
||||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
|
||||||
val y = deriv { x * x * x }
|
|
||||||
assertEquals(8.0, y.value) // y = x * x * x = 8
|
|
||||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testDiv() {
|
|
||||||
val x = Variable(5.0)
|
|
||||||
val y = Variable(2.0)
|
|
||||||
val z = deriv { x / y }
|
|
||||||
assertEquals(2.5, z.value) // z = x / y = 2.5
|
|
||||||
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
|
|
||||||
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPow3() {
|
|
||||||
val x = Variable(2.0) // diff w.r.t this x at 2
|
|
||||||
val y = deriv { pow(x, 3) }
|
|
||||||
assertEquals(8.0, y.value) // y = x ^ 3 = 8
|
|
||||||
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testPowFull() {
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = Variable(3.0)
|
|
||||||
val z = deriv { pow(x, y) }
|
|
||||||
assertApprox(8.0, z.value) // z = x ^ y = 8
|
|
||||||
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
|
|
||||||
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testFromPaper() {
|
|
||||||
val x = Variable(3.0)
|
|
||||||
val y = deriv { 2 * x + x * x * x }
|
|
||||||
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
|
|
||||||
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testInnerVariable() {
|
|
||||||
val x = Variable(1.0)
|
|
||||||
val y = deriv {
|
|
||||||
Variable(1.0) * x
|
|
||||||
}
|
|
||||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
|
||||||
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testLongChain() {
|
|
||||||
val n = 10_000
|
|
||||||
val x = Variable(1.0)
|
|
||||||
val y = deriv {
|
|
||||||
var res = Variable(1.0)
|
|
||||||
for (i in 1..n) res *= x
|
|
||||||
res
|
|
||||||
}
|
|
||||||
assertEquals(1.0, y.value) // y = x ^ n = 1
|
|
||||||
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testExample() {
|
|
||||||
val x = Variable(2.0)
|
|
||||||
val y = deriv { sqr(x) + 5 * x + 3 }
|
|
||||||
assertEquals(17.0, y.value) // the value of result (y)
|
|
||||||
assertEquals(9.0, y.deriv(x)) // dy/dx
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSqrt() {
|
|
||||||
val x = Variable(16.0)
|
|
||||||
val y = deriv { sqrt(x) }
|
|
||||||
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
|
|
||||||
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSin() {
|
|
||||||
val x = Variable(PI / 6.0)
|
|
||||||
val y = deriv { sin(x) }
|
|
||||||
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
|
|
||||||
assertApprox(sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testCos() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { cos(x) }
|
|
||||||
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
|
|
||||||
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(pi/6) = -0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testTan() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { tan(x) }
|
|
||||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
|
|
||||||
assertApprox(4.0 / 3.0, y.deriv(x)) // dy/dx = sec(pi/6)^2 = 4/3
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAsin() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { asin(x) }
|
|
||||||
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
|
|
||||||
assertApprox(6.0 / sqrt(36 - PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(36-pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAcos() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { acos(x) }
|
|
||||||
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
|
|
||||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAtan() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { atan(x) }
|
|
||||||
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
|
|
||||||
assertApprox(36.0 / (36.0 + PI * PI), y.deriv(x)) // dy/dx = 36/(36+pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testSinh() {
|
|
||||||
val x = Variable(0.0)
|
|
||||||
val y = deriv { sinh(x) }
|
|
||||||
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
|
|
||||||
assertApprox(kotlin.math.cosh(0.0), y.deriv(x)) // dy/dx = cosh(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testCosh() {
|
|
||||||
val x = Variable(0.0)
|
|
||||||
val y = deriv { cosh(x) }
|
|
||||||
assertApprox(1.0, y.value) //y = cosh(0)
|
|
||||||
assertApprox(0.0, y.deriv(x)) // dy/dx = sinh(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testTanh() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { tanh(x) }
|
|
||||||
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
|
|
||||||
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.deriv(x)) // dy/dx = sech(pi/6)^2
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAsinh() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { asinh(x) }
|
|
||||||
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
|
|
||||||
assertApprox(6.0 / sqrt(36 + PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(pi^2+36)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAcosh() {
|
|
||||||
val x = Variable(PI / 6)
|
|
||||||
val y = deriv { acosh(x) }
|
|
||||||
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
|
|
||||||
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testAtanh() {
|
|
||||||
val x = Variable(PI / 6.0)
|
|
||||||
val y = deriv { atanh(x) }
|
|
||||||
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
|
|
||||||
assertApprox(-36.0 / (PI * PI - 36.0), y.deriv(x)) // dy/dx = -36/(pi^2-36)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testDivGrad() {
|
|
||||||
val x = Variable(1.0)
|
|
||||||
val y = Variable(2.0)
|
|
||||||
val res = deriv { x * x + y * y }
|
|
||||||
assertEquals(6.0, res.div())
|
|
||||||
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun assertApprox(a: Double, b: Double) {
|
|
||||||
if ((a - b) > 1e-10) assertEquals(a, b)
|
|
||||||
}
|
|
||||||
}
|
|
@ -8,13 +8,13 @@ import kotlin.contracts.contract
|
|||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
// TODO make `inline`, when KT-41771 gets fixed
|
|
||||||
/**
|
/**
|
||||||
* Polynomial coefficients without fixation on specific context they are applied to
|
* Polynomial coefficients without fixation on specific context they are applied to
|
||||||
* @param coefficients constant is the leftmost coefficient
|
* @param coefficients constant is the leftmost coefficient
|
||||||
*/
|
*/
|
||||||
public inline class Polynomial<T : Any>(public val coefficients: List<T>)
|
public inline class Polynomial<T : Any>(public val coefficients: List<T>)
|
||||||
|
|
||||||
|
@Suppress("FunctionName")
|
||||||
public fun <T : Any> Polynomial(vararg coefficients: T): Polynomial<T> = Polynomial(coefficients.toList())
|
public fun <T : Any> Polynomial(vararg coefficients: T): Polynomial<T> = Polynomial(coefficients.toList())
|
||||||
|
|
||||||
public fun Polynomial<Double>.value(): Double = coefficients.reduceIndexed { index, acc, d -> acc + d.pow(index) }
|
public fun Polynomial<Double>.value(): Double = coefficients.reduceIndexed { index, acc, d -> acc + d.pow(index) }
|
||||||
@ -33,14 +33,6 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring
|
|||||||
res
|
res
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Represent a polynomial as a context-dependent function
|
|
||||||
*/
|
|
||||||
public fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, C, T> =
|
|
||||||
object : MathFunction<T, C, T> {
|
|
||||||
override fun C.invoke(arg: T): T = value(this, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represent the polynomial as a regular context-less function
|
* Represent the polynomial as a regular context-less function
|
||||||
*/
|
*/
|
||||||
@ -49,7 +41,7 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.asFunction(ring: C): (T) -> T =
|
|||||||
/**
|
/**
|
||||||
* An algebra for polynomials
|
* An algebra for polynomials
|
||||||
*/
|
*/
|
||||||
public class PolynomialSpace<T : Any, C : Ring<T>>(public val ring: C) : Space<Polynomial<T>> {
|
public class PolynomialSpace<T : Any, C : Ring<T>>(private val ring: C) : Space<Polynomial<T>> {
|
||||||
public override val zero: Polynomial<T> = Polynomial(emptyList())
|
public override val zero: Polynomial<T> = Polynomial(emptyList())
|
||||||
|
|
||||||
public override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
public override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
package kscience.kmath.functions
|
|
||||||
|
|
||||||
import kscience.kmath.operations.Algebra
|
|
||||||
import kscience.kmath.operations.RealField
|
|
||||||
|
|
||||||
// TODO make fun interface when KT-41770 is fixed
|
|
||||||
/**
|
|
||||||
* A regular function that could be called only inside specific algebra context
|
|
||||||
* @param T source type
|
|
||||||
* @param C source algebra constraint
|
|
||||||
* @param R result type
|
|
||||||
*/
|
|
||||||
public /*fun*/ interface MathFunction<T, C : Algebra<T>, R> {
|
|
||||||
public operator fun C.invoke(arg: T): R
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <R> MathFunction<Double, RealField, R>.invoke(arg: Double): R = RealField.invoke(arg)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A suspendable function defined in algebraic context
|
|
||||||
*/
|
|
||||||
// TODO make fun interface, when the new JVM IR is enabled
|
|
||||||
public interface SuspendableMathFunction<T, C : Algebra<T>, R> {
|
|
||||||
public suspend operator fun C.invoke(arg: T): R
|
|
||||||
}
|
|
||||||
|
|
||||||
public suspend fun <R> SuspendableMathFunction<Double, RealField, R>.invoke(arg: Double): R = RealField.invoke(arg)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A parametric function with parameter
|
|
||||||
*/
|
|
||||||
public fun interface ParametricFunction<T, P, C : Algebra<T>> {
|
|
||||||
public operator fun C.invoke(arg: T, parameter: P): T
|
|
||||||
}
|
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kscience.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import kscience.kmath.chains.collect
|
import kscience.kmath.chains.collect
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kscience.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import kscience.kmath.chains.SimpleChain
|
import kscience.kmath.chains.SimpleChain
|
@ -0,0 +1,59 @@
|
|||||||
|
package kscience.kmath.stat
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.*
|
||||||
|
import kscience.kmath.operations.ExtendedField
|
||||||
|
import kscience.kmath.structures.Buffer
|
||||||
|
import kscience.kmath.structures.indices
|
||||||
|
import kotlin.math.pow
|
||||||
|
|
||||||
|
public object Fitting {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
|
*/
|
||||||
|
public fun <T : Any, I : Any, A> chiSquared(
|
||||||
|
autoDiff: AutoDiffProcessor<T, I, A>,
|
||||||
|
x: Buffer<T>,
|
||||||
|
y: Buffer<T>,
|
||||||
|
yErr: Buffer<T>,
|
||||||
|
model: A.(I) -> I,
|
||||||
|
): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||||
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
return autoDiff.process {
|
||||||
|
var sum = zero
|
||||||
|
x.indices.forEach {
|
||||||
|
val xValue = const(x[it])
|
||||||
|
val yValue = const(y[it])
|
||||||
|
val yErrValue = const(yErr[it])
|
||||||
|
val modelValue = model(xValue)
|
||||||
|
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
||||||
|
}
|
||||||
|
sum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives
|
||||||
|
*/
|
||||||
|
public fun chiSquared(
|
||||||
|
x: Buffer<Double>,
|
||||||
|
y: Buffer<Double>,
|
||||||
|
yErr: Buffer<Double>,
|
||||||
|
model: Expression<Double>,
|
||||||
|
xSymbol: Symbol = StringSymbol("x"),
|
||||||
|
): Expression<Double> {
|
||||||
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
return Expression { arguments ->
|
||||||
|
x.indices.sumByDouble {
|
||||||
|
val xValue = x[it]
|
||||||
|
val yValue = y[it]
|
||||||
|
val yErrValue = yErr[it]
|
||||||
|
val modifiedArgs = arguments + (xSymbol to xValue)
|
||||||
|
val modelValue = model(modifiedArgs)
|
||||||
|
((yValue - modelValue) / yErrValue).pow(2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,91 @@
|
|||||||
|
package kscience.kmath.stat
|
||||||
|
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import kscience.kmath.expressions.Expression
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
|
|
||||||
|
public interface OptimizationFeature
|
||||||
|
|
||||||
|
public class OptimizationResult<T>(
|
||||||
|
public val point: Map<Symbol, T>,
|
||||||
|
public val value: T,
|
||||||
|
public val features: Set<OptimizationFeature> = emptySet(),
|
||||||
|
) {
|
||||||
|
override fun toString(): String {
|
||||||
|
return "OptimizationResult(point=$point, value=$value)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun <T> OptimizationResult<T>.plus(
|
||||||
|
feature: OptimizationFeature,
|
||||||
|
): OptimizationResult<T> = OptimizationResult(point, value, features + feature)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A configuration builder for optimization problem
|
||||||
|
*/
|
||||||
|
public interface OptimizationProblem<T : Any> {
|
||||||
|
/**
|
||||||
|
* Define the initial guess for the optimization problem
|
||||||
|
*/
|
||||||
|
public fun initialGuess(map: Map<Symbol, T>): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set an objective function expression
|
||||||
|
*/
|
||||||
|
public fun expression(expression: Expression<T>): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set a differentiable expression as objective function as function and gradient provider
|
||||||
|
*/
|
||||||
|
public fun diffExpression(expression: DifferentiableExpression<T>): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the problem from previous optimization run
|
||||||
|
*/
|
||||||
|
public fun update(result: OptimizationResult<T>)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Make an optimization run
|
||||||
|
*/
|
||||||
|
public fun optimize(): OptimizationResult<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
|
||||||
|
public fun build(symbols: List<Symbol>): P
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
|
||||||
|
symbols: List<Symbol>,
|
||||||
|
block: P.() -> Unit,
|
||||||
|
): P = build(symbols).apply(block)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : OptimizationProblem<T>> Expression<T>.optimizeWith(
|
||||||
|
factory: OptimizationProblemFactory<T, F>,
|
||||||
|
vararg symbols: Symbol,
|
||||||
|
configuration: F.() -> Unit,
|
||||||
|
): OptimizationResult<T> {
|
||||||
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
|
val problem = factory(symbols.toList(),configuration)
|
||||||
|
problem.expression(this)
|
||||||
|
return problem.optimize()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.optimizeWith(
|
||||||
|
factory: OptimizationProblemFactory<T, F>,
|
||||||
|
vararg symbols: Symbol,
|
||||||
|
configuration: F.() -> Unit,
|
||||||
|
): OptimizationResult<T> {
|
||||||
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
|
val problem = factory(symbols.toList(), configuration)
|
||||||
|
problem.diffExpression(this)
|
||||||
|
return problem.optimize()
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kscience.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
|
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kscience.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import kscience.kmath.chains.ConstantChain
|
import kscience.kmath.chains.ConstantChain
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kotlinx.coroutines.CoroutineDispatcher
|
import kotlinx.coroutines.CoroutineDispatcher
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kscience.kmath.chains.Chain
|
import kscience.kmath.chains.Chain
|
||||||
import kscience.kmath.chains.SimpleChain
|
import kscience.kmath.chains.SimpleChain
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import org.apache.commons.rng.UniformRandomProvider
|
import org.apache.commons.rng.UniformRandomProvider
|
||||||
import org.apache.commons.rng.simple.RandomSource
|
import org.apache.commons.rng.simple.RandomSource
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kscience.kmath.chains.BlockingIntChain
|
import kscience.kmath.chains.BlockingIntChain
|
||||||
import kscience.kmath.chains.BlockingRealChain
|
import kscience.kmath.chains.BlockingRealChain
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kotlinx.coroutines.flow.take
|
import kotlinx.coroutines.flow.take
|
||||||
import kotlinx.coroutines.flow.toList
|
import kotlinx.coroutines.flow.toList
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kotlinx.coroutines.runBlocking
|
import kotlinx.coroutines.runBlocking
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
@ -1,4 +1,4 @@
|
|||||||
package kscience.kmath.prob
|
package kscience.kmath.stat
|
||||||
|
|
||||||
import kotlinx.coroutines.flow.drop
|
import kotlinx.coroutines.flow.drop
|
||||||
import kotlinx.coroutines.flow.first
|
import kotlinx.coroutines.flow.first
|
@ -10,8 +10,8 @@ pluginManagement {
|
|||||||
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
||||||
}
|
}
|
||||||
|
|
||||||
val toolsVersion = "0.6.1-dev-1.4.20-M1"
|
val toolsVersion = "0.6.4-dev-1.4.20-M2"
|
||||||
val kotlinVersion = "1.4.20-M1"
|
val kotlinVersion = "1.4.20-M2"
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
||||||
@ -34,11 +34,11 @@ include(
|
|||||||
":kmath-histograms",
|
":kmath-histograms",
|
||||||
":kmath-commons",
|
":kmath-commons",
|
||||||
":kmath-viktor",
|
":kmath-viktor",
|
||||||
":kmath-prob",
|
":kmath-stat",
|
||||||
":kmath-dimensions",
|
":kmath-dimensions",
|
||||||
":kmath-for-real",
|
":kmath-for-real",
|
||||||
":kmath-geometry",
|
":kmath-geometry",
|
||||||
":kmath-ast",
|
":kmath-ast",
|
||||||
":examples",
|
":kmath-ejml",
|
||||||
":kmath-ejml"
|
":examples"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user