Feature/diff api #154
@ -26,6 +26,9 @@ public class DerivativeStructureField(
|
|||||||
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
||||||
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
||||||
override val identity: String = symbol.identity
|
override val identity: String = symbol.identity
|
||||||
|
override fun toString(): String = identity
|
||||||
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -0,0 +1,21 @@
|
|||||||
|
|||||||
|
package kscience.kmath.expressions
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
/**
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
* And object that could be differentiated
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
*/
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public interface Differentiable<T> {
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public fun derivative(orders: Map<Symbol, Int>): T
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
}
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
derivative(mapOf(*orders))
|
||||||
Overload with Overload with `vararg symbols: Symbol` for order 1 can be added, too.
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
No sense in that. People will use first, maximum second derivatives. There is an extension for the first. The second one could be added any moment. No sense in that. People will use first, maximum second derivatives. There is an extension for the first. The second one could be added any moment.
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
|||||||
|
//}
|
||||||
1. SAM interface.
2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
Removed Removed
|
@ -56,21 +56,6 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T =
|
|||||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||||
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||||
|
|
||||||
/**
|
|
||||||
* And object that could be differentiated
|
|
||||||
*/
|
|
||||||
public interface Differentiable<T> {
|
|
||||||
public fun derivative(orders: Map<Symbol, Int>): T
|
|
||||||
}
|
|
||||||
|
|
||||||
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
|
||||||
derivative(mapOf(*orders))
|
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for expression construction
|
* A context for expression construction
|
||||||
@ -96,6 +81,10 @@ public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
|||||||
public fun const(value: T): E
|
public fun const(value: T): E
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//public interface ExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>> {
|
||||||
|
// public fun expression(block: A.() -> E): Expression<T>
|
||||||
|
//}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bind a given [Symbol] to this context variable and produce context-specific object.
|
* Bind a given [Symbol] to this context variable and produce context-specific object.
|
||||||
*/
|
*/
|
||||||
|
@ -7,8 +7,9 @@ import kscience.kmath.operations.*
|
|||||||
*
|
*
|
||||||
* @param algebra The algebra to provide for Expressions built.
|
* @param algebra The algebra to provide for Expressions built.
|
||||||
*/
|
*/
|
||||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
|
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
||||||
ExpressionAlgebra<T, Expression<T>> {
|
public val algebra: A,
|
||||||
|
) : ExpressionAlgebra<T, Expression<T>> {
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of constant expression which does not depend on arguments.
|
* Builds an Expression of constant expression which does not depend on arguments.
|
||||||
*/
|
*/
|
||||||
|
@ -15,11 +15,11 @@ import kotlin.contracts.contract
|
|||||||
public open class AutoDiffValue<out T>(public val value: T)
|
public open class AutoDiffValue<out T>(public val value: T)
|
||||||
I am sure that it may be replaced with sealed class to prevent one to extend AutoDiffValue with irrelevant object. I am sure that it may be replaced with sealed class to prevent one to extend AutoDiffValue with irrelevant object.
The idea is that it could be extended anytime. Here The idea is that it could be extended anytime. Here `AutoDiffValue` is just a marker interface. It is possible to even replace it with an inline class, but we need performance measurements to make that change.
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents result of [withAutoDiff] call.
|
* Represents result of [simpleAutoDiff] call.
|
||||||
*
|
*
|
||||||
* @param T the non-nullable type of value.
|
* @param T the non-nullable type of value.
|
||||||
* @param value the value of result.
|
* @param value the value of result.
|
||||||
* @property withAutoDiff The mapping of differentiated variables to their derivatives.
|
* @property simpleAutoDiff The mapping of differentiated variables to their derivatives.
|
||||||
* @property context The field over [T].
|
* @property context The field over [T].
|
||||||
*/
|
*/
|
||||||
public class DerivationResult<T : Any>(
|
public class DerivationResult<T : Any>(
|
||||||
@ -62,7 +62,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
|||||||
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||||
* @return the result of differentiation.
|
* @return the result of differentiation.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
): DerivationResult<T> {
|
): DerivationResult<T> {
|
||||||
@ -71,10 +71,10 @@ public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
|||||||
return AutoDiffContext(this, bindings).derivate(body)
|
return AutoDiffContext(this, bindings).derivate(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
vararg bindings: Pair<Symbol, T>,
|
vararg bindings: Pair<Symbol, T>,
|
||||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
): DerivationResult<T> = withAutoDiff(bindings.toMap(), body)
|
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents field in context of which functions can be derived.
|
* Represents field in context of which functions can be derived.
|
||||||
@ -136,7 +136,7 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
|||||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
|
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||||
* with respect to this variable.
|
* with respect to this variable.
|
||||||
*
|
*
|
||||||
* @param T the non-nullable type of value.
|
* @param T the non-nullable type of value.
|
||||||
@ -148,6 +148,8 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
|||||||
var d: T,
|
var d: T,
|
||||||
) : AutoDiffValue<T>(value), Symbol{
|
) : AutoDiffValue<T>(value), Symbol{
|
||||||
override fun toString(): String = identity
|
override fun toString(): String = identity
|
||||||
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
|
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
|
||||||
|
@ -10,21 +10,17 @@ import kotlin.test.assertEquals
|
|||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class SimpleAutoDiffTest {
|
class SimpleAutoDiffTest {
|
||||||
fun d(
|
|
||||||
vararg bindings: Pair<Symbol, Double>,
|
|
||||||
body: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>,
|
|
||||||
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
|
|
||||||
|
|
||||||
fun dx(
|
fun dx(
|
||||||
xBinding: Pair<Symbol, Double>,
|
xBinding: Pair<Symbol, Double>,
|
||||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||||
|
|
||||||
fun dxy(
|
fun dxy(
|
||||||
xBinding: Pair<Symbol, Double>,
|
xBinding: Pair<Symbol, Double>,
|
||||||
yBinding: Pair<Symbol, Double>,
|
yBinding: Pair<Symbol, Double>,
|
||||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) {
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
|
||||||
body(bind(xBinding.first), bind(yBinding.first))
|
body(bind(xBinding.first), bind(yBinding.first))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +34,7 @@ class SimpleAutoDiffTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPlusX2() {
|
fun testPlusX2() {
|
||||||
val y = d(x to 3.0) {
|
val y = RealField.simpleAutoDiff(x to 3.0) {
|
||||||
// diff w.r.t this x at 3
|
// diff w.r.t this x at 3
|
||||||
val x = bind(x)
|
val x = bind(x)
|
||||||
x + x
|
x + x
|
||||||
@ -47,10 +43,21 @@ class SimpleAutoDiffTest {
|
|||||||
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
|
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlusX2Expr() {
|
||||||
|
val expr = diff{
|
||||||
|
val x = bind(x)
|
||||||
|
x + x
|
||||||
|
}
|
||||||
|
assertEquals(6.0, expr(x to 3.0)) // y = x + x = 6
|
||||||
|
assertEquals(2.0, expr.derivative(x)(x to 3.0)) // dy/dx = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPlus() {
|
fun testPlus() {
|
||||||
// two variables
|
// two variables
|
||||||
val z = d(x to 2.0, y to 3.0) {
|
val z = RealField.simpleAutoDiff(x to 2.0, y to 3.0) {
|
||||||
val x = bind(x)
|
val x = bind(x)
|
||||||
val y = bind(y)
|
val y = bind(y)
|
||||||
x + y
|
x + y
|
||||||
@ -63,7 +70,7 @@ class SimpleAutoDiffTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testMinus() {
|
fun testMinus() {
|
||||||
// two variables
|
// two variables
|
||||||
val z = d(x to 7.0, y to 3.0) {
|
val z = RealField.simpleAutoDiff(x to 7.0, y to 3.0) {
|
||||||
val x = bind(x)
|
val x = bind(x)
|
||||||
val y = bind(y)
|
val y = bind(y)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user
Removed
Removed