diff --git a/.space.kts b/.space.kts
index 9dda0cbf7..d70ad6d59 100644
--- a/.space.kts
+++ b/.space.kts
@@ -1 +1,3 @@
-job("Build") { gradlew("openjdk:11", "build") }
+job("Build") {
+ gradlew("openjdk:11", "build")
+}
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 214730ecc..0652156b1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,11 @@
- Automatic README generation for features (#139)
- Native support for `memory`, `core` and `dimensions`
- `kmath-ejml` to supply EJML SimpleMatrix wrapper.
+- A separate `Symbol` entity, which is used for global unbound symbol.
+- A `Symbol` indexing scope.
+- Basic optimization API for Commons-math.
+- Chi squared optimization for array-like data in CM
+- `Fitting` utility object in prob/stat
### Changed
- Package changed from `scientifik` to `kscience.kmath`.
@@ -16,6 +21,8 @@
- `Polynomial` secondary constructor made function.
- Kotlin version: 1.3.72 -> 1.4.20-M1
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
+- Full autodiff refactoring based on `Symbol`
+- `kmath-prob` renamed to `kmath-stat`
### Deprecated
diff --git a/README.md b/README.md
index eadc3036b..afab32dcf 100644
--- a/README.md
+++ b/README.md
@@ -53,9 +53,7 @@ can be used for a wide variety of purposes from high performance calculations to
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
to submit a feature request if you want something to be done first.
-
-* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
-
+
## Planned features
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
@@ -101,7 +99,7 @@ can be used for a wide variety of purposes from high performance calculations to
> - [buffers](kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure
> - [expressions](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions
> - [domains](kmath-core/src/commonMain/kotlin/kscience/kmath/domains) : Domains
-> - [autodif](kmath-core/src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt) : Automatic differentiation
+> - [autodif](kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
@@ -117,6 +115,12 @@ can be used for a wide variety of purposes from high performance calculations to
> **Maturity**: EXPERIMENTAL
+* ### [kmath-ejml](kmath-ejml)
+>
+>
+> **Maturity**: EXPERIMENTAL
+
+
* ### [kmath-for-real](kmath-for-real)
>
>
@@ -147,7 +151,7 @@ can be used for a wide variety of purposes from high performance calculations to
> **Maturity**: EXPERIMENTAL
-* ### [kmath-prob](kmath-prob)
+* ### [kmath-stat](kmath-stat)
>
>
> **Maturity**: EXPERIMENTAL
@@ -178,8 +182,8 @@ repositories{
}
dependencies{
- api("kscience.kmath:kmath-core:${kmathVersion}")
- //api("scientifik:kmath-core:${kmathVersion}") for 0.1.3 and earlier
+ api("kscience.kmath:kmath-core:0.2.0-dev-2")
+ //api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
}
```
@@ -197,4 +201,4 @@ with the same artifact names.
## Contributing
-The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
+The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
\ No newline at end of file
diff --git a/build.gradle.kts b/build.gradle.kts
index 05e2d5979..acb9f3b68 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -2,7 +2,7 @@ plugins {
id("ru.mipt.npm.project")
}
-val kmathVersion: String by extra("0.2.0-dev-2")
+val kmathVersion: String by extra("0.2.0-dev-3")
val bintrayRepo: String by extra("kscience")
val githubProject: String by extra("kmath")
@@ -25,3 +25,7 @@ subprojects {
readme {
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
}
+
+apiValidation{
+ validationDisabled = true
+}
\ No newline at end of file
diff --git a/docs/templates/README-TEMPLATE.md b/docs/templates/README-TEMPLATE.md
index f451adb24..5117e0694 100644
--- a/docs/templates/README-TEMPLATE.md
+++ b/docs/templates/README-TEMPLATE.md
@@ -107,4 +107,4 @@ with the same artifact names.
## Contributing
-The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
+The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
\ No newline at end of file
diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts
index 900da966b..9ba1ec5be 100644
--- a/examples/build.gradle.kts
+++ b/examples/build.gradle.kts
@@ -23,7 +23,7 @@ dependencies {
implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons"))
- implementation(project(":kmath-prob"))
+ implementation(project(":kmath-stat"))
implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions"))
implementation(project(":kmath-ejml"))
diff --git a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt
index 9c0a01961..ef554aeff 100644
--- a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt
+++ b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionBenchmark.kt
@@ -4,7 +4,7 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import kscience.kmath.chains.BlockingRealChain
-import kscience.kmath.prob.*
+import kscience.kmath.stat.*
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
import org.apache.commons.rng.simple.RandomSource
import java.time.Duration
diff --git a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt
index 7d53e5178..6146e17af 100644
--- a/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt
+++ b/examples/src/main/kotlin/kscience/kmath/commons/prob/DistributionDemo.kt
@@ -3,9 +3,9 @@ package kscience.kmath.commons.prob
import kotlinx.coroutines.runBlocking
import kscience.kmath.chains.Chain
import kscience.kmath.chains.collectWithState
-import kscience.kmath.prob.Distribution
-import kscience.kmath.prob.RandomGenerator
-import kscience.kmath.prob.normal
+import kscience.kmath.stat.Distribution
+import kscience.kmath.stat.RandomGenerator
+import kscience.kmath.stat.normal
private data class AveragingChainState(var num: Int = 0, var value: Double = 0.0)
diff --git a/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt
index 34b3c9981..e84fd8df3 100644
--- a/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt
+++ b/examples/src/main/kotlin/kscience/kmath/operations/ComplexDemo.kt
@@ -6,8 +6,8 @@ import kscience.kmath.structures.complex
fun main() {
// 2d element
- val element = NDElement.complex(2, 2) { index: IntArray ->
- Complex(index[0].toDouble() - index[1].toDouble(), index[0].toDouble() + index[1].toDouble())
+ val element = NDElement.complex(2, 2) { (i,j) ->
+ Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble())
}
println(element)
diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt
index 483bc530c..5ca75e993 100644
--- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt
+++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt
@@ -14,8 +14,8 @@ import kotlin.contracts.contract
* @author Alexander Nozik
*/
public class MstExpression(public val algebra: Algebra, public val mst: MST) : Expression {
- private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra {
- override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
+ private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra {
+ override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
@@ -27,7 +27,7 @@ public class MstExpression(public val algebra: Algebra, public val mst: MS
error("Numeric nodes are not supported by $this")
}
- override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst)
+ override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst)
}
/**
@@ -37,7 +37,7 @@ public class MstExpression(public val algebra: Algebra, public val mst: MS
*/
public inline fun , E : Algebra> A.mst(
mstAlgebra: E,
- block: E.() -> MST
+ block: E.() -> MST,
): MstExpression = MstExpression(this, mstAlgebra.block())
/**
@@ -116,7 +116,7 @@ public inline fun > FunctionalExpressionField> FunctionalExpressionExtendedField.mstInExtendedField(
- block: MstExtendedField.() -> MST
+ block: MstExtendedField.() -> MST,
): MstExpression {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block)
diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt
index 06f02a94d..a1e482103 100644
--- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt
+++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt
@@ -25,7 +25,7 @@ internal class AsmBuilder internal constructor(
private val classOfT: Class<*>,
private val algebra: Algebra,
private val className: String,
- private val invokeLabel0Visitor: AsmBuilder.() -> Unit
+ private val invokeLabel0Visitor: AsmBuilder.() -> Unit,
) {
/**
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
@@ -379,22 +379,14 @@ internal class AsmBuilder internal constructor(
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be
* provided.
*/
- internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
+ internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
load(invokeArgumentsVar, MAP_TYPE)
aconst(name)
- if (defaultValue != null)
- loadTConstant(defaultValue)
-
invokestatic(
MAP_INTRINSICS_TYPE.internalName,
"getOrFail",
-
- Type.getMethodDescriptor(
- OBJECT_TYPE,
- MAP_TYPE,
- OBJECT_TYPE,
- *OBJECT_TYPE.wrapToArrayIf { defaultValue != null }),
+ Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
false
)
@@ -429,7 +421,7 @@ internal class AsmBuilder internal constructor(
method: String,
descriptor: String,
expectedArity: Int,
- opcode: Int = INVOKEINTERFACE
+ opcode: Int = INVOKEINTERFACE,
) {
run loop@{
repeat(expectedArity) {
diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt
index 708b3c2b4..588b9611a 100644
--- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt
+++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt
@@ -2,11 +2,12 @@
package kscience.kmath.asm.internal
+import kscience.kmath.expressions.StringSymbol
+import kscience.kmath.expressions.Symbol
+
/**
- * Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
+ * Gets value with given [key] or throws [NoSuchElementException] whenever it is not present.
*
* @author Iaroslav Postovalov
*/
-@JvmOverloads
-internal fun Map.getOrFail(key: K, default: V? = null): V =
- this[key] ?: default ?: error("Parameter not found: $key")
+internal fun Map.getOrFail(key: String): V = getValue(StringSymbol(key))
diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt
index 0cf1307d1..5eebfe43d 100644
--- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt
+++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt
@@ -1,6 +1,5 @@
package kscience.kmath.asm
-import kscience.kmath.asm.compile
import kscience.kmath.ast.mstInField
import kscience.kmath.ast.mstInRing
import kscience.kmath.ast.mstInSpace
@@ -11,6 +10,7 @@ import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmAlgebras {
+
@Test
fun space() {
val res1 = ByteRing.mstInSpace {
diff --git a/kmath-commons/build.gradle.kts b/kmath-commons/build.gradle.kts
index ed6452ad8..6a44c92f2 100644
--- a/kmath-commons/build.gradle.kts
+++ b/kmath-commons/build.gradle.kts
@@ -6,7 +6,7 @@ description = "Commons math binding for kmath"
dependencies {
api(project(":kmath-core"))
api(project(":kmath-coroutines"))
- api(project(":kmath-prob"))
-// api(project(":kmath-functions"))
+ api(project(":kmath-stat"))
+ api(project(":kmath-functions"))
api("org.apache.commons:commons-math3:3.6.1")
}
diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt
new file mode 100644
index 000000000..c593f5103
--- /dev/null
+++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt
@@ -0,0 +1,115 @@
+package kscience.kmath.commons.expressions
+
+import kscience.kmath.expressions.*
+import kscience.kmath.operations.ExtendedField
+import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
+
+/**
+ * A field over commons-math [DerivativeStructure].
+ *
+ * @property order The derivation order.
+ * @property bindings The map of bindings values. All bindings are considered free parameters
+ */
+public class DerivativeStructureField(
+ public val order: Int,
+ private val bindings: Map
+) : ExtendedField, ExpressionAlgebra {
+ public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
+ public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
+
+ /**
+ * A class that implements both [DerivativeStructure] and a [Symbol]
+ */
+ public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
+ DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
+ override val identity: String = symbol.identity
+ override fun toString(): String = identity
+ override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
+ override fun hashCode(): Int = identity.hashCode()
+ }
+
+ /**
+ * Identity-based symbol bindings map
+ */
+ private val variables: Map = bindings.entries.associate { (key, value) ->
+ key.identity to DerivativeStructureSymbol(key, value)
+ }
+
+ override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value)
+
+ public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
+
+ public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
+
+ //public fun Number.const(): DerivativeStructure = const(toDouble())
+
+ public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
+ return derivative(mapOf(parameter to order))
+ }
+
+ public fun DerivativeStructure.derivative(orders: Map): Double {
+ return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
+ }
+
+ public fun DerivativeStructure.derivative(vararg orders: Pair): Double = derivative(mapOf(*orders))
+ public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
+
+ public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
+ is Double -> a.multiply(k)
+ is Int -> a.multiply(k)
+ else -> a.multiply(k.toDouble())
+ }
+
+ public override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
+ public override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
+ public override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
+ public override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
+ public override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
+ public override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
+ public override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
+ public override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
+ public override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
+ public override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
+ public override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
+ public override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
+ public override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
+ public override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
+
+ public override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
+ is Double -> arg.pow(pow)
+ is Int -> arg.pow(pow)
+ else -> arg.pow(pow.toDouble())
+ }
+
+ public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
+ public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
+ public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
+
+ public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
+ public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
+ public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
+ public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
+
+ public companion object : AutoDiffProcessor {
+ override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression {
+ return DerivativeStructureExpression(function)
+ }
+ }
+}
+
+/**
+ * A constructs that creates a derivative structure with required order on-demand
+ */
+public class DerivativeStructureExpression(
+ public val function: DerivativeStructureField.() -> DerivativeStructure,
+) : DifferentiableExpression {
+ public override operator fun invoke(arguments: Map): Double =
+ DerivativeStructureField(0, arguments).function().value
+
+ /**
+ * Get the derivative expression with given orders
+ */
+ public override fun derivativeOrNull(orders: Map): Expression = Expression { arguments ->
+ with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
+ }
+}
diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt
deleted file mode 100644
index c39f0d04c..000000000
--- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt
+++ /dev/null
@@ -1,131 +0,0 @@
-package kscience.kmath.commons.expressions
-
-import kscience.kmath.expressions.Expression
-import kscience.kmath.expressions.ExpressionAlgebra
-import kscience.kmath.operations.ExtendedField
-import kscience.kmath.operations.Field
-import kscience.kmath.operations.invoke
-import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
-import kotlin.properties.ReadOnlyProperty
-
-/**
- * A field over commons-math [DerivativeStructure].
- *
- * @property order The derivation order.
- * @property parameters The map of free parameters.
- */
-public class DerivativeStructureField(
- public val order: Int,
- public val parameters: Map
-) : ExtendedField {
- public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) }
- public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) }
-
- private val variables: Map = parameters.mapValues { (key, value) ->
- DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
- }
-
- public val variable: ReadOnlyProperty = ReadOnlyProperty { _, property ->
- variables[property.name] ?: error("A variable with name ${property.name} does not exist")
- }
-
- public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
- variables[name] ?: default ?: error("A variable with name $name does not exist")
-
- public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
-
- public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
- return deriv(mapOf(parName to order))
- }
-
- public fun DerivativeStructure.deriv(orders: Map): Double {
- return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
- }
-
- public fun DerivativeStructure.deriv(vararg orders: Pair): Double = deriv(mapOf(*orders))
- public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
-
- public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
- is Double -> a.multiply(k)
- is Int -> a.multiply(k)
- else -> a.multiply(k.toDouble())
- }
-
- public override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
- public override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
- public override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
- public override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
- public override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
- public override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
- public override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
- public override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
- public override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
- public override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
- public override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
- public override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
- public override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
- public override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
-
- public override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
- is Double -> arg.pow(pow)
- is Int -> arg.pow(pow)
- else -> arg.pow(pow.toDouble())
- }
-
- public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
- public override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
- public override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
-
- public override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
- public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
- public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
- public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
-}
-
-/**
- * A constructs that creates a derivative structure with required order on-demand
- */
-public class DiffExpression(public val function: DerivativeStructureField.() -> DerivativeStructure) :
- Expression {
- public override operator fun invoke(arguments: Map): Double = DerivativeStructureField(
- 0,
- arguments
- ).function().value
-
- /**
- * Get the derivative expression with given orders
- * TODO make result [DiffExpression]
- */
- public fun derivative(orders: Map): Expression = Expression { arguments ->
- (DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().deriv(orders) }
- }
-
- //TODO add gradient and maybe other vector operators
-}
-
-public fun DiffExpression.derivative(vararg orders: Pair): Expression = derivative(mapOf(*orders))
-public fun DiffExpression.derivative(name: String): Expression = derivative(name to 1)
-
-/**
- * A context for [DiffExpression] (not to be confused with [DerivativeStructure])
- */
-public object DiffExpressionAlgebra : ExpressionAlgebra, Field {
- public override val zero: DiffExpression = DiffExpression { 0.0.const() }
- public override val one: DiffExpression = DiffExpression { 1.0.const() }
-
- public override fun variable(name: String, default: Double?): DiffExpression =
- DiffExpression { variable(name, default?.const()) }
-
- public override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
-
- public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
- DiffExpression { a.function(this) + b.function(this) }
-
- public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k }
-
- public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
- DiffExpression { a.function(this) * b.function(this) }
-
- public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
- DiffExpression { a.function(this) / b.function(this) }
-}
diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt
new file mode 100644
index 000000000..13f9af7bb
--- /dev/null
+++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/CMOptimizationProblem.kt
@@ -0,0 +1,111 @@
+package kscience.kmath.commons.optimization
+
+import kscience.kmath.expressions.*
+import kscience.kmath.stat.OptimizationFeature
+import kscience.kmath.stat.OptimizationProblem
+import kscience.kmath.stat.OptimizationProblemFactory
+import kscience.kmath.stat.OptimizationResult
+import org.apache.commons.math3.optim.*
+import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
+import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer
+import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction
+import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
+import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
+import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.AbstractSimplex
+import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
+import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
+import kotlin.reflect.KClass
+
+public operator fun PointValuePair.component1(): DoubleArray = point
+public operator fun PointValuePair.component2(): Double = value
+
+public class CMOptimizationProblem(
+ override val symbols: List,
+) : OptimizationProblem, SymbolIndexer, OptimizationFeature {
+ private val optimizationData: HashMap, OptimizationData> = HashMap()
+ private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
+ public var convergenceChecker: ConvergenceChecker = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
+ DEFAULT_ABSOLUTE_TOLERANCE, DEFAULT_MAX_ITER)
+
+ public fun addOptimizationData(data: OptimizationData) {
+ optimizationData[data::class] = data
+ }
+
+ init {
+ addOptimizationData(MaxEval.unlimited())
+ }
+
+ public fun exportOptimizationData(): List = optimizationData.values.toList()
+
+ public override fun initialGuess(map: Map): Unit {
+ addOptimizationData(InitialGuess(map.toDoubleArray()))
+ }
+
+ public override fun expression(expression: Expression): Unit {
+ val objectiveFunction = ObjectiveFunction {
+ val args = it.toMap()
+ expression(args)
+ }
+ addOptimizationData(objectiveFunction)
+ }
+
+ public override fun diffExpression(expression: DifferentiableExpression): Unit {
+ expression(expression)
+ val gradientFunction = ObjectiveFunctionGradient {
+ val args = it.toMap()
+ DoubleArray(symbols.size) { index ->
+ expression.derivative(symbols[index])(args)
+ }
+ }
+ addOptimizationData(gradientFunction)
+ if (optimizatorBuilder == null) {
+ optimizatorBuilder = {
+ NonLinearConjugateGradientOptimizer(
+ NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
+ convergenceChecker
+ )
+ }
+ }
+ }
+
+ public fun simplex(simplex: AbstractSimplex) {
+ addOptimizationData(simplex)
+ //Set optimization builder to simplex if it is not present
+ if (optimizatorBuilder == null) {
+ optimizatorBuilder = { SimplexOptimizer(convergenceChecker) }
+ }
+ }
+
+ public fun simplexSteps(steps: Map) {
+ simplex(NelderMeadSimplex(steps.toDoubleArray()))
+ }
+
+ public fun goal(goalType: GoalType) {
+ addOptimizationData(goalType)
+ }
+
+ public fun optimizer(block: () -> MultivariateOptimizer) {
+ optimizatorBuilder = block
+ }
+
+ override fun update(result: OptimizationResult) {
+ initialGuess(result.point)
+ }
+
+ override fun optimize(): OptimizationResult {
+ val optimizer = optimizatorBuilder?.invoke() ?: error("Optimizer not defined")
+ val (point, value) = optimizer.optimize(*optimizationData.values.toTypedArray())
+ return OptimizationResult(point.toMap(), value, setOf(this))
+ }
+
+ public companion object : OptimizationProblemFactory {
+ public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
+ public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
+ public const val DEFAULT_MAX_ITER: Int = 1000
+
+ override fun build(symbols: List): CMOptimizationProblem = CMOptimizationProblem(symbols)
+ }
+}
+
+public fun CMOptimizationProblem.initialGuess(vararg pairs: Pair): Unit = initialGuess(pairs.toMap())
+public fun CMOptimizationProblem.simplexSteps(vararg pairs: Pair): Unit = simplexSteps(pairs.toMap())
diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt
new file mode 100644
index 000000000..42475db6c
--- /dev/null
+++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/optimization/cmFit.kt
@@ -0,0 +1,70 @@
+package kscience.kmath.commons.optimization
+
+import kscience.kmath.commons.expressions.DerivativeStructureField
+import kscience.kmath.expressions.DifferentiableExpression
+import kscience.kmath.expressions.Expression
+import kscience.kmath.expressions.Symbol
+import kscience.kmath.stat.Fitting
+import kscience.kmath.stat.OptimizationResult
+import kscience.kmath.stat.optimizeWith
+import kscience.kmath.structures.Buffer
+import kscience.kmath.structures.asBuffer
+import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
+import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
+
+
+/**
+ * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
+ */
+public fun Fitting.chiSquared(
+ x: Buffer,
+ y: Buffer,
+ yErr: Buffer,
+ model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
+): DifferentiableExpression = chiSquared(DerivativeStructureField, x, y, yErr, model)
+
+/**
+ * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
+ */
+public fun Fitting.chiSquared(
+ x: Iterable,
+ y: Iterable,
+ yErr: Iterable,
+ model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
+): DifferentiableExpression = chiSquared(
+ DerivativeStructureField,
+ x.toList().asBuffer(),
+ y.toList().asBuffer(),
+ yErr.toList().asBuffer(),
+ model
+)
+
+
+/**
+ * Optimize expression without derivatives
+ */
+public fun Expression.optimize(
+ vararg symbols: Symbol,
+ configuration: CMOptimizationProblem.() -> Unit,
+): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
+
+
+/**
+ * Optimize differentiable expression
+ */
+public fun DifferentiableExpression.optimize(
+ vararg symbols: Symbol,
+ configuration: CMOptimizationProblem.() -> Unit,
+): OptimizationResult = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
+
+public fun DifferentiableExpression.minimize(
+ vararg startPoint: Pair,
+ configuration: CMOptimizationProblem.() -> Unit = {},
+): OptimizationResult {
+ require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" }
+ val problem = CMOptimizationProblem(startPoint.map { it.first }).apply(configuration)
+ problem.diffExpression(this)
+ problem.initialGuess(startPoint.toMap())
+ problem.goal(GoalType.MINIMIZE)
+ return problem.optimize()
+}
\ No newline at end of file
diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt
index 58609deae..1eab5f2bd 100644
--- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt
+++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/random/CMRandomGeneratorWrapper.kt
@@ -1,9 +1,10 @@
package kscience.kmath.commons.random
-import kscience.kmath.prob.RandomGenerator
+import kscience.kmath.stat.RandomGenerator
-public class CMRandomGeneratorWrapper(public val factory: (IntArray) -> RandomGenerator) :
- org.apache.commons.math3.random.RandomGenerator {
+public class CMRandomGeneratorWrapper(
+ public val factory: (IntArray) -> RandomGenerator,
+) : org.apache.commons.math3.random.RandomGenerator {
private var generator: RandomGenerator = factory(intArrayOf())
public override fun nextBoolean(): Boolean = generator.nextBoolean()
diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt
similarity index 51%
rename from kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt
rename to kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt
index f905e6818..8886e123f 100644
--- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt
+++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt
@@ -1,6 +1,6 @@
package kscience.kmath.commons.expressions
-import kscience.kmath.expressions.invoke
+import kscience.kmath.expressions.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.test.Test
@@ -8,33 +8,37 @@ import kotlin.test.assertEquals
internal inline fun diff(
order: Int,
- vararg parameters: Pair,
- block: DerivativeStructureField.() -> R
+ vararg parameters: Pair,
+ block: DerivativeStructureField.() -> R,
): R {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
}
internal class AutoDiffTest {
+ private val x by symbol
+ private val y by symbol
+
@Test
fun derivativeStructureFieldTest() {
- val res = diff(3, "x" to 1.0, "y" to 1.0) {
- val x by variable
- val y = variable("y")
+ val res: Double = diff(3, x to 1.0, y to 1.0) {
+ val x = bind(x)//by binding()
+ val y = symbol("y")
val z = x * (-sin(x * y) + y)
- z.deriv("x")
+ z.derivative(x)
}
+ println(res)
}
@Test
fun autoDifTest() {
- val f = DiffExpression {
- val x by variable
- val y by variable
+ val f = DerivativeStructureExpression {
+ val x by binding()
+ val y by binding()
x.pow(2) + 2 * x * y + y.pow(2) + 1
}
- assertEquals(10.0, f("x" to 1.0, "y" to 2.0))
- assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0))
+ assertEquals(10.0, f(x to 1.0, y to 2.0))
+ assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
}
}
diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt
new file mode 100644
index 000000000..4384a5124
--- /dev/null
+++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt
@@ -0,0 +1,66 @@
+package kscience.kmath.commons.optimization
+
+import kscience.kmath.commons.expressions.DerivativeStructureExpression
+import kscience.kmath.expressions.symbol
+import kscience.kmath.stat.Distribution
+import kscience.kmath.stat.Fitting
+import kscience.kmath.stat.RandomGenerator
+import kscience.kmath.stat.normal
+import kscience.kmath.structures.asBuffer
+import org.junit.jupiter.api.Test
+import kotlin.math.pow
+
+internal class OptimizeTest {
+ val x by symbol
+ val y by symbol
+
+ val normal = DerivativeStructureExpression {
+ exp(-bind(x).pow(2) / 2) + exp(-bind(y).pow(2) / 2)
+ }
+
+ @Test
+ fun testGradientOptimization() {
+ val result = normal.optimize(x, y) {
+ initialGuess(x to 1.0, y to 1.0)
+ //no need to select optimizer. Gradient optimizer is used by default because gradients are provided by function
+ }
+ println(result.point)
+ println(result.value)
+ }
+
+ @Test
+ fun testSimplexOptimization() {
+ val result = normal.optimize(x, y) {
+ initialGuess(x to 1.0, y to 1.0)
+ simplexSteps(x to 2.0, y to 0.5)
+ //this sets simplex optimizer
+ }
+ println(result.point)
+ println(result.value)
+ }
+
+ @Test
+ fun testCmFit() {
+ val a by symbol
+ val b by symbol
+ val c by symbol
+
+ val sigma = 1.0
+ val generator = Distribution.normal(0.0, sigma)
+ val chain = generator.sample(RandomGenerator.default(112667))
+ val x = (1..100).map { it.toDouble() }
+ val y = x.map { it ->
+ it.pow(2) + it + 1 + chain.nextDouble()
+ }
+ val yErr = x.map { sigma }
+ val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
+ val cWithDefault = bindOrNull(c) ?: one
+ bind(a) * x.pow(2) + bind(b) * x + cWithDefault
+ }
+
+ val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
+ println(result)
+ println("Chi2/dof = ${result.value / (x.size - 3)}")
+ }
+
+}
\ No newline at end of file
diff --git a/kmath-core/README.md b/kmath-core/README.md
index 2cf7ed5dc..5501b1d7a 100644
--- a/kmath-core/README.md
+++ b/kmath-core/README.md
@@ -7,12 +7,12 @@ The core features of KMath:
- [buffers](src/commonMain/kotlin/kscience/kmath/structures/Buffers.kt) : One-dimensional structure
- [expressions](src/commonMain/kotlin/kscience/kmath/expressions) : Functional Expressions
- [domains](src/commonMain/kotlin/kscience/kmath/domains) : Domains
- - [autodif](src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt) : Automatic differentiation
+ - [autodif](src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt) : Automatic differentiation
> #### Artifact:
>
-> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-1`.
+> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
>
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
>
@@ -22,25 +22,28 @@ The core features of KMath:
>
> ```gradle
> repositories {
+> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
+
> }
>
> dependencies {
-> implementation 'kscience.kmath:kmath-core:0.2.0-dev-1'
+> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
> }
> ```
> **Gradle Kotlin DSL:**
>
> ```kotlin
> repositories {
+> maven("https://dl.bintray.com/kotlin/kotlin-eap")
> maven("https://dl.bintray.com/mipt-npm/kscience")
> maven("https://dl.bintray.com/mipt-npm/dev")
> maven("https://dl.bintray.com/hotkeytlt/maven")
> }
>
> dependencies {
-> implementation("kscience.kmath:kmath-core:0.2.0-dev-1")
+> implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
> }
> ```
diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts
index b56151abe..b0849eca5 100644
--- a/kmath-core/build.gradle.kts
+++ b/kmath-core/build.gradle.kts
@@ -41,6 +41,6 @@ readme {
feature(
id = "autodif",
description = "Automatic differentiation",
- ref = "src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt"
+ ref = "src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt"
)
}
\ No newline at end of file
diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt
new file mode 100644
index 000000000..4fe73f283
--- /dev/null
+++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt
@@ -0,0 +1,39 @@
+package kscience.kmath.expressions
+
+/**
+ * An expression that provides derivatives
+ */
+public interface DifferentiableExpression : Expression{
+ public fun derivativeOrNull(orders: Map): Expression?
+}
+
+public fun DifferentiableExpression.derivative(orders: Map): Expression =
+ derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided")
+
+public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression =
+ derivative(mapOf(*orders))
+
+public fun DifferentiableExpression.derivative(symbol: Symbol): Expression = derivative(symbol to 1)
+
+public fun DifferentiableExpression.derivative(name: String): Expression =
+ derivative(StringSymbol(name) to 1)
+
+/**
+ * A [DifferentiableExpression] that defines only first derivatives
+ */
+public abstract class FirstDerivativeExpression : DifferentiableExpression {
+
+ public abstract fun derivativeOrNull(symbol: Symbol): Expression?
+
+ public override fun derivativeOrNull(orders: Map): Expression? {
+ val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
+ return derivativeOrNull(dSymbol)
+ }
+}
+
+/**
+ * A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
+ */
+public interface AutoDiffProcessor> {
+ public fun process(function: A.() -> I): DifferentiableExpression
+}
\ No newline at end of file
diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt
index 5ade9e3ca..ab9ff0e72 100644
--- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt
+++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt
@@ -1,6 +1,25 @@
package kscience.kmath.expressions
import kscience.kmath.operations.Algebra
+import kotlin.jvm.JvmName
+import kotlin.properties.ReadOnlyProperty
+
+/**
+ * A marker interface for a symbol. A symbol mus have an identity
+ */
+public interface Symbol {
+ /**
+ * Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
+ */
+ public val identity: String
+}
+
+/**
+ * A [Symbol] with a [String] identity
+ */
+public inline class StringSymbol(override val identity: String) : Symbol {
+ override fun toString(): String = identity
+}
/**
* An elementary function that could be invoked on a map of arguments
@@ -12,30 +31,69 @@ public fun interface Expression {
* @param arguments the map of arguments.
* @return the value.
*/
- public operator fun invoke(arguments: Map): T
-
- public companion object
+ public operator fun invoke(arguments: Map): T
}
+/**
+ * Invoke an expression without parameters
+ */
+public operator fun Expression.invoke(): T = invoke(emptyMap())
+//This method exists to avoid resolution ambiguity of vararg methods
+
/**
* Calls this expression from arguments.
*
* @param pairs the pair of arguments' names to values.
* @return the value.
*/
-public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs))
+@JvmName("callBySymbol")
+public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs))
+
+@JvmName("callByString")
+public operator fun Expression.invoke(vararg pairs: Pair): T =
+ invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
+
/**
* A context for expression construction
+ *
+ * @param T type of the constants for the expression
+ * @param E type of the actual expression state
*/
-public interface ExpressionAlgebra : Algebra {
+public interface ExpressionAlgebra : Algebra {
+
/**
- * Introduce a variable into expression context
+ * Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
*/
- public fun variable(name: String, default: T? = null): E
+ public fun bindOrNull(symbol: Symbol): E?
+
+ /**
+ * Bind a string to a context using [StringSymbol]
+ */
+ override fun symbol(value: String): E = bind(StringSymbol(value))
/**
* A constant expression which does not depend on arguments
*/
public fun const(value: T): E
}
+
+/**
+ * Bind a given [Symbol] to this context variable and produce context-specific object.
+ */
+public fun ExpressionAlgebra.bind(symbol: Symbol): E =
+ bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this")
+
+/**
+ * A delegate to create a symbol with a string identity in this scope
+ */
+public val symbol: ReadOnlyProperty = ReadOnlyProperty { thisRef, property ->
+ StringSymbol(property.name)
+}
+
+/**
+ * Bind a symbol by name inside the [ExpressionAlgebra]
+ */
+public fun ExpressionAlgebra.binding(): ReadOnlyProperty = ReadOnlyProperty { _, property ->
+ bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
+}
\ No newline at end of file
diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
index 5b050dd36..0630e8e4b 100644
--- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
+++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
@@ -2,67 +2,43 @@ package kscience.kmath.expressions
import kscience.kmath.operations.*
-internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) :
- Expression {
- override operator fun invoke(arguments: Map): T =
- context.unaryOperation(name, expr.invoke(arguments))
-}
-
-internal class FunctionalBinaryOperation(
- val context: Algebra,
- val name: String,
- val first: Expression,
- val second: Expression
-) : Expression {
- override operator fun invoke(arguments: Map): T =
- context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
-}
-
-internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression {
- override operator fun invoke(arguments: Map): T =
- arguments[name] ?: default ?: error("Parameter not found: $name")
-}
-
-internal class FunctionalConstantExpression(val value: T) : Expression {
- override operator fun invoke(arguments: Map): T = value
-}
-
-internal class FunctionalConstProductExpression(
- val context: Space,
- private val expr: Expression,
- val const: Number
-) : Expression {
- override operator fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const)
-}
-
/**
* A context class for [Expression] construction.
*
* @param algebra The algebra to provide for Expressions built.
*/
-public abstract class FunctionalExpressionAlgebra>(public val algebra: A) :
- ExpressionAlgebra> {
+public abstract class FunctionalExpressionAlgebra>(
+ public val algebra: A,
+) : ExpressionAlgebra> {
/**
* Builds an Expression of constant expression which does not depend on arguments.
*/
- public override fun const(value: T): Expression = FunctionalConstantExpression(value)
+ public override fun const(value: T): Expression = Expression { value }
/**
* Builds an Expression to access a variable.
*/
- public override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default)
+ public override fun bindOrNull(symbol: Symbol): Expression? = Expression { arguments ->
+ arguments[symbol] ?: error("Argument not found: $symbol")
+ }
/**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/
- public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression =
- FunctionalBinaryOperation(algebra, operation, left, right)
+ public override fun binaryOperation(
+ operation: String,
+ left: Expression,
+ right: Expression,
+ ): Expression = Expression { arguments ->
+ algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
+ }
/**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/
- public override fun unaryOperation(operation: String, arg: Expression): Expression =
- FunctionalUnaryOperation(algebra, operation, arg)
+ public override fun unaryOperation(operation: String, arg: Expression): Expression = Expression { arguments ->
+ algebra.unaryOperation(operation, arg.invoke(arguments))
+ }
}
/**
@@ -81,8 +57,9 @@ public open class FunctionalExpressionSpace>(algebra: A) :
/**
* Builds an Expression of multiplication of expression by number.
*/
- public override fun multiply(a: Expression, k: Number): Expression =
- FunctionalConstProductExpression(algebra, a, k)
+ public override fun multiply(a: Expression, k: Number): Expression = Expression { arguments ->
+ algebra.multiply(a.invoke(arguments), k)
+ }
public operator fun Expression.plus(arg: T): Expression = this + const(arg)
public operator fun Expression.minus(arg: T): Expression = this - const(arg)
@@ -118,8 +95,8 @@ public open class FunctionalExpressionRing(algebra: A) : FunctionalExpress
}
public open class FunctionalExpressionField(algebra: A) :
- FunctionalExpressionRing(algebra),
- Field> where A : Field, A : NumericAlgebra {
+ FunctionalExpressionRing(algebra), Field>
+ where A : Field, A : NumericAlgebra {
/**
* Builds an Expression of division an expression by another one.
*/
diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt
new file mode 100644
index 000000000..5a9642690
--- /dev/null
+++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt
@@ -0,0 +1,395 @@
+package kscience.kmath.expressions
+
+import kscience.kmath.linear.Point
+import kscience.kmath.operations.*
+import kscience.kmath.structures.asBuffer
+import kotlin.contracts.InvocationKind
+import kotlin.contracts.contract
+
+/*
+ * Implementation of backward-mode automatic differentiation.
+ * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
+ */
+
+
+public open class AutoDiffValue(public val value: T)
+
+/**
+ * Represents result of [simpleAutoDiff] call.
+ *
+ * @param T the non-nullable type of value.
+ * @param value the value of result.
+ * @property simpleAutoDiff The mapping of differentiated variables to their derivatives.
+ * @property context The field over [T].
+ */
+public class DerivationResult(
+ public val value: T,
+ private val derivativeValues: Map,
+ public val context: Field,
+) {
+ /**
+ * Returns derivative of [variable] or returns [Ring.zero] in [context].
+ */
+ public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero
+
+ /**
+ * Computes the divergence.
+ */
+ public fun div(): T = context { sum(derivativeValues.values) }
+}
+
+/**
+ * Computes the gradient for variables in given order.
+ */
+public fun DerivationResult.grad(vararg variables: Symbol): Point {
+ check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
+ return variables.map(::derivative).asBuffer()
+}
+
+/**
+ * Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
+ *
+ * The partial derivatives are placed in argument `d` variable
+ *
+ * Example:
+ * ```
+ * val x by symbol // define variable(s) and their values
+ * val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
+ * assertEquals(17.0, y.x) // the value of result (y)
+ * assertEquals(9.0, x.d) // dy/dx
+ * ```
+ *
+ * @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
+ * @return the result of differentiation.
+ */
+public fun > F.simpleAutoDiff(
+ bindings: Map,
+ body: SimpleAutoDiffField.() -> AutoDiffValue,
+): DerivationResult {
+ contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
+
+ return SimpleAutoDiffField(this, bindings).derivate(body)
+}
+
+public fun > F.simpleAutoDiff(
+ vararg bindings: Pair,
+ body: SimpleAutoDiffField.() -> AutoDiffValue,
+): DerivationResult = simpleAutoDiff(bindings.toMap(), body)
+
+/**
+ * Represents field in context of which functions can be derived.
+ */
+public open class SimpleAutoDiffField>(
+ public val context: F,
+ bindings: Map,
+) : Field>, ExpressionAlgebra> {
+
+ // this stack contains pairs of blocks and values to apply them to
+ private var stack: Array = arrayOfNulls(8)
+ private var sp: Int = 0
+ private val derivatives: MutableMap, T> = hashMapOf()
+
+ /**
+ * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
+ * with respect to this variable.
+ *
+ * @param T the non-nullable type of value.
+ * @property value The value of this variable.
+ */
+ private class AutoDiffVariableWithDerivative(
+ override val identity: String,
+ value: T,
+ var d: T,
+ ) : AutoDiffValue(value), Symbol {
+ override fun toString(): String = identity
+ override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
+ override fun hashCode(): Int = identity.hashCode()
+ }
+
+ private val bindings: Map> = bindings.entries.associate {
+ it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
+ }
+
+ override fun bindOrNull(symbol: Symbol): AutoDiffValue? = bindings[symbol.identity]
+
+ private fun getDerivative(variable: AutoDiffValue): T =
+ (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
+
+ private fun setDerivative(variable: AutoDiffValue, value: T) {
+ if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
+ }
+
+
+ @Suppress("UNCHECKED_CAST")
+ private fun runBackwardPass() {
+ while (sp > 0) {
+ val value = stack[--sp]
+ val block = stack[--sp] as F.(Any?) -> Unit
+ context.block(value)
+ }
+ }
+
+ override val zero: AutoDiffValue get() = const(context.zero)
+ override val one: AutoDiffValue get() = const(context.one)
+
+ override fun const(value: T): AutoDiffValue = AutoDiffValue(value)
+
+ /**
+ * A variable accessing inner state of derivatives.
+ * Use this value in inner builders to avoid creating additional derivative bindings.
+ */
+ public var AutoDiffValue.d: T
+ get() = getDerivative(this)
+ set(value) = setDerivative(this, value)
+
+ public inline fun const(block: F.() -> T): AutoDiffValue = const(context.block())
+
+ /**
+ * Performs update of derivative after the rest of the formula in the back-pass.
+ *
+ * For example, implementation of `sin` function is:
+ *
+ * ```
+ * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
+ * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
+ * }
+ * ```
+ */
+ @Suppress("UNCHECKED_CAST")
+ public fun derive(value: R, block: F.(R) -> Unit): R {
+ // save block to stack for backward pass
+ if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
+ stack[sp++] = block
+ stack[sp++] = value
+ return value
+ }
+
+
+ internal fun derivate(function: SimpleAutoDiffField.() -> AutoDiffValue): DerivationResult {
+ val result = function()
+ result.d = context.one // computing derivative w.r.t result
+ runBackwardPass()
+ return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
+ }
+
+ // Overloads for Double constants
+
+ override operator fun Number.plus(b: AutoDiffValue): AutoDiffValue =
+ derive(const { this@plus.toDouble() * one + b.value }) { z ->
+ b.d += z.d
+ }
+
+ override operator fun AutoDiffValue.plus(b: Number): AutoDiffValue = b.plus(this)
+
+ override operator fun Number.minus(b: AutoDiffValue): AutoDiffValue =
+ derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
+
+ override operator fun AutoDiffValue.minus(b: Number): AutoDiffValue =
+ derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
+
+
+ // Basic math (+, -, *, /)
+
+ override fun add(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue =
+ derive(const { a.value + b.value }) { z ->
+ a.d += z.d
+ b.d += z.d
+ }
+
+ override fun multiply(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue =
+ derive(const { a.value * b.value }) { z ->
+ a.d += z.d * b.value
+ b.d += z.d * a.value
+ }
+
+ override fun divide(a: AutoDiffValue, b: AutoDiffValue): AutoDiffValue