diff --git a/README.md b/README.md
index 708bd8eb1..cbdf98afb 100644
--- a/README.md
+++ b/README.md
@@ -53,9 +53,7 @@ can be used for a wide variety of purposes from high performance calculations to
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
to submit a feature request if you want something to be done first.
-
-* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
-
+
## Planned features
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
@@ -117,6 +115,12 @@ can be used for a wide variety of purposes from high performance calculations to
> **Maturity**: EXPERIMENTAL
+* ### [kmath-ejml](kmath-ejml)
+>
+>
+> **Maturity**: EXPERIMENTAL
+
+
* ### [kmath-for-real](kmath-for-real)
>
>
@@ -178,8 +182,8 @@ repositories{
}
dependencies{
- api("kscience.kmath:kmath-core:0.2.0-dev-1")
- //api("kscience.kmath:kmath-core-jvm:0.2.0-dev-1") for jvm-specific version
+ api("kscience.kmath:kmath-core:0.2.0-dev-2")
+ //api("kscience.kmath:kmath-core-jvm:0.2.0-dev-2") for jvm-specific version
}
```
diff --git a/build.gradle.kts b/build.gradle.kts
index 239ea1296..74b76d731 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -24,4 +24,8 @@ subprojects {
readme {
readmeTemplate = file("docs/templates/README-TEMPLATE.md")
+}
+
+apiValidation{
+ validationDisabled = true
}
\ No newline at end of file
diff --git a/docs/templates/README-TEMPLATE.md b/docs/templates/README-TEMPLATE.md
index f451adb24..5117e0694 100644
--- a/docs/templates/README-TEMPLATE.md
+++ b/docs/templates/README-TEMPLATE.md
@@ -107,4 +107,4 @@ with the same artifact names.
## Contributing
-The project requires a lot of additional work. Please feel free to contribute in any way and propose new features.
+The project requires a lot of additional work. The most important thing we need is a feedback about what features are required the most. Feel free to open feature issues with requests. We are also welcome to code contributions, especially in issues marked as [waiting for a hero](https://github.com/mipt-npm/kmath/labels/waiting%20for%20a%20hero).
\ No newline at end of file
diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt
index 483bc530c..5ca75e993 100644
--- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt
+++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt
@@ -14,8 +14,8 @@ import kotlin.contracts.contract
* @author Alexander Nozik
*/
public class MstExpression(public val algebra: Algebra, public val mst: MST) : Expression {
- private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra {
- override fun symbol(value: String): T = arguments[value] ?: algebra.symbol(value)
+ private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra {
+ override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
@@ -27,7 +27,7 @@ public class MstExpression(public val algebra: Algebra, public val mst: MS
error("Numeric nodes are not supported by $this")
}
- override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst)
+ override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst)
}
/**
@@ -37,7 +37,7 @@ public class MstExpression(public val algebra: Algebra, public val mst: MS
*/
public inline fun , E : Algebra> A.mst(
mstAlgebra: E,
- block: E.() -> MST
+ block: E.() -> MST,
): MstExpression = MstExpression(this, mstAlgebra.block())
/**
@@ -116,7 +116,7 @@ public inline fun > FunctionalExpressionField> FunctionalExpressionExtendedField.mstInExtendedField(
- block: MstExtendedField.() -> MST
+ block: MstExtendedField.() -> MST,
): MstExpression {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block)
diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt
index 708b3c2b4..09e9a71b0 100644
--- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt
+++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/mapIntrinsics.kt
@@ -2,6 +2,8 @@
package kscience.kmath.asm.internal
+import kscience.kmath.expressions.StringSymbol
+
/**
* Gets value with given [key] or throws [IllegalStateException] whenever it is not present.
*
@@ -9,4 +11,4 @@ package kscience.kmath.asm.internal
*/
@JvmOverloads
internal fun Map.getOrFail(key: K, default: V? = null): V =
- this[key] ?: default ?: error("Parameter not found: $key")
+ this[StringSymbol(key.toString())] ?: default ?: error("Parameter not found: $key")
diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt
index 0cf1307d1..5eebfe43d 100644
--- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt
+++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt
@@ -1,6 +1,5 @@
package kscience.kmath.asm
-import kscience.kmath.asm.compile
import kscience.kmath.ast.mstInField
import kscience.kmath.ast.mstInRing
import kscience.kmath.ast.mstInSpace
@@ -11,6 +10,7 @@ import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmAlgebras {
+
@Test
fun space() {
val res1 = ByteRing.mstInSpace {
diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt
similarity index 50%
rename from kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt
rename to kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt
index 1eca1a773..9a27e40cd 100644
--- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DiffExpression.kt
+++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt
@@ -1,48 +1,57 @@
package kscience.kmath.commons.expressions
+import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.ExpressionAlgebra
+import kscience.kmath.expressions.Symbol
import kscience.kmath.operations.ExtendedField
-import kscience.kmath.operations.Field
-import kscience.kmath.operations.invoke
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
-import kotlin.properties.ReadOnlyProperty
/**
* A field over commons-math [DerivativeStructure].
*
* @property order The derivation order.
- * @property parameters The map of free parameters.
+ * @property bindings The map of bindings values. All bindings are considered free parameters
*/
public class DerivativeStructureField(
public val order: Int,
- public val parameters: Map,
-) : ExtendedField {
- public override val zero: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order) }
- public override val one: DerivativeStructure by lazy { DerivativeStructure(parameters.size, order, 1.0) }
+ private val bindings: Map
+) : ExtendedField, ExpressionAlgebra {
+ public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) }
+ public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) }
- private val variables: Map = parameters.mapValues { (key, value) ->
- DerivativeStructure(parameters.size, order, parameters.keys.indexOf(key), value)
+ /**
+ * A class that implements both [DerivativeStructure] and a [Symbol]
+ */
+ public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
+ DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
+ override val identity: Any = symbol.identity
}
- public val variable: ReadOnlyProperty = ReadOnlyProperty { _, property ->
- variables[property.name] ?: error("A variable with name ${property.name} does not exist")
+ /**
+ * Identity-based symbol bindings map
+ */
+ private val variables: Map = bindings.entries.associate { (key, value) ->
+ key.identity to DerivativeStructureSymbol(key, value)
}
- public fun variable(name: String, default: DerivativeStructure? = null): DerivativeStructure =
- variables[name] ?: default ?: error("A variable with name $name does not exist")
+ override fun const(value: Double): DerivativeStructure = DerivativeStructure(order, bindings.size, value)
- public fun Number.const(): DerivativeStructure = DerivativeStructure(order, parameters.size, toDouble())
+ public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
- public fun DerivativeStructure.deriv(parName: String, order: Int = 1): Double {
- return deriv(mapOf(parName to order))
+ public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
+
+ public fun Number.const(): DerivativeStructure = const(toDouble())
+
+ public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double {
+ return derivative(mapOf(parameter to order))
}
- public fun DerivativeStructure.deriv(orders: Map): Double {
- return getPartialDerivative(*parameters.keys.map { orders[it] ?: 0 }.toIntArray())
+ public fun DerivativeStructure.derivative(orders: Map): Double {
+ return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
}
- public fun DerivativeStructure.deriv(vararg orders: Pair): Double = deriv(mapOf(*orders))
+ public fun DerivativeStructure.derivative(vararg orders: Pair): Double = derivative(mapOf(*orders))
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
@@ -85,48 +94,16 @@ public class DerivativeStructureField(
/**
* A constructs that creates a derivative structure with required order on-demand
*/
-public class DiffExpression(
+public class DerivativeStructureExpression(
public val function: DerivativeStructureField.() -> DerivativeStructure,
-) : Expression {
- public override operator fun invoke(arguments: Map): Double = DerivativeStructureField(
- 0,
- arguments
- ).function().value
+) : DifferentiableExpression {
+ public override operator fun invoke(arguments: Map): Double =
+ DerivativeStructureField(0, arguments).function().value
/**
* Get the derivative expression with given orders
- * TODO make result [DiffExpression]
*/
- public fun derivative(orders: Map): Expression = Expression { arguments ->
- (DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().deriv(orders) }
+ public override fun derivative(orders: Map): Expression = Expression { arguments ->
+ with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) }
}
-
- //TODO add gradient and maybe other vector operators
-}
-
-public fun DiffExpression.derivative(vararg orders: Pair): Expression = derivative(mapOf(*orders))
-public fun DiffExpression.derivative(name: String): Expression = derivative(name to 1)
-
-/**
- * A context for [DiffExpression] (not to be confused with [DerivativeStructure])
- */
-public object DiffExpressionAlgebra : ExpressionAlgebra, Field {
- public override val zero: DiffExpression = DiffExpression { 0.0.const() }
- public override val one: DiffExpression = DiffExpression { 1.0.const() }
-
- public override fun variable(name: String, default: Double?): DiffExpression =
- DiffExpression { variable(name, default?.const()) }
-
- public override fun const(value: Double): DiffExpression = DiffExpression { value.const() }
-
- public override fun add(a: DiffExpression, b: DiffExpression): DiffExpression =
- DiffExpression { a.function(this) + b.function(this) }
-
- public override fun multiply(a: DiffExpression, k: Number): DiffExpression = DiffExpression { a.function(this) * k }
-
- public override fun multiply(a: DiffExpression, b: DiffExpression): DiffExpression =
- DiffExpression { a.function(this) * b.function(this) }
-
- public override fun divide(a: DiffExpression, b: DiffExpression): DiffExpression =
- DiffExpression { a.function(this) / b.function(this) }
}
diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt
similarity index 51%
rename from kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt
rename to kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt
index 197faaf49..8886e123f 100644
--- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/AutoDiffTest.kt
+++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt
@@ -1,6 +1,6 @@
package kscience.kmath.commons.expressions
-import kscience.kmath.expressions.invoke
+import kscience.kmath.expressions.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.test.Test
@@ -8,33 +8,37 @@ import kotlin.test.assertEquals
internal inline fun diff(
order: Int,
- vararg parameters: Pair,
- block: DerivativeStructureField.() -> R
+ vararg parameters: Pair,
+ block: DerivativeStructureField.() -> R,
): R {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return DerivativeStructureField(order, mapOf(*parameters)).run(block)
}
internal class AutoDiffTest {
+ private val x by symbol
+ private val y by symbol
+
@Test
fun derivativeStructureFieldTest() {
- val res: Double = diff(3, "x" to 1.0, "y" to 1.0) {
- val x by variable
- val y = variable("y")
+ val res: Double = diff(3, x to 1.0, y to 1.0) {
+ val x = bind(x)//by binding()
+ val y = symbol("y")
val z = x * (-sin(x * y) + y)
- z.deriv("x")
+ z.derivative(x)
}
+ println(res)
}
@Test
fun autoDifTest() {
- val f = DiffExpression {
- val x by variable
- val y by variable
+ val f = DerivativeStructureExpression {
+ val x by binding()
+ val y by binding()
x.pow(2) + 2 * x * y + y.pow(2) + 1
}
- assertEquals(10.0, f("x" to 1.0, "y" to 2.0))
- assertEquals(6.0, f.derivative("x")("x" to 1.0, "y" to 2.0))
+ assertEquals(10.0, f(x to 1.0, y to 2.0))
+ assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
}
}
diff --git a/kmath-core/README.md b/kmath-core/README.md
index 2cf7ed5dc..6935c0d3c 100644
--- a/kmath-core/README.md
+++ b/kmath-core/README.md
@@ -12,7 +12,7 @@ The core features of KMath:
> #### Artifact:
>
-> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-1`.
+> This module artifact: `kscience.kmath:kmath-core:0.2.0-dev-2`.
>
> Bintray release version: [ ![Download](https://api.bintray.com/packages/mipt-npm/kscience/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/kscience/kmath-core/_latestVersion)
>
@@ -22,25 +22,28 @@ The core features of KMath:
>
> ```gradle
> repositories {
+> maven { url "https://dl.bintray.com/kotlin/kotlin-eap" }
> maven { url 'https://dl.bintray.com/mipt-npm/kscience' }
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
> maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
+
> }
>
> dependencies {
-> implementation 'kscience.kmath:kmath-core:0.2.0-dev-1'
+> implementation 'kscience.kmath:kmath-core:0.2.0-dev-2'
> }
> ```
> **Gradle Kotlin DSL:**
>
> ```kotlin
> repositories {
+> maven("https://dl.bintray.com/kotlin/kotlin-eap")
> maven("https://dl.bintray.com/mipt-npm/kscience")
> maven("https://dl.bintray.com/mipt-npm/dev")
> maven("https://dl.bintray.com/hotkeytlt/maven")
> }
>
> dependencies {
-> implementation("kscience.kmath:kmath-core:0.2.0-dev-1")
+> implementation("kscience.kmath:kmath-core:0.2.0-dev-2")
> }
> ```
diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts
index b56151abe..bd254c39d 100644
--- a/kmath-core/build.gradle.kts
+++ b/kmath-core/build.gradle.kts
@@ -41,6 +41,6 @@ readme {
feature(
id = "autodif",
description = "Automatic differentiation",
- ref = "src/commonMain/kotlin/kscience/kmath/misc/AutoDiff.kt"
+ ref = "src/commonMain/kotlin/kscience/kmath/misc/SimpleAutoDiff.kt"
)
}
\ No newline at end of file
diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt
index 5ade9e3ca..d64eb5a55 100644
--- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt
+++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt
@@ -1,6 +1,26 @@
package kscience.kmath.expressions
import kscience.kmath.operations.Algebra
+import kotlin.jvm.JvmName
+import kotlin.properties.ReadOnlyProperty
+
+/**
+ * A marker interface for a symbol. A symbol mus have an identity
+ */
+public interface Symbol {
+ /**
+ * Identity object for the symbol. Two symbols with the same identity are considered to be the same symbol.
+ * By default uses object identity
+ */
+ public val identity: Any get() = this
+}
+
+/**
+ * A [Symbol] with a [String] identity
+ */
+public inline class StringSymbol(override val identity: String) : Symbol {
+ override fun toString(): String = identity
+}
/**
* An elementary function that could be invoked on a map of arguments
@@ -12,30 +32,81 @@ public fun interface Expression {
* @param arguments the map of arguments.
* @return the value.
*/
- public operator fun invoke(arguments: Map): T
+ public operator fun invoke(arguments: Map): T
public companion object
}
+/**
+ * Invlode an expression without parameters
+ */
+public operator fun Expression.invoke(): T = invoke(emptyMap())
+//This method exists to avoid resolution ambiguity of vararg methods
+
/**
* Calls this expression from arguments.
*
* @param pairs the pair of arguments' names to values.
* @return the value.
*/
-public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs))
+@JvmName("callBySymbol")
+public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs))
+
+@JvmName("callByString")
+public operator fun Expression.invoke(vararg pairs: Pair): T =
+ invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
+
+/**
+ * And object that could be differentiated
+ */
+public interface Differentiable {
+ public fun derivative(orders: Map): T
+}
+
+public interface DifferentiableExpression : Differentiable>, Expression
+
+public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression =
+ derivative(mapOf(*orders))
+
+public fun DifferentiableExpression.derivative(symbol: Symbol): Expression = derivative(symbol to 1)
+
+public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name) to 1)
/**
* A context for expression construction
+ *
+ * @param T type of the constants for the expression
+ * @param E type of the actual expression state
*/
-public interface ExpressionAlgebra : Algebra {
+public interface ExpressionAlgebra : Algebra {
+
/**
- * Introduce a variable into expression context
+ * Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
*/
- public fun variable(name: String, default: T? = null): E
+ public fun bindOrNull(symbol: Symbol): E?
+
+ /**
+ * Bind a string to a context using [StringSymbol]
+ */
+ override fun symbol(value: String): E = bind(StringSymbol(value))
/**
* A constant expression which does not depend on arguments
*/
public fun const(value: T): E
}
+
+/**
+ * Bind a given [Symbol] to this context variable and produce context-specific object.
+ */
+public fun ExpressionAlgebra.bind(symbol: Symbol): E =
+ bindOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this")
+
+public val symbol: ReadOnlyProperty = ReadOnlyProperty { _, property ->
+ StringSymbol(property.name)
+}
+
+public fun ExpressionAlgebra.binding(): ReadOnlyProperty =
+ ReadOnlyProperty { _, property ->
+ bind(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
+ }
\ No newline at end of file
diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
index 5b050dd36..9fd15238a 100644
--- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
+++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt
@@ -2,39 +2,6 @@ package kscience.kmath.expressions
import kscience.kmath.operations.*
-internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) :
- Expression {
- override operator fun invoke(arguments: Map): T =
- context.unaryOperation(name, expr.invoke(arguments))
-}
-
-internal class FunctionalBinaryOperation(
- val context: Algebra,
- val name: String,
- val first: Expression,
- val second: Expression
-) : Expression {
- override operator fun invoke(arguments: Map): T =
- context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
-}
-
-internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression {
- override operator fun invoke(arguments: Map): T =
- arguments[name] ?: default ?: error("Parameter not found: $name")
-}
-
-internal class FunctionalConstantExpression(val value: T) : Expression {
- override operator fun invoke(arguments: Map): T = value
-}
-
-internal class FunctionalConstProductExpression(
- val context: Space,
- private val expr: Expression,
- val const: Number
-) : Expression {
- override operator fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const)
-}
-
/**
* A context class for [Expression] construction.
*
@@ -45,24 +12,32 @@ public abstract class FunctionalExpressionAlgebra>(public val
/**
* Builds an Expression of constant expression which does not depend on arguments.
*/
- public override fun const(value: T): Expression = FunctionalConstantExpression(value)
+ public override fun const(value: T): Expression = Expression { value }
/**
* Builds an Expression to access a variable.
*/
- public override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default)
+ public override fun bindOrNull(symbol: Symbol): Expression? = Expression { arguments ->
+ arguments[symbol] ?: error("Argument not found: $symbol")
+ }
/**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/
- public override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression =
- FunctionalBinaryOperation(algebra, operation, left, right)
+ public override fun binaryOperation(
+ operation: String,
+ left: Expression,
+ right: Expression,
+ ): Expression = Expression { arguments ->
+ algebra.binaryOperation(operation, left.invoke(arguments), right.invoke(arguments))
+ }
/**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/
- public override fun unaryOperation(operation: String, arg: Expression): Expression =
- FunctionalUnaryOperation(algebra, operation, arg)
+ public override fun unaryOperation(operation: String, arg: Expression): Expression = Expression { arguments ->
+ algebra.unaryOperation(operation, arg.invoke(arguments))
+ }
}
/**
@@ -81,8 +56,9 @@ public open class FunctionalExpressionSpace>(algebra: A) :
/**
* Builds an Expression of multiplication of expression by number.
*/
- public override fun multiply(a: Expression, k: Number): Expression =
- FunctionalConstProductExpression(algebra, a, k)
+ public override fun multiply(a: Expression, k: Number): Expression = Expression { arguments ->
+ algebra.multiply(a.invoke(arguments), k)
+ }
public operator fun Expression.plus(arg: T): Expression = this + const(arg)
public operator fun Expression.minus(arg: T): Expression = this - const(arg)
@@ -118,8 +94,8 @@ public open class FunctionalExpressionRing(algebra: A) : FunctionalExpress
}
public open class FunctionalExpressionField(algebra: A) :
- FunctionalExpressionRing(algebra),
- Field> where A : Field, A : NumericAlgebra {
+ FunctionalExpressionRing(algebra), Field>
+ where A : Field, A : NumericAlgebra {
/**
* Builds an Expression of division an expression by another one.
*/
diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt
new file mode 100644
index 000000000..5e8fe3e99
--- /dev/null
+++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt
@@ -0,0 +1,329 @@
+package kscience.kmath.expressions
+
+import kscience.kmath.linear.Point
+import kscience.kmath.operations.*
+import kscience.kmath.structures.asBuffer
+import kotlin.contracts.InvocationKind
+import kotlin.contracts.contract
+
+/*
+ * Implementation of backward-mode automatic differentiation.
+ * Initial gist by Roman Elizarov: https://gist.github.com/elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d
+ */
+
+
+/**
+ * A [Symbol] with bound value
+ */
+public interface BoundSymbol : Symbol {
+ public val value: T
+}
+
+/**
+ * Bind a [Symbol] to a [value] and produce [BoundSymbol]
+ */
+public fun Symbol.bind(value: T): BoundSymbol = object : BoundSymbol {
+ override val identity = this@bind.identity
+ override val value: T = value
+}
+
+/**
+ * Represents result of [withAutoDiff] call.
+ *
+ * @param T the non-nullable type of value.
+ * @param value the value of result.
+ * @property withAutoDiff The mapping of differentiated variables to their derivatives.
+ * @property context The field over [T].
+ */
+public class DerivationResult(
+ override val value: T,
+ private val derivativeValues: Map,
+ public val context: Field,
+) : BoundSymbol {
+ /**
+ * Returns derivative of [variable] or returns [Ring.zero] in [context].
+ */
+ public fun derivative(variable: Symbol): T = derivativeValues[variable.identity] ?: context.zero
+
+ /**
+ * Computes the divergence.
+ */
+ public fun div(): T = context { sum(derivativeValues.values) }
+}
+
+/**
+ * Computes the gradient for variables in given order.
+ */
+public fun DerivationResult.grad(vararg variables: Symbol): Point {
+ check(variables.isNotEmpty()) { "Variable order is not provided for gradient construction" }
+ return variables.map(::derivative).asBuffer()
+}
+
+/**
+ * Runs differentiation and establishes [AutoDiffField] context inside the block of code.
+ *
+ * The partial derivatives are placed in argument `d` variable
+ *
+ * Example:
+ * ```
+ * val x by symbol // define variable(s) and their values
+ * val y = RealField.withAutoDiff() { sqr(x) + 5 * x + 3 } // write formulate in deriv context
+ * assertEquals(17.0, y.x) // the value of result (y)
+ * assertEquals(9.0, x.d) // dy/dx
+ * ```
+ *
+ * @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
+ * @return the result of differentiation.
+ */
+public fun > F.withAutoDiff(
+ bindings: Collection>,
+ body: AutoDiffField.() -> BoundSymbol,
+): DerivationResult {
+ contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
+
+ return AutoDiffContext(this, bindings).derivate(body)
+}
+
+public fun > F.withAutoDiff(
+ vararg bindings: Pair,
+ body: AutoDiffField.() -> BoundSymbol,
+): DerivationResult = withAutoDiff(bindings.map { it.first.bind(it.second) }, body)
+
+/**
+ * Represents field in context of which functions can be derived.
+ */
+public abstract class AutoDiffField>
+ : Field>, ExpressionAlgebra> {
+
+ public abstract val context: F
+
+ /**
+ * A variable accessing inner state of derivatives.
+ * Use this value in inner builders to avoid creating additional derivative bindings.
+ */
+ public abstract var BoundSymbol.d: T
+
+ /**
+ * Performs update of derivative after the rest of the formula in the back-pass.
+ *
+ * For example, implementation of `sin` function is:
+ *
+ * ```
+ * fun AD.sin(x: Variable): Variable = derive(Variable(sin(x.x)) { z -> // call derive with function result
+ * x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
+ * }
+ * ```
+ */
+ public abstract fun derive(value: R, block: F.(R) -> Unit): R
+
+ public inline fun const(block: F.() -> T): BoundSymbol = const(context.block())
+
+ // Overloads for Double constants
+
+ override operator fun Number.plus(b: BoundSymbol): BoundSymbol =
+ derive(const { this@plus.toDouble() * one + b.value }) { z ->
+ b.d += z.d
+ }
+
+ override operator fun BoundSymbol.plus(b: Number): BoundSymbol = b.plus(this)
+
+ override operator fun Number.minus(b: BoundSymbol): BoundSymbol =
+ derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
+
+ override operator fun BoundSymbol.minus(b: Number): BoundSymbol =
+ derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
+}
+
+/**
+ * Automatic Differentiation context class.
+ */
+private class AutoDiffContext>(
+ override val context: F,
+ bindings: Collection>,
+) : AutoDiffField() {
+ // this stack contains pairs of blocks and values to apply them to
+ private var stack: Array = arrayOfNulls(8)
+ private var sp: Int = 0
+ private val derivatives: MutableMap = hashMapOf()
+ override val zero: BoundSymbol get() = const(context.zero)
+ override val one: BoundSymbol get() = const(context.one)
+
+ /**
+ * Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
+ * with respect to this variable.
+ *
+ * @param T the non-nullable type of value.
+ * @property value The value of this variable.
+ */
+ private class AutoDiffVariableWithDeriv(override val value: T, var d: T) : BoundSymbol
+
+ private val bindings: Map> = bindings.associateBy { it.identity }
+
+ override fun bindOrNull(symbol: Symbol): BoundSymbol? = bindings[symbol.identity]
+
+ override fun const(value: T): BoundSymbol = AutoDiffVariableWithDeriv(value, context.zero)
+
+ override var BoundSymbol.d: T
+ get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[identity] ?: context.zero
+ set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[identity] = value
+
+ @Suppress("UNCHECKED_CAST")
+ override fun derive(value: R, block: F.(R) -> Unit): R {
+ // save block to stack for backward pass
+ if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
+ stack[sp++] = block
+ stack[sp++] = value
+ return value
+ }
+
+ @Suppress("UNCHECKED_CAST")
+ fun runBackwardPass() {
+ while (sp > 0) {
+ val value = stack[--sp]
+ val block = stack[--sp] as F.(Any?) -> Unit
+ context.block(value)
+ }
+ }
+
+ // Basic math (+, -, *, /)
+
+ override fun add(a: BoundSymbol, b: BoundSymbol): BoundSymbol =
+ derive(const { a.value + b.value }) { z ->
+ a.d += z.d
+ b.d += z.d
+ }
+
+ override fun multiply(a: BoundSymbol, b: BoundSymbol): BoundSymbol =
+ derive(const { a.value * b.value }) { z ->
+ a.d += z.d * b.value
+ b.d += z.d * a.value
+ }
+
+ override fun divide(a: BoundSymbol, b: BoundSymbol): BoundSymbol =
+ derive(const { a.value / b.value }) { z ->
+ a.d += z.d / b.value
+ b.d -= z.d * a.value / (b.value * b.value)
+ }
+
+ override fun multiply(a: BoundSymbol, k: Number): BoundSymbol =
+ derive(const { k.toDouble() * a.value }) { z ->
+ a.d += z.d * k.toDouble()
+ }
+
+ inline fun derivate(function: AutoDiffField.() -> BoundSymbol): DerivationResult {
+ val result = function()
+ result.d = context.one // computing derivative w.r.t result
+ runBackwardPass()
+ return DerivationResult(result.value, derivatives, context)
+ }
+}
+
+/**
+ * A constructs that creates a derivative structure with required order on-demand
+ */
+public class SimpleAutoDiffExpression>(
+ public val field: F,
+ public val function: AutoDiffField.() -> BoundSymbol,
+) : DifferentiableExpression {
+ public override operator fun invoke(arguments: Map): T {
+ val bindings = arguments.entries.map { it.key.bind(it.value) }
+ return AutoDiffContext(field, bindings).function().value
+ }
+
+ /**
+ * Get the derivative expression with given orders
+ */
+ public override fun derivative(orders: Map): Expression {
+ val dSymbol = orders.entries.singleOrNull { it.value == 1 }
+ ?: error("SimpleAutoDiff supports only first order derivatives")
+ return Expression { arguments ->
+ val bindings = arguments.entries.map { it.key.bind(it.value) }
+ val derivationResult = AutoDiffContext(field, bindings).derivate(function)
+ derivationResult.derivative(dSymbol.key)
+ }
+ }
+}
+
+
+// Extensions for differentiation of various basic mathematical functions
+
+// x ^ 2
+public fun > AutoDiffField.sqr(x: BoundSymbol): BoundSymbol =
+ derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
+
+// x ^ 1/2
+public fun > AutoDiffField.sqrt(x: BoundSymbol): BoundSymbol =
+ derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
+
+// x ^ y (const)
+public fun > AutoDiffField.pow(
+ x: BoundSymbol,
+ y: Double,
+): BoundSymbol =
+ derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
+
+public fun > AutoDiffField.pow(
+ x: BoundSymbol,
+ y: Int,
+): BoundSymbol =
+ pow(x, y.toDouble())
+
+// exp(x)
+public fun > AutoDiffField.exp(x: BoundSymbol): BoundSymbol =
+ derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
+
+// ln(x)
+public fun > AutoDiffField.ln(x: BoundSymbol): BoundSymbol =
+ derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
+
+// x ^ y (any)
+public fun > AutoDiffField.pow(
+ x: BoundSymbol,
+ y: BoundSymbol,
+): BoundSymbol =
+ exp(y * ln(x))
+
+// sin(x)
+public fun > AutoDiffField.sin(x: BoundSymbol): BoundSymbol =
+ derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
+
+// cos(x)
+public fun > AutoDiffField.cos(x: BoundSymbol): BoundSymbol =
+ derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
+
+public fun