Merge branch 'dev' into feature/quaternion

# Conflicts:
#	CHANGELOG.md
This commit is contained in:
Iaroslav Postovalov 2020-10-29 15:49:49 +07:00
commit 6d016c87f2
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
52 changed files with 1492 additions and 843 deletions

View File

@ -1 +1,3 @@
job("Build") { gradlew("openjdk:11", "build") } job("Build") {
gradlew("openjdk:11", "build")
}

View File

@ -8,6 +8,11 @@
- Automatic README generation for features (#139) - Automatic README generation for features (#139)
- Native support for `memory`, `core` and `dimensions` - Native support for `memory`, `core` and `dimensions`
- `kmath-ejml` to supply EJML SimpleMatrix wrapper. - `kmath-ejml` to supply EJML SimpleMatrix wrapper.
- A separate `Symbol` entity, which is used for global unbound symbol.
- A `Symbol` indexing scope.
- Basic optimization API for Commons-math.
- Chi squared optimization for array-like data in CM
- `Fitting` utility object in prob/stat
- Basic Quaternion vector support. - Basic Quaternion vector support.
### Changed ### Changed
@ -17,6 +22,8 @@
- `Polynomial` secondary constructor made function. - `Polynomial` secondary constructor made function.
- Kotlin version: 1.3.72 -> 1.4.20-M1 - Kotlin version: 1.3.72 -> 1.4.20-M1
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library. - `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
- Full autodiff refactoring based on `Symbol`
- `kmath-prob` renamed to `kmath-stat`
### Deprecated ### Deprecated

View File

@ -53,9 +53,7 @@ can be used for a wide variety of purposes from high performance calculations to
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/) * **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
to submit a feature request if you want something to be done first. to submit a feature request if you want something to be done first.
* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
## Planned features ## Planned features
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks. * **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
@ -101,7 +99,7 @@ can be used for a wide variety of purposes from high performance calculations to
> - [buffers](kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure > - [buffers](kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure
> - [expressions](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions > - [expressions](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions
> - [domains](kmath-core/src/commonMain/kotlin/kscience/kmath/domains) : Domains > - [domains](kmath-core/src/commonMain/kotlin/kscience/kmath/domains) : Domains
> - [autodif](kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt) : Automatic differentiation > - [autodif](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
<hr/> <hr/>
@ -117,6 +115,12 @@ can be used for a wide variety of purposes from high performance calculations to
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
<hr/> <hr/>
* ### [kmath-ejml](kmath-ejml)
>
>
> **Maturity**: EXPERIMENTAL
<hr/>
* ### [kmath-for-real](kmath-for-real) * ### [kmath-for-real](kmath-for-real)
> >
> >
@ -147,7 +151,7 @@ can be used for a wide variety of purposes from high performance calculations to
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
<hr/> <hr/>
* ### [kmath-prob](kmath-prob) * ### [kmath-stat](kmath-stat)
> >
> >
> **Maturity**: EXPERIMENTAL > **Maturity**: EXPERIMENTAL
@ -178,8 +182,8 @@ repositories{
} }
dependencies{ dependencies{
api("kscience.kmath:kmath-core:${kmathVersion}") api("kscience.kmath:kmath-core:0.2.0-dev-2")
//api("scientifik:kmath-core:${kmathVersion}") for 0.1.3 and earlier //api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
} }
``` ```
@ -197,4 +201,4 @@ with the same artifact names.
## Contributing ## Contributing
The project requires a lot of additional work. Please feel free to contribute in any way and propose new features. The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).

View File

@ -2,7 +2,7 @@ plugins {
id("ru.mipt.npm.project") id("ru.mipt.npm.project")
} }
val kmathVersion: String by extra("0.2.0-dev-2") val kmathVersion: String by extra("0.2.0-dev-3")
val bintrayRepo: String by extra("kscience") val bintrayRepo: String by extra("kscience")
val githubProject: String by extra("kmath") val githubProject: String by extra("kmath")
@ -25,3 +25,7 @@ subprojects {
readme { readme {
readmeTemplate = file("docs/templates/README-TEMPLATE.md") readmeTemplate = file("docs/templates/README-TEMPLATE.md")
} }
apiValidation{
validationDisabled = true
}

View File

@ -107,4 +107,4 @@ with the same artifact names.
## Contributing ## Contributing
The project requires a lot of additional work. Please feel free to contribute in any way and propose new features. The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).

View File

@ -23,7 +23,7 @@ dependencies {
implementation(project(":kmath-core")) implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines")) implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons")) implementation(project(":kmath-commons"))
implementation(project(":kmath-prob")) implementation(project(":kmath-stat"))
implementation(project(":kmath-viktor")) implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions")) implementation(project(":kmath-dimensions"))
implementation(project(":kmath-ejml")) implementation(project(":kmath-ejml"))

View File

@ -4,7 +4,7 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kscience.kmath.chains.BlockingRealChain import kscience.kmath.chains.BlockingRealChain
import kscience.kmath.prob.* import kscience.kmath.stat.*
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
import org.apache.commons.rng.simple.RandomSource import org.apache.commons.rng.simple.RandomSource
import java.time.Duration import java.time.Duration

View File

@ -3,9 +3,9 @@ package kscience.kmath.commons.prob
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kscience.kmath.chains.Chain import kscience.kmath.chains.Chain
import kscience.kmath.chains.collectWithState import kscience.kmath.chains.collectWithState
import kscience.kmath.prob.Distribution import kscience.kmath.stat.Distribution
import kscience.kmath.prob.RandomGenerator import kscience.kmath.stat.RandomGenerator
import kscience.kmath.prob.normal import kscience.kmath.stat.normal
private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0) private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)

View File

@ -6,8 +6,8 @@ import kscience.kmath.structures.complex
fun main() { fun main() {
// 2d element // 2d element
val element = NDElement.complex(2, 2) { index: IntArray -> val element = NDElement.complex(2, 2) { (i,j) ->
Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble()) Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble())
} }
println(element) println(element)

View File

@ -14,8 +14,8 @@ import kotlin.contracts.contract
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> { public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<String, T>) : NumericAlgebra<T> { private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value) override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T = override fun binaryOperation(operation: String, left: T, right: T): T =
@ -27,7 +27,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
error("Numeric nodes are not supported by $this") error("Numeric nodes are not supported by $this")
} }
override operator fun invoke(arguments: Map<String, T>): T = InnerAlgebra(arguments).evaluate(mst) override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
} }
/** /**
@ -37,7 +37,7 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
*/ */
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst( public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E, mstAlgebra: E,
block: E.() -> MST block: E.() -> MST,
): MstExpression<T> = MstExpression(this, mstAlgebra.block()) ): MstExpression<T> = MstExpression(this, mstAlgebra.block())
/** /**
@ -116,7 +116,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField( public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
block: MstExtendedField.() -> MST block: MstExtendedField.() -> MST,
): MstExpression<T> { ): MstExpression<T> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block) return algebra.mstInExtendedField(block)

View File

@ -25,7 +25,7 @@ internal class AsmBuilder<T> internal constructor(
private val classOfT: Class<*>, private val classOfT: Class<*>,
private val algebra: Algebra<T>, private val algebra: Algebra<T>,
private val className: String, private val className: String,
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit,
) { ) {
/** /**
* Internal classloader of [AsmBuilder] with alias to define class from byte array. * Internal classloader of [AsmBuilder] with alias to define class from byte array.
@ -379,22 +379,14 @@ internal class AsmBuilder<T> internal constructor(
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
* provided. * provided.
*/ */
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
load(invokeArgumentsVar, MAP_TYPE) load(invokeArgumentsVar, MAP_TYPE)
aconst(name) aconst(name)
if (defaultValue != null)
loadTConstant(defaultValue)
invokestatic( invokestatic(
MAP_INTRINSICS_TYPE.internalName, MAP_INTRINSICS_TYPE.internalName,
"getOrFail", "getOrFail",
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
Type.getMethodDescriptor(
OBJECT_TYPE,
MAP_TYPE,
OBJECT_TYPE,
*OBJECT_TYPE.wrapToArrayIf { defaultValue != null }),
false false
) )
@ -429,7 +421,7 @@ internal class AsmBuilder<T> internal constructor(
method: String, method: String,
descriptor: String, descriptor: String,
expectedArity: Int, expectedArity: Int,
opcode: Int = INVOKEINTERFACE opcode: Int = INVOKEINTERFACE,
) { ) {
run loop@{ run loop@{
repeat(expectedArity) { repeat(expectedArity) {

View File

@ -2,11 +2,12 @@
package kscience.kmath.asm.internal package kscience.kmath.asm.internal
import kscience.kmath.expressions.StringSymbol
import kscience.kmath.expressions.Symbol
/** /**
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present. * Gets value with given [key] or throws [NoSuchElementException] whenever it is not present.
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@JvmOverloads internal fun <V> Map<Symbol, V>.getOrFail(key: String): V = getValue(StringSymbol(key))
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
this[key] ?: default ?: error("Parameter not found: $key")

View File

@ -1,6 +1,5 @@
package kscience.kmath.asm package kscience.kmath.asm
import kscience.kmath.asm.compile
import kscience.kmath.ast.mstInField import kscience.kmath.ast.mstInField
import kscience.kmath.ast.mstInRing import kscience.kmath.ast.mstInRing
import kscience.kmath.ast.mstInSpace import kscience.kmath.ast.mstInSpace
@ -11,6 +10,7 @@ import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class TestAsmAlgebras { internal class TestAsmAlgebras {
@Test @Test
fun space() { fun space() {
val res1 = ByteRing.mstInSpace { val res1 = ByteRing.mstInSpace {

View File

@ -6,7 +6,7 @@ description = "Commons math binding for kmath"
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
api(project(":kmath-coroutines")) api(project(":kmath-coroutines"))
api(project(":kmath-prob")) api(project(":kmath-stat"))
// api(project(":kmath-functions")) api(project(":kmath-functions"))
api("org.apache.commons:commons-math3:3.6.1") api("org.apache.commons:commons-math3:3.6.1")
} }

View File

@ -0,0 +1,115 @@
package kscience.kmath.commons.expressions
import kscience.kmath.expressions.*
import kscience.kmath.operations.ExtendedField
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
/**
* A field over commons-math [DerivativeStructure].
*
* @property order The derivation order.
* @property bindings The map of bindings values. All bindings are considered free parameters
*/
public class DerivativeStructureField(
public val order: Int,
private val bindings: Map<Symbol, Double>
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
/**
* A class that implements both [DerivativeStructure] and a [Symbol]
*/
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
override val identity: String = symbol.identity
override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
}
/**
* Identity-based symbol bindings map
*/
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) ->
key.identity to DerivativeStructureSymbol(key, value)
}
override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value)
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
//public fun Number.const(): DerivativeStructure = const(toDouble())
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
return derivative(mapOf(parameter to order))
}
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double {
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
}
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
is Double -> a.multiply(k)
is Int -> a.multiply(k)
else -> a.multiply(k.toDouble())
}
public override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
public override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
public override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
public override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
public override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
public override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
public override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
public override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
public override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
public override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
public override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
public override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
public override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
public override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
public override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
is Double -> arg.pow(pow)
is Int -> arg.pow(pow)
else -> arg.pow(pow.toDouble())
}
public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
return DerivativeStructureExpression(function)
}
}
}
/**
* A constructs that creates a derivative structure with required order on-demand
*/
public class DerivativeStructureExpression(
public val function: DerivativeStructureField.() -> DerivativeStructure,
) : DifferentiableExpression<Double> {
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
DerivativeStructureField(0, arguments).function().value
/**
* Get the derivative expression with given orders
*/
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments ->
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
}
}

View File

@ -1,131 +0,0 @@
package kscience.kmath.commons.expressions
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.ExpressionAlgebra
import kscience.kmath.operations.ExtendedField
import kscience.kmath.operations.Field
import kscience.kmath.operations.invoke
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import kotlin.properties.ReadOnlyProperty
/**
* A field over commons-math [DerivativeStructure].
*
* @property order The derivation order.
* @property parameters The map of free parameters.
*/
public class DerivativeStructureField(
public val order: Int,
public val parameters: Map<String, Double>
) : ExtendedField<DerivativeStructure> {
public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) }
public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) }
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
}
public val variable: ReadOnlyProperty<Any?, DerivativeStructure> = ReadOnlyProperty { _, property ->
variables[property.name] ?: error("A variable with name ${property.name} does not exist")
}
public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
variables[name] ?: default ?: error("A variable with name $name does not exist")
public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
return deriv(mapOf(parName to order))
}
public fun DerivativeStructure.deriv(orders: Map<String, Int>): Double {
return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
}
public fun DerivativeStructure.deriv(vararg orders: Pair<String, Int>): Double = deriv(mapOf(*orders))
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
is Double -> a.multiply(k)
is Int -> a.multiply(k)
else -> a.multiply(k.toDouble())
}
public override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
public override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
public override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
public override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
public override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
public override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
public override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
public override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
public override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
public override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
public override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
public override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
public override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
public override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
public override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
is Double -> arg.pow(pow)
is Int -> arg.pow(pow)
else -> arg.pow(pow.toDouble())
}
public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
}
/**
* A constructs that creates a derivative structure with required order on-demand
*/
public class DiffExpression(public val function: DerivativeStructureField.() -> DerivativeStructure) :
Expression<Double> {
public override operator fun invoke(arguments: Map<String, Double>): Double = DerivativeStructureField(
0,
arguments
).function().value
/**
* Get the derivative expression with given orders
* TODO make result [DiffExpression]
*/
public fun derivative(orders: Map<String, Int>): Expression<Double> = Expression { arguments ->
(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().deriv(orders) }
}
//TODO add gradient and maybe other vector operators
}
public fun DiffExpression.derivative(vararg orders: Pair<String, Int>): Expression<Double> = derivative(mapOf(*orders))
public fun DiffExpression.derivative(name: String): Expression<Double> = derivative(name to 1)
/**
* A context for [DiffExpression] (not to be confused with [DerivativeStructure])
*/
public object DiffExpressionAlgebra : ExpressionAlgebra<Double, DiffExpression>, Field<DiffExpression> {
public override val zero: DiffExpression = DiffExpression { 0.0.const() }
public override val one: DiffExpression = DiffExpression { 1.0.const() }
public override fun variable(name: String, default: Double?): DiffExpression =
DiffExpression { variable(name, default?.const()) }
public override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) + b.function(this) }
public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k }
public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) * b.function(this) }
public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
DiffExpression { a.function(this) / b.function(this) }
}

View File

@ -0,0 +1,111 @@
package kscience.kmath.commons.optimization
import kscience.kmath.expressions.*
import kscience.kmath.stat.OptimizationFeature
import kscience.kmath.stat.OptimizationProblem
import kscience.kmath.stat.OptimizationProblemFactory
import kscience.kmath.stat.OptimizationResult
import org.apache.commons.math3.optim.*
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
import kotlin.reflect.KClass
public operator fun PointValuePair.component1(): DoubleArray = point
public operator fun PointValuePair.component2(): Double = value
public class CMOptimizationProblem(
override val symbols: List<Symbol>,
) : OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER)
public fun addOptimizationData(data: OptimizationData) {
optimizationData[data::class] = data
}
init {
addOptimizationData(MaxEval.unlimited())
}
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
public override fun initialGuess(map: Map<Symbol, Double>): Unit {
addOptimizationData(InitialGuess(map.toDoubleArray()))
}
public override fun expression(expression: Expression<Double>): Unit {
val objectiveFunction = ObjectiveFunction {
val args = it.toMap()
expression(args)
}
addOptimizationData(objectiveFunction)
}
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit {
expression(expression)
val gradientFunction = ObjectiveFunctionGradient {
val args = it.toMap()
DoubleArray(symbols.size) { index ->
expression.derivative(symbols[index])(args)
}
}
addOptimizationData(gradientFunction)
if (optimizatorBuilder == null) {
optimizatorBuilder = {
NonLinearConjugateGradientOptimizer(
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
convergenceChecker
)
}
}
}
public fun simplex(simplex: AbstractSimplex) {
addOptimizationData(simplex)
//Set optimization builder to simplex if it is not present
if (optimizatorBuilder == null) {
optimizatorBuilder = { SimplexOptimizer(convergenceChecker) }
}
}
public fun simplexSteps(steps: Map<Symbol, Double>) {
simplex(NelderMeadSimplex(steps.toDoubleArray()))
}
public fun goal(goalType: GoalType) {
addOptimizationData(goalType)
}
public fun optimizer(block: () -> MultivariateOptimizer) {
optimizatorBuilder = block
}
override fun update(result: OptimizationResult<Double>) {
initialGuess(result.point)
}
override fun optimize(): OptimizationResult<Double> {
val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
return OptimizationResult(point.toMap(), value, setOf(this))
}
public companion object : OptimizationProblemFactory<Double, CMOptimizationProblem> {
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
public const val DEFAULT_MAX_ITER: Int = 1000
override fun build(symbols: List<Symbol>): CMOptimizationProblem = CMOptimizationProblem(symbols)
}
}
public fun CMOptimizationProblem.initialGuess(vararg pairs: Pair<Symbol, Double>): Unit = initialGuess(pairs.toMap())
public fun CMOptimizationProblem.simplexSteps(vararg pairs: Pair<Symbol, Double>): Unit = simplexSteps(pairs.toMap())

View File

@ -0,0 +1,70 @@
package kscience.kmath.commons.optimization
import kscience.kmath.commons.expressions.DerivativeStructureField
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.Symbol
import kscience.kmath.stat.Fitting
import kscience.kmath.stat.OptimizationResult
import kscience.kmath.stat.optimizeWith
import kscience.kmath.structures.Buffer
import kscience.kmath.structures.asBuffer
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
/**
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
*/
public fun Fitting.chiSquared(
x: Buffer<Double>,
y: Buffer<Double>,
yErr: Buffer<Double>,
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
): DifferentiableExpression<Double> = chiSquared(DerivativeStructureField, x, y, yErr, model)
/**
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
*/
public fun Fitting.chiSquared(
x: Iterable<Double>,
y: Iterable<Double>,
yErr: Iterable<Double>,
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
): DifferentiableExpression<Double> = chiSquared(
DerivativeStructureField,
x.toList().asBuffer(),
y.toList().asBuffer(),
yErr.toList().asBuffer(),
model
)
/**
* Optimize expression without derivatives
*/
public fun Expression<Double>.optimize(
vararg symbols: Symbol,
configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
/**
* Optimize differentiable expression
*/
public fun DifferentiableExpression<Double>.optimize(
vararg symbols: Symbol,
configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
public fun DifferentiableExpression<Double>.minimize(
vararg startPoint: Pair<Symbol, Double>,
configuration: CMOptimizationProblem.() -> Unit = {},
): OptimizationResult<Double> {
require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = CMOptimizationProblem(startPoint.map { it.first }).apply(configuration)
problem.diffExpression(this)
problem.initialGuess(startPoint.toMap())
problem.goal(GoalType.MINIMIZE)
return problem.optimize()
}

View File

@ -1,9 +1,10 @@
package kscience.kmath.commons.random package kscience.kmath.commons.random
import kscience.kmath.prob.RandomGenerator import kscience.kmath.stat.RandomGenerator
public class CMRandomGeneratorWrapper(public val factory: (IntArray) -> RandomGenerator) : public class CMRandomGeneratorWrapper(
org.apache.commons.math3.random.RandomGenerator { public val factory: (IntArray) -> RandomGenerator,
) : org.apache.commons.math3.random.RandomGenerator {
private var generator: RandomGenerator = factory(intArrayOf()) private var generator: RandomGenerator = factory(intArrayOf())
public override fun nextBoolean(): Boolean = generator.nextBoolean() public override fun nextBoolean(): Boolean = generator.nextBoolean()

View File

@ -1,6 +1,6 @@
package kscience.kmath.commons.expressions package kscience.kmath.commons.expressions
import kscience.kmath.expressions.invoke import kscience.kmath.expressions.*
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.test.Test import kotlin.test.Test
@ -8,33 +8,37 @@ import kotlin.test.assertEquals
internal inline fun <R> diff( internal inline fun <R> diff(
order: Int, order: Int,
vararg parameters: Pair<String, Double>, vararg parameters: Pair<Symbol, Double>,
block: DerivativeStructureField.() -> R block: DerivativeStructureField.() -> R,
): R { ): R {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return DerivativeStructureField(order, mapOf(*parameters)).run(block) return DerivativeStructureField(order, mapOf(*parameters)).run(block)
} }
internal class AutoDiffTest { internal class AutoDiffTest {
private val x by symbol
private val y by symbol
@Test @Test
fun derivativeStructureFieldTest() { fun derivativeStructureFieldTest() {
val res = diff(3, "x" to 1.0, "y" to 1.0) { val res: Double = diff(3, x to 1.0, y to 1.0) {
val x by variable val x = bind(x)//by binding()
val y = variable("y") val y = symbol("y")
val z = x * (-sin(x * y) + y) val z = x * (-sin(x * y) + y)
z.deriv("x") z.derivative(x)
} }
println(res)
} }
@Test @Test
fun autoDifTest() { fun autoDifTest() {
val f = DiffExpression { val f = DerivativeStructureExpression {
val x by variable val x by binding()
val y by variable val y by binding()
x.pow(2) + 2 * x * y + y.pow(2) + 1 x.pow(2) + 2 * x * y + y.pow(2) + 1
} }
assertEquals(10.0, f("x" to 1.0, "y" to 2.0)) assertEquals(10.0, f(x to 1.0, y to 2.0))
assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0)) assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
} }
} }

View File

@ -0,0 +1,66 @@
package kscience.kmath.commons.optimization
import kscience.kmath.commons.expressions.DerivativeStructureExpression
import kscience.kmath.expressions.symbol
import kscience.kmath.stat.Distribution
import kscience.kmath.stat.Fitting
import kscience.kmath.stat.RandomGenerator
import kscience.kmath.stat.normal
import kscience.kmath.structures.asBuffer
import org.junit.jupiter.api.Test
import kotlin.math.pow
internal class OptimizeTest {
val x by symbol
val y by symbol
val normal = DerivativeStructureExpression {
exp(-bind(x).pow(2) / 2) + exp(-bind(y).pow(2) / 2)
}
@Test
fun testGradientOptimization() {
val result = normal.optimize(x, y) {
initialGuess(x to 1.0, y to 1.0)
//no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function
}
println(result.point)
println(result.value)
}
@Test
fun testSimplexOptimization() {
val result = normal.optimize(x, y) {
initialGuess(x to 1.0, y to 1.0)
simplexSteps(x to 2.0, y to 0.5)
//this sets simplex optimizer
}
println(result.point)
println(result.value)
}
@Test
fun testCmFit() {
val a by symbol
val b by symbol
val c by symbol
val sigma = 1.0
val generator = Distribution.normal(0.0, sigma)
val chain = generator.sample(RandomGenerator.default(112667))
val x = (1..100).map { it.toDouble() }
val y = x.map { it ->
it.pow(2) + it + 1 + chain.nextDouble()
}
val yErr = x.map { sigma }
val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
val cWithDefault = bindOrNull(c) ?: one
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
}
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
println(result)
println("Chi2/dof = ${result.value / (x.size - 3)}")
}
}

View File

@ -7,12 +7,12 @@ The core features of KMath:
- [buffers](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure - [buffers](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure
- [expressions](src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions - [expressions](src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions
- [domains](src/commonMain/kotlin/kscience/kmath/domains) : Domains - [domains](src/commonMain/kotlin/kscience/kmath/domains) : Domains
- [autodif](src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt) : Automatic differentiation - [autodif](src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
> #### Artifact: > #### Artifact:
> >
> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-1`. > This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
> >
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion) > Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
> >
@ -22,25 +22,28 @@ The core features of KMath:
> >
> ```gradle > ```gradle
> repositories { > repositories {
> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' } > maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
> maven { url 'https://dl.bintray.com/mipt-npm/dev' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' }
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' } > maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
> } > }
> >
> dependencies { > dependencies {
> implementation 'kscience.kmath:kmath-core:0.2.0-dev-1' > implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
> } > }
> ``` > ```
> **Gradle Kotlin DSL:** > **Gradle Kotlin DSL:**
> >
> ```kotlin > ```kotlin
> repositories { > repositories {
> maven("https://dl.bintray.com/kotlin/kotlin-eap")
> maven("https://dl.bintray.com/mipt-npm/kscience") > maven("https://dl.bintray.com/mipt-npm/kscience")
> maven("https://dl.bintray.com/mipt-npm/dev") > maven("https://dl.bintray.com/mipt-npm/dev")
> maven("https://dl.bintray.com/hotkeytlt/maven") > maven("https://dl.bintray.com/hotkeytlt/maven")
> } > }
> >
> dependencies { > dependencies {
> implementation("kscience.kmath:kmath-core:0.2.0-dev-1") > implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
> } > }
> ``` > ```

View File

@ -41,6 +41,6 @@ readme {
feature( feature(
id = "autodif", id = "autodif",
description = "Automatic differentiation", description = "Automatic differentiation",
ref = "src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt" ref = "src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt"
) )
} }

View File

@ -0,0 +1,39 @@
package kscience.kmath.expressions
/**
* An expression that provides derivatives
*/
public interface DifferentiableExpression<T> : Expression<T>{
public fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>?
}
public fun <T> DifferentiableExpression<T>.derivative(orders: Map<Symbol, Int>): Expression<T> =
derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided")
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
derivative(mapOf(*orders))
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
derivative(StringSymbol(name) to 1)
/**
* A [DifferentiableExpression] that defines only first derivatives
*/
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>? {
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
return derivativeOrNull(dSymbol)
}
}
/**
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
*/
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
public fun process(function: A.() -> I): DifferentiableExpression<T>
}

View File

@ -1,6 +1,25 @@
package kscience.kmath.expressions package kscience.kmath.expressions
import kscience.kmath.operations.Algebra import kscience.kmath.operations.Algebra
import kotlin.jvm.JvmName
import kotlin.properties.ReadOnlyProperty
/**
* A marker interface for a symbol. A symbol mus have an identity
*/
public interface Symbol {
/**
* Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
*/
public val identity: String
}
/**
* A [Symbol] with a [String] identity
*/
public inline class StringSymbol(override val identity: String) : Symbol {
override fun toString(): String = identity
}
/** /**
* An elementary function that could be invoked on a map of arguments * An elementary function that could be invoked on a map of arguments
@ -12,30 +31,69 @@ public fun interface Expression<T> {
* @param arguments the map of arguments. * @param arguments the map of arguments.
* @return the value. * @return the value.
*/ */
public operator fun invoke(arguments: Map<String, T>): T public operator fun invoke(arguments: Map<Symbol, T>): T
public companion object
} }
/**
* Invoke an expression without parameters
*/
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
//This method exists to avoid resolution ambiguity of vararg methods
/** /**
* Calls this expression from arguments. * Calls this expression from arguments.
* *
* @param pairs the pair of arguments' names to values. * @param pairs the pair of arguments' names to values.
* @return the value. * @return the value.
*/ */
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke(mapOf(*pairs)) @JvmName("callBySymbol")
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
@JvmName("callByString")
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
/** /**
* A context for expression construction * A context for expression construction
*
* @param T type of the constants for the expression
* @param E type of the actual expression state
*/ */
public interface ExpressionAlgebra<T, E> : Algebra<E> { public interface ExpressionAlgebra<in T, E> : Algebra<E> {
/** /**
* Introduce a variable into expression context * Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
*/ */
public fun variable(name: String, default: T? = null): E public fun bindOrNull(symbol: Symbol): E?
/**
* Bind a string to a context using [StringSymbol]
*/
override fun symbol(value: String): E = bind(StringSymbol(value))
/** /**
* A constant expression which does not depend on arguments * A constant expression which does not depend on arguments
*/ */
public fun const(value: T): E public fun const(value: T): E
} }
/**
* Bind a given [Symbol] to this context variable and produce context-specific object.
*/
public fun <T, E> ExpressionAlgebra<T, E>.bind(symbol: Symbol): E =
bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this")
/**
* A delegate to create a symbol with a string identity in this scope
*/
public val symbol: ReadOnlyProperty<Any?, StringSymbol> = ReadOnlyProperty { thisRef, property ->
StringSymbol(property.name)
}
/**
* Bind a symbol by name inside the [ExpressionAlgebra]
*/
public fun <T, E> ExpressionAlgebra<T, E>.binding(): ReadOnlyProperty<Any?, E> = ReadOnlyProperty { _, property ->
bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
}

View File

@ -2,67 +2,43 @@ package kscience.kmath.expressions
import kscience.kmath.operations.* import kscience.kmath.operations.*
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T =
context.unaryOperation(name, expr.invoke(arguments))
}
internal class FunctionalBinaryOperation<T>(
val context: Algebra<T>,
val name: String,
val first: Expression<T>,
val second: Expression<T>
) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T =
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
}
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T = value
}
internal class FunctionalConstProductExpression<T>(
val context: Space<T>,
private val expr: Expression<T>,
val const: Number
) : Expression<T> {
override operator fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
/** /**
* A context class for [Expression] construction. * A context class for [Expression] construction.
* *
* @param algebra The algebra to provide for Expressions built. * @param algebra The algebra to provide for Expressions built.
*/ */
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) : public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
ExpressionAlgebra<T, Expression<T>> { public val algebra: A,
) : ExpressionAlgebra<T, Expression<T>> {
/** /**
* Builds an Expression of constant expression which does not depend on arguments. * Builds an Expression of constant expression which does not depend on arguments.
*/ */
public override fun const(value: T): Expression<T> = FunctionalConstantExpression(value) public override fun const(value: T): Expression<T> = Expression { value }
/** /**
* Builds an Expression to access a variable. * Builds an Expression to access a variable.
*/ */
public override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default) public override fun bindOrNull(symbol: Symbol): Expression<T>? = Expression { arguments ->
arguments[symbol] ?: error("Argument not found: $symbol")
}
/** /**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/ */
public override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> = public override fun binaryOperation(
FunctionalBinaryOperation(algebra, operation, left, right) operation: String,
left: Expression<T>,
right: Expression<T>,
): Expression<T> = Expression { arguments ->
algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
}
/** /**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. * Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/ */
public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = public override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> = Expression { arguments ->
FunctionalUnaryOperation(algebra, operation, arg) algebra.unaryOperation(operation, arg.invoke(arguments))
}
} }
/** /**
@ -81,8 +57,9 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
/** /**
* Builds an Expression of multiplication of expression by number. * Builds an Expression of multiplication of expression by number.
*/ */
public override fun multiply(a: Expression<T>, k: Number): Expression<T> = public override fun multiply(a: Expression<T>, k: Number): Expression<T> = Expression { arguments ->
FunctionalConstProductExpression(algebra, a, k) algebra.multiply(a.invoke(arguments), k)
}
public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg) public operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg) public operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
@ -118,8 +95,8 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
} }
public open class FunctionalExpressionField<T, A>(algebra: A) : public open class FunctionalExpressionField<T, A>(algebra: A) :
FunctionalExpressionRing<T, A>(algebra), FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> { where A : Field<T>, A : NumericAlgebra<T> {
/** /**
* Builds an Expression of division an expression by another one. * Builds an Expression of division an expression by another one.
*/ */

View File

@ -0,0 +1,395 @@
package kscience.kmath.expressions
import kscience.kmath.linear.Point
import kscience.kmath.operations.*
import kscience.kmath.structures.asBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/*
* Implementation of backward-mode automatic differentiation.
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
*/
public open class AutoDiffValue<out T>(public val value: T)
/**
* Represents result of [simpleAutoDiff] call.
*
* @param T the non-nullable type of value.
* @param value the value of result.
* @property simpleAutoDiff The mapping of differentiated variables to their derivatives.
* @property context The field over [T].
*/
public class DerivationResult<T : Any>(
public val value: T,
private val derivativeValues: Map<String, T>,
public val context: Field<T>,
) {
/**
* Returns derivative of [variable] or returns [Ring.zero] in [context].
*/
public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero
/**
* Computes the divergence.
*/
public fun div(): T = context { sum(derivativeValues.values) }
}
/**
* Computes the gradient for variables in given order.
*/
public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T> {
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
return variables.map(::derivative).asBuffer()
}
/**
* Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
*
* The partial derivatives are placed in argument `d` variable
*
* Example:
* ```
* val x by symbol // define variable(s) and their values
* val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
* assertEquals(17.0, y.x) // the value of result (y)
* assertEquals(9.0, x.d) // dy/dx
* ```
*
* @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
* @return the result of differentiation.
*/
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
bindings: Map<Symbol, T>,
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return SimpleAutoDiffField(this, bindings).derivate(body)
}
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
vararg bindings: Pair<Symbol, T>,
body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
/**
* Represents field in context of which functions can be derived.
*/
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public val context: F,
bindings: Map<Symbol, T>,
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
// this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
/**
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
* with respect to this variable.
*
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
private class AutoDiffVariableWithDerivative<T : Any>(
override val identity: String,
value: T,
var d: T,
) : AutoDiffValue<T>(value), Symbol {
override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
}
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
}
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
private fun getDerivative(variable: AutoDiffValue<T>): T =
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
}
@Suppress("UNCHECKED_CAST")
private fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one)
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
/**
* A variable accessing inner state of derivatives.
* Use this value in inner builders to avoid creating additional derivative bindings.
*/
public var AutoDiffValue<T>.d: T
get() = getDerivative(this)
set(value) = setDerivative(this, value)
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
/**
* Performs update of derivative after the rest of the formula in the back-pass.
*
* For example, implementation of `sin` function is:
*
* ```
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
* }
* ```
*/
@Suppress("UNCHECKED_CAST")
public fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
}
// Overloads for Double constants
override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d
}
override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
// Basic math (+, -, *, /)
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value + b.value }) { z ->
a.d += z.d
b.d += z.d
}
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value * b.value }) { z ->
a.d += z.d * b.value
b.d += z.d * a.value
}
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value / b.value }) { z ->
a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value)
}
override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
derive(const { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble()
}
}
/**
* A constructs that creates a derivative structure with required order on-demand
*/
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F,
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
) : FirstDerivativeExpression<T>() {
public override operator fun invoke(arguments: Map<Symbol, T>): T {
//val bindings = arguments.entries.map { it.key.bind(it.value) }
return SimpleAutoDiffField(field, arguments).function().value
}
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
//val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
derivationResult.derivative(symbol)
}
}
/**
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
*/
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
return SimpleAutoDiffExpression(field, function)
}
}
}
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
public fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const)
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>,
y: Double,
): AutoDiffValue<T> =
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>,
y: Int,
): AutoDiffValue<T> = pow(x, y.toDouble())
// exp(x)
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x)
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any)
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>,
y: AutoDiffValue<T>,
): AutoDiffValue<T> =
exp(y * ln(x))
// sin(x)
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x)
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tan(x.value) }) { z ->
val c = cos(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tanh(x.value) }) { z ->
val c = cosh(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
context: F,
bindings: Map<Symbol, T>,
) : ExtendedField<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
// x ^ 2
public fun sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sqr(x)
// x ^ 1/2
public override fun sqrt(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sqrt(arg)
// x ^ y (const)
public override fun power(arg: AutoDiffValue<T>, pow: Number): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).pow(arg, pow.toDouble())
// exp(x)
public override fun exp(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).exp(arg)
// ln(x)
public override fun ln(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).ln(arg)
// x ^ y (any)
public fun pow(
x: AutoDiffValue<T>,
y: AutoDiffValue<T>,
): AutoDiffValue<T> = exp(y * ln(x))
// sin(x)
public override fun sin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sin(arg)
// cos(x)
public override fun cos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).cos(arg)
public override fun tan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).tan(arg)
public override fun asin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).asin(arg)
public override fun acos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).acos(arg)
public override fun atan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).atan(arg)
public override fun sinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sinh(arg)
public override fun cosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).cosh(arg)
public override fun tanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).tanh(arg)
public override fun asinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).asinh(arg)
public override fun acosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).acosh(arg)
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).atanh(arg)
}

View File

@ -0,0 +1,61 @@
package kscience.kmath.expressions
import kscience.kmath.linear.Point
import kscience.kmath.structures.BufferFactory
import kscience.kmath.structures.Structure2D
/**
* An environment to easy transform indexed variables to symbols and back.
* TODO requires multi-receivers to be beutiful
*/
public interface SymbolIndexer {
public val symbols: List<Symbol>
public fun indexOf(symbol: Symbol): Int = symbols.indexOf(symbol)
public operator fun <T> List<T>.get(symbol: Symbol): T {
require(size == symbols.size) { "The input list size for indexer should be ${symbols.size} but $size found" }
return get(this@SymbolIndexer.indexOf(symbol))
}
public operator fun <T> Array<T>.get(symbol: Symbol): T {
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
return get(this@SymbolIndexer.indexOf(symbol))
}
public operator fun DoubleArray.get(symbol: Symbol): Double {
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
return get(this@SymbolIndexer.indexOf(symbol))
}
public operator fun <T> Point<T>.get(symbol: Symbol): T {
require(size == symbols.size) { "The input buffer size for indexer should be ${symbols.size} but $size found" }
return get(this@SymbolIndexer.indexOf(symbol))
}
public fun DoubleArray.toMap(): Map<Symbol, Double> {
require(size == symbols.size) { "The input array size for indexer should be ${symbols.size} but $size found" }
return symbols.indices.associate { symbols[it] to get(it) }
}
public operator fun <T> Structure2D<T>.get(rowSymbol: Symbol, columnSymbol: Symbol): T =
get(indexOf(rowSymbol), indexOf(columnSymbol))
public fun <T> Map<Symbol, T>.toList(): List<T> = symbols.map { getValue(it) }
public fun <T> Map<Symbol, T>.toPoint(bufferFactory: BufferFactory<T>): Point<T> =
bufferFactory(symbols.size) { getValue(symbols[it]) }
public fun Map<Symbol, Double>.toDoubleArray(): DoubleArray = DoubleArray(symbols.size) { getValue(symbols[it]) }
}
public inline class SimpleSymbolIndexer(override val symbols: List<Symbol>) : SymbolIndexer
/**
* Execute the block with symbol indexer based on given symbol order
*/
public inline fun <R> withSymbols(vararg symbols: Symbol, block: SymbolIndexer.() -> R): R =
with(SimpleSymbolIndexer(symbols.toList()), block)
public inline fun <R> withSymbols(symbols: Collection<Symbol>, block: SymbolIndexer.() -> R): R =
with(SimpleSymbolIndexer(symbols.toList()), block)

View File

@ -7,6 +7,7 @@ import kscience.kmath.operations.Space
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
/** /**
* Creates a functional expression with this [Space]. * Creates a functional expression with this [Space].
*/ */

View File

@ -1,266 +0,0 @@
package kscience.kmath.misc
import kscience.kmath.linear.Point
import kscience.kmath.operations.*
import kscience.kmath.structures.asBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/*
* Implementation of backward-mode automatic differentiation.
* Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
*/
/**
* Differentiable variable with value and derivative of differentiation ([deriv]) result
* with respect to this variable.
*
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
public open class Variable<T : Any>(public val value: T)
/**
* Represents result of [deriv] call.
*
* @param T the non-nullable type of value.
* @param value the value of result.
* @property deriv The mapping of differentiated variables to their derivatives.
* @property context The field over [T].
*/
public class DerivationResult<T : Any>(
value: T,
public val deriv: Map<Variable<T>, T>,
public val context: Field<T>
) : Variable<T>(value) {
/**
* Returns derivative of [variable] or returns [Ring.zero] in [context].
*/
public fun deriv(variable: Variable<T>): T = deriv[variable] ?: context.zero
/**
* Computes the divergence.
*/
public fun div(): T = context { sum(deriv.values) }
/**
* Computes the gradient for variables in given order.
*/
public fun grad(vararg variables: Variable<T>): Point<T> {
check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
return variables.map(::deriv).asBuffer()
}
}
/**
* Runs differentiation and establishes [AutoDiffField] context inside the block of code.
*
* The partial derivatives are placed in argument `d` variable
*
* Example:
* ```
* val x = Variable(2) // define variable(s) and their values
* val y = deriv { sqr(x) + 5 * x + 3 } // write formulate in deriv context
* assertEquals(17.0, y.x) // the value of result (y)
* assertEquals(9.0, x.d) // dy/dx
* ```
*
* @param body the action in [AutoDiffField] context returning [Variable] to differentiate with respect to.
* @return the result of differentiation.
*/
public inline fun <T : Any, F : Field<T>> F.deriv(body: AutoDiffField<T, F>.() -> Variable<T>): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return (AutoDiffContext(this)) {
val result = body()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
DerivationResult(result.value, derivatives, this@deriv)
}
}
/**
* Represents field in context of which functions can be derived.
*/
public abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
public abstract val context: F
/**
* A variable accessing inner state of derivatives.
* Use this value in inner builders to avoid creating additional derivative bindings.
*/
public abstract var Variable<T>.d: T
/**
* Performs update of derivative after the rest of the formula in the back-pass.
*
* For example, implementation of `sin` function is:
*
* ```
* fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
* }
* ```
*/
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R
/**
*
*/
public abstract fun variable(value: T): Variable<T>
public inline fun variable(block: F.() -> T): Variable<T> = variable(context.block())
// Overloads for Double constants
override operator fun Number.plus(b: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d
}
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
override operator fun Number.minus(b: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
override operator fun Variable<T>.minus(b: Number): Variable<T> =
derive(variable { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
}
/**
* Automatic Differentiation context class.
*/
@PublishedApi
internal class AutoDiffContext<T : Any, F : Field<T>>(override val context: F) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0
val derivatives: MutableMap<Variable<T>, T> = hashMapOf()
override val zero: Variable<T> get() = Variable(context.zero)
override val one: Variable<T> get() = Variable(context.one)
/**
* A variable coupled with its derivative. For internal use only
*/
private class VariableWithDeriv<T : Any>(x: T, var d: T) : Variable<T>(x)
override fun variable(value: T): Variable<T> =
VariableWithDeriv(value, context.zero)
override var Variable<T>.d: T
get() = (this as? VariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
set(value) = if (this is VariableWithDeriv) d = value else derivatives[this] = value
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
@Suppress("UNCHECKED_CAST")
fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
// Basic math (+, -, *, /)
override fun add(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value + b.value }) { z ->
a.d += z.d
b.d += z.d
}
override fun multiply(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value * b.value }) { z ->
a.d += z.d * b.value
b.d += z.d * a.value
}
override fun divide(a: Variable<T>, b: Variable<T>): Variable<T> = derive(variable { a.value / b.value }) { z ->
a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value)
}
override fun multiply(a: Variable<T>, k: Number): Variable<T> = derive(variable { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble()
}
}
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: Variable<T>): Variable<T> =
derive(variable { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: Variable<T>): Variable<T> =
derive(variable { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Double): Variable<T> =
derive(variable { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Int): Variable<T> =
pow(x, y.toDouble())
// exp(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: Variable<T>): Variable<T> =
derive(variable { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: Variable<T>): Variable<T> =
derive(variable { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow(x: Variable<T>, y: Variable<T>): Variable<T> =
exp(y * ln(x))
// sin(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: Variable<T>): Variable<T> =
derive(variable { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: Variable<T>): Variable<T> =
derive(variable { tan(x.value) }) { z ->
val c = cos(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: Variable<T>): Variable<T> =
derive(variable { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: Variable<T>): Variable<T> =
derive(variable { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: Variable<T>): Variable<T> =
derive(variable { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: Variable<T>): Variable<T> =
derive(variable { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: Variable<T>): Variable<T> =
derive(variable { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: Variable<T>): Variable<T> =
derive(variable { tan(x.value) }) { z ->
val c = cosh(x.value)
x.d += z.d / (c * c)
}
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: Variable<T>): Variable<T> =
derive(variable { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: Variable<T>): Variable<T> =
derive(variable { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: Variable<T>): Variable<T> =
derive(variable { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }

View File

@ -150,6 +150,7 @@ public data class Complex(val re: Double, val im: Double) : FieldElement<Complex
} }
} }
/** /**
* Creates a complex number with real part equal to this real. * Creates a complex number with real part equal to this real.
* *

View File

@ -6,19 +6,21 @@ import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFails
class ExpressionFieldTest { class ExpressionFieldTest {
val x by symbol
@Test @Test
fun testExpression() { fun testExpression() {
val context = FunctionalExpressionField(RealField) val context = FunctionalExpressionField(RealField)
val expression = context { val expression = context {
val x = variable("x", 2.0) val x by binding()
x * x + 2 * x + one x * x + 2 * x + one
} }
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression(x to 1.0), 4.0)
assertEquals(expression(), 9.0) assertFails { expression()}
} }
@Test @Test
@ -26,33 +28,33 @@ class ExpressionFieldTest {
val context = FunctionalExpressionField(ComplexField) val context = FunctionalExpressionField(ComplexField)
val expression = context { val expression = context {
val x = variable("x", Complex(2.0, 0.0)) val x = bind(x)
x * x + 2 * x + one x * x + 2 * x + one
} }
assertEquals(expression("x" to Complex(1.0, 0.0)), Complex(4.0, 0.0)) assertEquals(expression(x to Complex(1.0, 0.0)), Complex(4.0, 0.0))
assertEquals(expression(), Complex(9.0, 0.0)) //assertEquals(expression(), Complex(9.0, 0.0))
} }
@Test @Test
fun separateContext() { fun separateContext() {
fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> { fun <T> FunctionalExpressionField<T, *>.expression(): Expression<T> {
val x = variable("x") val x by binding()
return x * x + 2 * x + one return x * x + 2 * x + one
} }
val expression = FunctionalExpressionField(RealField).expression() val expression = FunctionalExpressionField(RealField).expression()
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression(x to 1.0), 4.0)
} }
@Test @Test
fun valueExpression() { fun valueExpression() {
val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = { val expressionBuilder: FunctionalExpressionField<Double, *>.() -> Expression<Double> = {
val x = variable("x") val x by binding()
x * x + 2 * x + one x * x + 2 * x + one
} }
val expression = FunctionalExpressionField(RealField).expressionBuilder() val expression = FunctionalExpressionField(RealField).expressionBuilder()
assertEquals(expression("x" to 1.0), 4.0) assertEquals(expression(x to 1.0), 4.0)
} }
} }

View File

@ -0,0 +1,285 @@
package kscience.kmath.expressions
import kscience.kmath.operations.RealField
import kscience.kmath.structures.asBuffer
import kotlin.math.E
import kotlin.math.PI
import kotlin.math.pow
import kotlin.math.sqrt
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class SimpleAutoDiffTest {
fun dx(
xBinding: Pair<Symbol, Double>,
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
fun dxy(
xBinding: Pair<Symbol, Double>,
yBinding: Pair<Symbol, Double>,
body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
body(bind(xBinding.first), bind(yBinding.first))
}
fun diff(block: SimpleAutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
return SimpleAutoDiffExpression(RealField, block)
}
val x by symbol
val y by symbol
val z by symbol
@Test
fun testPlusX2() {
val y = RealField.simpleAutoDiff(x to 3.0) {
// diff w.r.t this x at 3
val x = bind(x)
x + x
}
assertEquals(6.0, y.value) // y = x + x = 6
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
}
@Test
fun testPlusX2Expr() {
val expr = diff {
val x = bind(x)
x + x
}
assertEquals(6.0, expr(x to 3.0)) // y = x + x = 6
assertEquals(2.0, expr.derivative(x)(x to 3.0)) // dy/dx = 2
}
@Test
fun testPlus() {
// two variables
val z = RealField.simpleAutoDiff(x to 2.0, y to 3.0) {
val x = bind(x)
val y = bind(y)
x + y
}
assertEquals(5.0, z.value) // z = x + y = 5
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
assertEquals(1.0, z.derivative(y)) // dz/dy = 1
}
@Test
fun testMinus() {
// two variables
val z = RealField.simpleAutoDiff(x to 7.0, y to 3.0) {
val x = bind(x)
val y = bind(y)
x - y
}
assertEquals(4.0, z.value) // z = x - y = 4
assertEquals(1.0, z.derivative(x)) // dz/dx = 1
assertEquals(-1.0, z.derivative(y)) // dz/dy = -1
}
@Test
fun testMulX2() {
val y = dx(x to 3.0) { x ->
// diff w.r.t this x at 3
x * x
}
assertEquals(9.0, y.value) // y = x * x = 9
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
}
@Test
fun testSqr() {
val y = dx(x to 3.0) { x -> sqr(x) }
assertEquals(9.0, y.value) // y = x ^ 2 = 9
assertEquals(6.0, y.derivative(x)) // dy/dx = 2 * x = 7
}
@Test
fun testSqrSqr() {
val y = dx(x to 2.0) { x -> sqr(sqr(x)) }
assertEquals(16.0, y.value) // y = x ^ 4 = 16
assertEquals(32.0, y.derivative(x)) // dy/dx = 4 * x^3 = 32
}
@Test
fun testX3() {
val y = dx(x to 2.0) { x ->
// diff w.r.t this x at 2
x * x * x
}
assertEquals(8.0, y.value) // y = x * x * x = 8
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x * x = 12
}
@Test
fun testDiv() {
val z = dxy(x to 5.0, y to 2.0) { x, y ->
x / y
}
assertEquals(2.5, z.value) // z = x / y = 2.5
assertEquals(0.5, z.derivative(x)) // dz/dx = 1 / y = 0.5
assertEquals(-1.25, z.derivative(y)) // dz/dy = -x / y^2 = -1.25
}
@Test
fun testPow3() {
val y = dx(x to 2.0) { x ->
// diff w.r.t this x at 2
pow(x, 3)
}
assertEquals(8.0, y.value) // y = x ^ 3 = 8
assertEquals(12.0, y.derivative(x)) // dy/dx = 3 * x ^ 2 = 12
}
@Test
fun testPowFull() {
val z = dxy(x to 2.0, y to 3.0) { x, y ->
pow(x, y)
}
assertApprox(8.0, z.value) // z = x ^ y = 8
assertApprox(12.0, z.derivative(x)) // dz/dx = y * x ^ (y - 1) = 12
assertApprox(8.0 * kotlin.math.ln(2.0), z.derivative(y)) // dz/dy = x ^ y * ln(x)
}
@Test
fun testFromPaper() {
val y = dx(x to 3.0) { x -> 2 * x + x * x * x }
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
assertEquals(29.0, y.derivative(x)) // dy/dx = 2 + 3 * x * x = 29
}
@Test
fun testInnerVariable() {
val y = dx(x to 1.0) { x ->
const(1.0) * x
}
assertEquals(1.0, y.value) // y = x ^ n = 1
assertEquals(1.0, y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
}
@Test
fun testLongChain() {
val n = 10_000
val y = dx(x to 1.0) { x ->
var res = const(1.0)
for (i in 1..n) res *= x
res
}
assertEquals(1.0, y.value) // y = x ^ n = 1
assertEquals(n.toDouble(), y.derivative(x)) // dy/dx = n * x ^ (n - 1) = n - 1
}
@Test
fun testExample() {
val y = dx(x to 2.0) { x -> sqr(x) + 5 * x + 3 }
assertEquals(17.0, y.value) // the value of result (y)
assertEquals(9.0, y.derivative(x)) // dy/dx
}
@Test
fun testSqrt() {
val y = dx(x to 16.0) { x -> sqrt(x) }
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
assertEquals(1.0 / 8, y.derivative(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
}
@Test
fun testSin() {
val y = dx(x to PI / 6.0) { x -> sin(x) }
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
assertApprox(sqrt(3.0) / 2, y.derivative(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
}
@Test
fun testCos() {
val y = dx(x to PI / 6) { x -> cos(x) }
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
assertApprox(-0.5, y.derivative(x)) // dy/dx = -sin(pi/6) = -0.5
}
@Test
fun testTan() {
val y = dx(x to PI / 6) { x -> tan(x) }
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
assertApprox(4.0 / 3.0, y.derivative(x)) // dy/dx = sec(pi/6)^2 = 4/3
}
@Test
fun testAsin() {
val y = dx(x to PI / 6) { x -> asin(x) }
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
assertApprox(6.0 / sqrt(36 - PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(36-pi^2)
}
@Test
fun testAcos() {
val y = dx(x to PI / 6) { x -> acos(x) }
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
}
@Test
fun testAtan() {
val y = dx(x to PI / 6) { x -> atan(x) }
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
assertApprox(36.0 / (36.0 + PI * PI), y.derivative(x)) // dy/dx = 36/(36+pi^2)
}
@Test
fun testSinh() {
val y = dx(x to 0.0) { x -> sinh(x) }
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
assertApprox(kotlin.math.cosh(0.0), y.derivative(x)) // dy/dx = cosh(0)
}
@Test
fun testCosh() {
val y = dx(x to 0.0) { x -> cosh(x) }
assertApprox(1.0, y.value) //y = cosh(0)
assertApprox(0.0, y.derivative(x)) // dy/dx = sinh(0)
}
@Test
fun testTanh() {
val y = dx(x to 1.0) { x -> tanh(x) }
assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6)
assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
}
@Test
fun testAsinh() {
val y = dx(x to PI / 6) { x -> asinh(x) }
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
assertApprox(6.0 / sqrt(36 + PI * PI), y.derivative(x)) // dy/dx = 6/sqrt(pi^2+36)
}
@Test
fun testAcosh() {
val y = dx(x to PI / 6) { x -> acosh(x) }
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.derivative(x)) // dy/dx = -6/sqrt(36-pi^2)
}
@Test
fun testAtanh() {
val y = dx(x to PI / 6) { x -> atanh(x) }
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
assertApprox(-36.0 / (PI * PI - 36.0), y.derivative(x)) // dy/dx = -36/(pi^2-36)
}
@Test
fun testDivGrad() {
val res = dxy(x to 1.0, y to 2.0) { x, y -> x * x + y * y }
assertEquals(6.0, res.div())
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
}
private fun assertApprox(a: Double, b: Double) {
if ((a - b) > 1e-10) assertEquals(a, b)
}
}

View File

@ -1,261 +0,0 @@
package kscience.kmath.misc
import kscience.kmath.operations.RealField
import kscience.kmath.structures.asBuffer
import kotlin.math.PI
import kotlin.math.pow
import kotlin.math.sqrt
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class AutoDiffTest {
inline fun deriv(body: AutoDiffField<Double, RealField>.() -> Variable<Double>): DerivationResult<Double> =
RealField.deriv(body)
@Test
fun testPlusX2() {
val x = Variable(3.0) // diff w.r.t this x at 3
val y = deriv { x + x }
assertEquals(6.0, y.value) // y = x + x = 6
assertEquals(2.0, y.deriv(x)) // dy/dx = 2
}
@Test
fun testPlus() {
// two variables
val x = Variable(2.0)
val y = Variable(3.0)
val z = deriv { x + y }
assertEquals(5.0, z.value) // z = x + y = 5
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
assertEquals(1.0, z.deriv(y)) // dz/dy = 1
}
@Test
fun testMinus() {
// two variables
val x = Variable(7.0)
val y = Variable(3.0)
val z = deriv { x - y }
assertEquals(4.0, z.value) // z = x - y = 4
assertEquals(1.0, z.deriv(x)) // dz/dx = 1
assertEquals(-1.0, z.deriv(y)) // dz/dy = -1
}
@Test
fun testMulX2() {
val x = Variable(3.0) // diff w.r.t this x at 3
val y = deriv { x * x }
assertEquals(9.0, y.value) // y = x * x = 9
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
}
@Test
fun testSqr() {
val x = Variable(3.0)
val y = deriv { sqr(x) }
assertEquals(9.0, y.value) // y = x ^ 2 = 9
assertEquals(6.0, y.deriv(x)) // dy/dx = 2 * x = 7
}
@Test
fun testSqrSqr() {
val x = Variable(2.0)
val y = deriv { sqr(sqr(x)) }
assertEquals(16.0, y.value) // y = x ^ 4 = 16
assertEquals(32.0, y.deriv(x)) // dy/dx = 4 * x^3 = 32
}
@Test
fun testX3() {
val x = Variable(2.0) // diff w.r.t this x at 2
val y = deriv { x * x * x }
assertEquals(8.0, y.value) // y = x * x * x = 8
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x * x = 12
}
@Test
fun testDiv() {
val x = Variable(5.0)
val y = Variable(2.0)
val z = deriv { x / y }
assertEquals(2.5, z.value) // z = x / y = 2.5
assertEquals(0.5, z.deriv(x)) // dz/dx = 1 / y = 0.5
assertEquals(-1.25, z.deriv(y)) // dz/dy = -x / y^2 = -1.25
}
@Test
fun testPow3() {
val x = Variable(2.0) // diff w.r.t this x at 2
val y = deriv { pow(x, 3) }
assertEquals(8.0, y.value) // y = x ^ 3 = 8
assertEquals(12.0, y.deriv(x)) // dy/dx = 3 * x ^ 2 = 12
}
@Test
fun testPowFull() {
val x = Variable(2.0)
val y = Variable(3.0)
val z = deriv { pow(x, y) }
assertApprox(8.0, z.value) // z = x ^ y = 8
assertApprox(12.0, z.deriv(x)) // dz/dx = y * x ^ (y - 1) = 12
assertApprox(8.0 * kotlin.math.ln(2.0), z.deriv(y)) // dz/dy = x ^ y * ln(x)
}
@Test
fun testFromPaper() {
val x = Variable(3.0)
val y = deriv { 2 * x + x * x * x }
assertEquals(33.0, y.value) // y = 2 * x + x * x * x = 33
assertEquals(29.0, y.deriv(x)) // dy/dx = 2 + 3 * x * x = 29
}
@Test
fun testInnerVariable() {
val x = Variable(1.0)
val y = deriv {
Variable(1.0) * x
}
assertEquals(1.0, y.value) // y = x ^ n = 1
assertEquals(1.0, y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
}
@Test
fun testLongChain() {
val n = 10_000
val x = Variable(1.0)
val y = deriv {
var res = Variable(1.0)
for (i in 1..n) res *= x
res
}
assertEquals(1.0, y.value) // y = x ^ n = 1
assertEquals(n.toDouble(), y.deriv(x)) // dy/dx = n * x ^ (n - 1) = n - 1
}
@Test
fun testExample() {
val x = Variable(2.0)
val y = deriv { sqr(x) + 5 * x + 3 }
assertEquals(17.0, y.value) // the value of result (y)
assertEquals(9.0, y.deriv(x)) // dy/dx
}
@Test
fun testSqrt() {
val x = Variable(16.0)
val y = deriv { sqrt(x) }
assertEquals(4.0, y.value) // y = x ^ 1/2 = 4
assertEquals(1.0 / 8, y.deriv(x)) // dy/dx = 1/2 / x ^ 1/4 = 1/8
}
@Test
fun testSin() {
val x = Variable(PI / 6.0)
val y = deriv { sin(x) }
assertApprox(0.5, y.value) // y = sin(PI/6) = 0.5
assertApprox(sqrt(3.0) / 2, y.deriv(x)) // dy/dx = cos(pi/6) = sqrt(3)/2
}
@Test
fun testCos() {
val x = Variable(PI / 6)
val y = deriv { cos(x) }
assertApprox(sqrt(3.0) / 2, y.value) //y = cos(pi/6) = sqrt(3)/2
assertApprox(-0.5, y.deriv(x)) // dy/dx = -sin(pi/6) = -0.5
}
@Test
fun testTan() {
val x = Variable(PI / 6)
val y = deriv { tan(x) }
assertApprox(1.0 / sqrt(3.0), y.value) // y = tan(pi/6) = 1/sqrt(3)
assertApprox(4.0 / 3.0, y.deriv(x)) // dy/dx = sec(pi/6)^2 = 4/3
}
@Test
fun testAsin() {
val x = Variable(PI / 6)
val y = deriv { asin(x) }
assertApprox(kotlin.math.asin(PI / 6.0), y.value) // y = asin(pi/6)
assertApprox(6.0 / sqrt(36 - PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(36-pi^2)
}
@Test
fun testAcos() {
val x = Variable(PI / 6)
val y = deriv { acos(x) }
assertApprox(kotlin.math.acos(PI / 6.0), y.value) // y = acos(pi/6)
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
}
@Test
fun testAtan() {
val x = Variable(PI / 6)
val y = deriv { atan(x) }
assertApprox(kotlin.math.atan(PI / 6.0), y.value) // y = atan(pi/6)
assertApprox(36.0 / (36.0 + PI * PI), y.deriv(x)) // dy/dx = 36/(36+pi^2)
}
@Test
fun testSinh() {
val x = Variable(0.0)
val y = deriv { sinh(x) }
assertApprox(kotlin.math.sinh(0.0), y.value) // y = sinh(0)
assertApprox(kotlin.math.cosh(0.0), y.deriv(x)) // dy/dx = cosh(0)
}
@Test
fun testCosh() {
val x = Variable(0.0)
val y = deriv { cosh(x) }
assertApprox(1.0, y.value) //y = cosh(0)
assertApprox(0.0, y.deriv(x)) // dy/dx = sinh(0)
}
@Test
fun testTanh() {
val x = Variable(PI / 6)
val y = deriv { tanh(x) }
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6)
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.deriv(x)) // dy/dx = sech(pi/6)^2
}
@Test
fun testAsinh() {
val x = Variable(PI / 6)
val y = deriv { asinh(x) }
assertApprox(kotlin.math.asinh(PI / 6.0), y.value) // y = asinh(pi/6)
assertApprox(6.0 / sqrt(36 + PI * PI), y.deriv(x)) // dy/dx = 6/sqrt(pi^2+36)
}
@Test
fun testAcosh() {
val x = Variable(PI / 6)
val y = deriv { acosh(x) }
assertApprox(kotlin.math.acosh(PI / 6.0), y.value) // y = acosh(pi/6)
assertApprox(-6.0 / sqrt(36.0 - PI * PI), y.deriv(x)) // dy/dx = -6/sqrt(36-pi^2)
}
@Test
fun testAtanh() {
val x = Variable(PI / 6.0)
val y = deriv { atanh(x) }
assertApprox(kotlin.math.atanh(PI / 6.0), y.value) // y = atanh(pi/6)
assertApprox(-36.0 / (PI * PI - 36.0), y.deriv(x)) // dy/dx = -36/(pi^2-36)
}
@Test
fun testDivGrad() {
val x = Variable(1.0)
val y = Variable(2.0)
val res = deriv { x * x + y * y }
assertEquals(6.0, res.div())
assertTrue(res.grad(x, y).contentEquals(doubleArrayOf(2.0, 4.0).asBuffer()))
}
private fun assertApprox(a: Double, b: Double) {
if ((a - b) > 1e-10) assertEquals(a, b)
}
}

View File

@ -8,13 +8,13 @@ import kotlin.contracts.contract
import kotlin.math.max import kotlin.math.max
import kotlin.math.pow import kotlin.math.pow
// TODO make `inline`, when KT-41771 gets fixed
/** /**
* Polynomial coefficients without fixation on specific context they are applied to * Polynomial coefficients without fixation on specific context they are applied to
* @param coefficients constant is the leftmost coefficient * @param coefficients constant is the leftmost coefficient
*/ */
public inline class Polynomial<T : Any>(public val coefficients: List<T>) public inline class Polynomial<T : Any>(public val coefficients: List<T>)
@Suppress("FunctionName")
public fun <T : Any> Polynomial(vararg coefficients: T): Polynomial<T> = Polynomial(coefficients.toList()) public fun <T : Any> Polynomial(vararg coefficients: T): Polynomial<T> = Polynomial(coefficients.toList())
public fun Polynomial<Double>.value(): Double = coefficients.reduceIndexed { index, acc, d -> acc + d.pow(index) } public fun Polynomial<Double>.value(): Double = coefficients.reduceIndexed { index, acc, d -> acc + d.pow(index) }
@ -33,14 +33,6 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring
res res
} }
/**
* Represent a polynomial as a context-dependent function
*/
public fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, C, T> =
object : MathFunction<T, C, T> {
override fun C.invoke(arg: T): T = value(this, arg)
}
/** /**
* Represent the polynomial as a regular context-less function * Represent the polynomial as a regular context-less function
*/ */
@ -49,7 +41,7 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.asFunction(ring: C): (T) -> T =
/** /**
* An algebra for polynomials * An algebra for polynomials
*/ */
public class PolynomialSpace<T : Any, C : Ring<T>>(public val ring: C) : Space<Polynomial<T>> { public class PolynomialSpace<T : Any, C : Ring<T>>(private val ring: C) : Space<Polynomial<T>> {
public override val zero: Polynomial<T> = Polynomial(emptyList()) public override val zero: Polynomial<T> = Polynomial(emptyList())
public override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> { public override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {

View File

@ -1,34 +0,0 @@
package kscience.kmath.functions
import kscience.kmath.operations.Algebra
import kscience.kmath.operations.RealField
// TODO make fun interface when KT-41770 is fixed
/**
* A regular function that could be called only inside specific algebra context
* @param T source type
* @param C source algebra constraint
* @param R result type
*/
public /*fun*/ interface MathFunction<T, C : Algebra<T>, R> {
public operator fun C.invoke(arg: T): R
}
public fun <R> MathFunction<Double, RealField, R>.invoke(arg: Double): R = RealField.invoke(arg)
/**
* A suspendable function defined in algebraic context
*/
// TODO make fun interface, when the new JVM IR is enabled
public interface SuspendableMathFunction<T, C : Algebra<T>, R> {
public suspend operator fun C.invoke(arg: T): R
}
public suspend fun <R> SuspendableMathFunction<Double, RealField, R>.invoke(arg: Double): R = RealField.invoke(arg)
/**
* A parametric function with parameter
*/
public fun interface ParametricFunction<T, P, C : Algebra<T>> {
public operator fun C.invoke(arg: T, parameter: P): T
}

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kscience.kmath.chains.Chain import kscience.kmath.chains.Chain
import kscience.kmath.chains.collect import kscience.kmath.chains.collect

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kscience.kmath.chains.Chain import kscience.kmath.chains.Chain
import kscience.kmath.chains.SimpleChain import kscience.kmath.chains.SimpleChain

View File

@ -0,0 +1,59 @@
package kscience.kmath.stat
import kscience.kmath.expressions.*
import kscience.kmath.operations.ExtendedField
import kscience.kmath.structures.Buffer
import kscience.kmath.structures.indices
import kotlin.math.pow
public object Fitting {
/**
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
*/
public fun <T : Any, I : Any, A> chiSquared(
autoDiff: AutoDiffProcessor<T, I, A>,
x: Buffer<T>,
y: Buffer<T>,
yErr: Buffer<T>,
model: A.(I) -> I,
): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
require(x.size == y.size) { "X and y buffers should be of the same size" }
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
return autoDiff.process {
var sum = zero
x.indices.forEach {
val xValue = const(x[it])
val yValue = const(y[it])
val yErrValue = const(yErr[it])
val modelValue = model(xValue)
sum += ((yValue - modelValue) / yErrValue).pow(2)
}
sum
}
}
/**
* Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives
*/
public fun chiSquared(
x: Buffer<Double>,
y: Buffer<Double>,
yErr: Buffer<Double>,
model: Expression<Double>,
xSymbol: Symbol = StringSymbol("x"),
): Expression<Double> {
require(x.size == y.size) { "X and y buffers should be of the same size" }
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
return Expression { arguments ->
x.indices.sumByDouble {
val xValue = x[it]
val yValue = y[it]
val yErrValue = yErr[it]
val modifiedArgs = arguments + (xSymbol to xValue)
val modelValue = model(modifiedArgs)
((yValue - modelValue) / yErrValue).pow(2)
}
}
}
}

View File

@ -0,0 +1,91 @@
package kscience.kmath.stat
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.Symbol
public interface OptimizationFeature
public class OptimizationResult<T>(
public val point: Map<Symbol, T>,
public val value: T,
public val features: Set<OptimizationFeature> = emptySet(),
) {
override fun toString(): String {
return "OptimizationResult(point=$point, value=$value)"
}
}
public operator fun <T> OptimizationResult<T>.plus(
feature: OptimizationFeature,
): OptimizationResult<T> = OptimizationResult(point, value, features + feature)
/**
* A configuration builder for optimization problem
*/
public interface OptimizationProblem<T : Any> {
/**
* Define the initial guess for the optimization problem
*/
public fun initialGuess(map: Map<Symbol, T>): Unit
/**
* Set an objective function expression
*/
public fun expression(expression: Expression<T>): Unit
/**
* Set a differentiable expression as objective function as function and gradient provider
*/
public fun diffExpression(expression: DifferentiableExpression<T>): Unit
/**
* Update the problem from previous optimization run
*/
public fun update(result: OptimizationResult<T>)
/**
* Make an optimization run
*/
public fun optimize(): OptimizationResult<T>
}
public interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
public fun build(symbols: List<Symbol>): P
}
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
symbols: List<Symbol>,
block: P.() -> Unit,
): P = build(symbols).apply(block)
/**
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
*/
public fun <T : Any, F : OptimizationProblem<T>> Expression<T>.optimizeWith(
factory: OptimizationProblemFactory<T, F>,
vararg symbols: Symbol,
configuration: F.() -> Unit,
): OptimizationResult<T> {
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = factory(symbols.toList(),configuration)
problem.expression(this)
return problem.optimize()
}
/**
* Optimize differentiable expression using specific [OptimizationProblemFactory]
*/
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.optimizeWith(
factory: OptimizationProblemFactory<T, F>,
vararg symbols: Symbol,
configuration: F.() -> Unit,
): OptimizationResult<T> {
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = factory(symbols.toList(), configuration)
problem.diffExpression(this)
return problem.optimize()
}

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kscience.kmath.chains.Chain import kscience.kmath.chains.Chain

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kotlin.random.Random import kotlin.random.Random

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kscience.kmath.chains.Chain import kscience.kmath.chains.Chain
import kscience.kmath.chains.ConstantChain import kscience.kmath.chains.ConstantChain

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kscience.kmath.chains.Chain import kscience.kmath.chains.Chain
import kscience.kmath.chains.SimpleChain import kscience.kmath.chains.SimpleChain

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import org.apache.commons.rng.UniformRandomProvider import org.apache.commons.rng.UniformRandomProvider
import org.apache.commons.rng.simple.RandomSource import org.apache.commons.rng.simple.RandomSource

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kscience.kmath.chains.BlockingIntChain import kscience.kmath.chains.BlockingIntChain
import kscience.kmath.chains.BlockingRealChain import kscience.kmath.chains.BlockingRealChain

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kotlinx.coroutines.flow.take import kotlinx.coroutines.flow.take
import kotlinx.coroutines.flow.toList import kotlinx.coroutines.flow.toList

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kotlin.test.Test import kotlin.test.Test

View File

@ -1,4 +1,4 @@
package kscience.kmath.prob package kscience.kmath.stat
import kotlinx.coroutines.flow.drop import kotlinx.coroutines.flow.drop
import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.first

View File

@ -10,8 +10,8 @@ pluginManagement {
maven("https://dl.bintray.com/kotlin/kotlin-dev/") maven("https://dl.bintray.com/kotlin/kotlin-dev/")
} }
val toolsVersion = "0.6.1-dev-1.4.20-M1" val toolsVersion = "0.6.4-dev-1.4.20-M2"
val kotlinVersion = "1.4.20-M1" val kotlinVersion = "1.4.20-M2"
plugins { plugins {
id("kotlinx.benchmark") version "0.2.0-dev-20" id("kotlinx.benchmark") version "0.2.0-dev-20"
@ -34,11 +34,11 @@ include(
":kmath-histograms", ":kmath-histograms",
":kmath-commons", ":kmath-commons",
":kmath-viktor", ":kmath-viktor",
":kmath-prob", ":kmath-stat",
":kmath-dimensions", ":kmath-dimensions",
":kmath-for-real", ":kmath-for-real",
":kmath-geometry", ":kmath-geometry",
":kmath-ast", ":kmath-ast",
":examples", ":kmath-ejml",
":kmath-ejml" ":examples"
) )