Refactoring
This commit is contained in:
parent
04d3f4a99f
commit
f7614da230
@ -26,6 +26,9 @@ public class DerivativeStructureField(
|
||||
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
||||
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
||||
override val identity: String = symbol.identity
|
||||
override fun toString(): String = identity
|
||||
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||
override fun hashCode(): Int = identity.hashCode()
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -0,0 +1,21 @@
|
||||
package kscience.kmath.expressions
|
||||
|
||||
/**
|
||||
* And object that could be differentiated
|
||||
*/
|
||||
public interface Differentiable<T> {
|
||||
public fun derivative(orders: Map<Symbol, Int>): T
|
||||
}
|
||||
|
||||
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||
derivative(mapOf(*orders))
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
||||
|
||||
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
|
||||
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
|
||||
//}
|
@ -56,21 +56,6 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T =
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||
|
||||
/**
|
||||
* And object that could be differentiated
|
||||
*/
|
||||
public interface Differentiable<T> {
|
||||
public fun derivative(orders: Map<Symbol, Int>): T
|
||||
}
|
||||
|
||||
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||
derivative(mapOf(*orders))
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
||||
|
||||
/**
|
||||
* A context for expression construction
|
||||
@ -96,6 +81,10 @@ public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
||||
public fun const(value: T): E
|
||||
}
|
||||
|
||||
//public interface ExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>> {
|
||||
// public fun expression(block: A.() -> E): Expression<T>
|
||||
//}
|
||||
|
||||
/**
|
||||
* Bind a given [Symbol] to this context variable and produce context-specific object.
|
||||
*/
|
||||
|
@ -7,8 +7,9 @@ import kscience.kmath.operations.*
|
||||
*
|
||||
* @param algebra The algebra to provide for Expressions built.
|
||||
*/
|
||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
|
||||
ExpressionAlgebra<T, Expression<T>> {
|
||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
||||
public val algebra: A,
|
||||
) : ExpressionAlgebra<T, Expression<T>> {
|
||||
/**
|
||||
* Builds an Expression of constant expression which does not depend on arguments.
|
||||
*/
|
||||
|
@ -15,11 +15,11 @@ import kotlin.contracts.contract
|
||||
public open class AutoDiffValue<out T>(public val value: T)
|
||||
|
||||
/**
|
||||
* Represents result of [withAutoDiff] call.
|
||||
* Represents result of [simpleAutoDiff] call.
|
||||
*
|
||||
* @param T the non-nullable type of value.
|
||||
* @param value the value of result.
|
||||
* @property withAutoDiff The mapping of differentiated variables to their derivatives.
|
||||
* @property simpleAutoDiff The mapping of differentiated variables to their derivatives.
|
||||
* @property context The field over [T].
|
||||
*/
|
||||
public class DerivationResult<T : Any>(
|
||||
@ -62,7 +62,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
||||
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||
* @return the result of differentiation.
|
||||
*/
|
||||
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||
bindings: Map<Symbol, T>,
|
||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
): DerivationResult<T> {
|
||||
@ -71,10 +71,10 @@ public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
||||
return AutoDiffContext(this, bindings).derivate(body)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||
vararg bindings: Pair<Symbol, T>,
|
||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
): DerivationResult<T> = withAutoDiff(bindings.toMap(), body)
|
||||
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
|
||||
|
||||
/**
|
||||
* Represents field in context of which functions can be derived.
|
||||
@ -136,7 +136,7 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||
|
||||
/**
|
||||
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
|
||||
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||
* with respect to this variable.
|
||||
*
|
||||
* @param T the non-nullable type of value.
|
||||
@ -148,6 +148,8 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
||||
var d: T,
|
||||
) : AutoDiffValue<T>(value), Symbol{
|
||||
override fun toString(): String = identity
|
||||
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||
override fun hashCode(): Int = identity.hashCode()
|
||||
}
|
||||
|
||||
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
|
||||
|
@ -10,21 +10,17 @@ import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class SimpleAutoDiffTest {
|
||||
fun d(
|
||||
vararg bindings: Pair<Symbol, Double>,
|
||||
body: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>,
|
||||
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
|
||||
|
||||
fun dx(
|
||||
xBinding: Pair<Symbol, Double>,
|
||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||
|
||||
fun dxy(
|
||||
xBinding: Pair<Symbol, Double>,
|
||||
yBinding: Pair<Symbol, Double>,
|
||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) {
|
||||
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
|
||||
body(bind(xBinding.first), bind(yBinding.first))
|
||||
}
|
||||
|
||||
@ -38,7 +34,7 @@ class SimpleAutoDiffTest {
|
||||
|
||||
@Test
|
||||
fun testPlusX2() {
|
||||
val y = d(x to 3.0) {
|
||||
val y = RealField.simpleAutoDiff(x to 3.0) {
|
||||
// diff w.r.t this x at 3
|
||||
val x = bind(x)
|
||||
x + x
|
||||
@ -47,10 +43,21 @@ class SimpleAutoDiffTest {
|
||||
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPlusX2Expr() {
|
||||
val expr = diff{
|
||||
val x = bind(x)
|
||||
x + x
|
||||
}
|
||||
assertEquals(6.0, expr(x to 3.0)) // y = x + x = 6
|
||||
assertEquals(2.0, expr.derivative(x)(x to 3.0)) // dy/dx = 2
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun testPlus() {
|
||||
// two variables
|
||||
val z = d(x to 2.0, y to 3.0) {
|
||||
val z = RealField.simpleAutoDiff(x to 2.0, y to 3.0) {
|
||||
val x = bind(x)
|
||||
val y = bind(y)
|
||||
x + y
|
||||
@ -63,7 +70,7 @@ class SimpleAutoDiffTest {
|
||||
@Test
|
||||
fun testMinus() {
|
||||
// two variables
|
||||
val z = d(x to 7.0, y to 3.0) {
|
||||
val z = RealField.simpleAutoDiff(x to 7.0, y to 3.0) {
|
||||
val x = bind(x)
|
||||
val y = bind(y)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user