forked from kscience/kmath
Refactoring
This commit is contained in:
parent
04d3f4a99f
commit
f7614da230
@ -26,6 +26,9 @@ public class DerivativeStructureField(
|
|||||||
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
|
||||||
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
|
||||||
override val identity: String = symbol.identity
|
override val identity: String = symbol.identity
|
||||||
|
override fun toString(): String = identity
|
||||||
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -0,0 +1,21 @@
|
|||||||
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
|
/**
|
||||||
|
* And object that could be differentiated
|
||||||
|
*/
|
||||||
|
public interface Differentiable<T> {
|
||||||
|
public fun derivative(orders: Map<Symbol, Int>): T
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
||||||
|
derivative(mapOf(*orders))
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
||||||
|
|
||||||
|
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
||||||
|
|
||||||
|
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
|
||||||
|
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
|
||||||
|
//}
|
@ -56,21 +56,6 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T =
|
|||||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||||
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||||
|
|
||||||
/**
|
|
||||||
* And object that could be differentiated
|
|
||||||
*/
|
|
||||||
public interface Differentiable<T> {
|
|
||||||
public fun derivative(orders: Map<Symbol, Int>): T
|
|
||||||
}
|
|
||||||
|
|
||||||
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
|
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
|
|
||||||
derivative(mapOf(*orders))
|
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
|
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for expression construction
|
* A context for expression construction
|
||||||
@ -96,6 +81,10 @@ public interface ExpressionAlgebra<in T, E> : Algebra<E> {
|
|||||||
public fun const(value: T): E
|
public fun const(value: T): E
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//public interface ExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>> {
|
||||||
|
// public fun expression(block: A.() -> E): Expression<T>
|
||||||
|
//}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bind a given [Symbol] to this context variable and produce context-specific object.
|
* Bind a given [Symbol] to this context variable and produce context-specific object.
|
||||||
*/
|
*/
|
||||||
|
@ -7,8 +7,9 @@ import kscience.kmath.operations.*
|
|||||||
*
|
*
|
||||||
* @param algebra The algebra to provide for Expressions built.
|
* @param algebra The algebra to provide for Expressions built.
|
||||||
*/
|
*/
|
||||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
|
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
||||||
ExpressionAlgebra<T, Expression<T>> {
|
public val algebra: A,
|
||||||
|
) : ExpressionAlgebra<T, Expression<T>> {
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of constant expression which does not depend on arguments.
|
* Builds an Expression of constant expression which does not depend on arguments.
|
||||||
*/
|
*/
|
||||||
|
@ -15,11 +15,11 @@ import kotlin.contracts.contract
|
|||||||
public open class AutoDiffValue<out T>(public val value: T)
|
public open class AutoDiffValue<out T>(public val value: T)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents result of [withAutoDiff] call.
|
* Represents result of [simpleAutoDiff] call.
|
||||||
*
|
*
|
||||||
* @param T the non-nullable type of value.
|
* @param T the non-nullable type of value.
|
||||||
* @param value the value of result.
|
* @param value the value of result.
|
||||||
* @property withAutoDiff The mapping of differentiated variables to their derivatives.
|
* @property simpleAutoDiff The mapping of differentiated variables to their derivatives.
|
||||||
* @property context The field over [T].
|
* @property context The field over [T].
|
||||||
*/
|
*/
|
||||||
public class DerivationResult<T : Any>(
|
public class DerivationResult<T : Any>(
|
||||||
@ -62,7 +62,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
|||||||
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
|
||||||
* @return the result of differentiation.
|
* @return the result of differentiation.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
): DerivationResult<T> {
|
): DerivationResult<T> {
|
||||||
@ -71,10 +71,10 @@ public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
|||||||
return AutoDiffContext(this, bindings).derivate(body)
|
return AutoDiffContext(this, bindings).derivate(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : Field<T>> F.withAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
vararg bindings: Pair<Symbol, T>,
|
vararg bindings: Pair<Symbol, T>,
|
||||||
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
): DerivationResult<T> = withAutoDiff(bindings.toMap(), body)
|
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents field in context of which functions can be derived.
|
* Represents field in context of which functions can be derived.
|
||||||
@ -136,7 +136,7 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
|||||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result
|
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||||
* with respect to this variable.
|
* with respect to this variable.
|
||||||
*
|
*
|
||||||
* @param T the non-nullable type of value.
|
* @param T the non-nullable type of value.
|
||||||
@ -148,6 +148,8 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
|
|||||||
var d: T,
|
var d: T,
|
||||||
) : AutoDiffValue<T>(value), Symbol{
|
) : AutoDiffValue<T>(value), Symbol{
|
||||||
override fun toString(): String = identity
|
override fun toString(): String = identity
|
||||||
|
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
|
||||||
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
|
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
|
||||||
|
@ -10,21 +10,17 @@ import kotlin.test.assertEquals
|
|||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class SimpleAutoDiffTest {
|
class SimpleAutoDiffTest {
|
||||||
fun d(
|
|
||||||
vararg bindings: Pair<Symbol, Double>,
|
|
||||||
body: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>,
|
|
||||||
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
|
|
||||||
|
|
||||||
fun dx(
|
fun dx(
|
||||||
xBinding: Pair<Symbol, Double>,
|
xBinding: Pair<Symbol, Double>,
|
||||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
|
||||||
|
|
||||||
fun dxy(
|
fun dxy(
|
||||||
xBinding: Pair<Symbol, Double>,
|
xBinding: Pair<Symbol, Double>,
|
||||||
yBinding: Pair<Symbol, Double>,
|
yBinding: Pair<Symbol, Double>,
|
||||||
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
|
||||||
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) {
|
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
|
||||||
body(bind(xBinding.first), bind(yBinding.first))
|
body(bind(xBinding.first), bind(yBinding.first))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,7 +34,7 @@ class SimpleAutoDiffTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPlusX2() {
|
fun testPlusX2() {
|
||||||
val y = d(x to 3.0) {
|
val y = RealField.simpleAutoDiff(x to 3.0) {
|
||||||
// diff w.r.t this x at 3
|
// diff w.r.t this x at 3
|
||||||
val x = bind(x)
|
val x = bind(x)
|
||||||
x + x
|
x + x
|
||||||
@ -47,10 +43,21 @@ class SimpleAutoDiffTest {
|
|||||||
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
|
assertEquals(2.0, y.derivative(x)) // dy/dx = 2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testPlusX2Expr() {
|
||||||
|
val expr = diff{
|
||||||
|
val x = bind(x)
|
||||||
|
x + x
|
||||||
|
}
|
||||||
|
assertEquals(6.0, expr(x to 3.0)) // y = x + x = 6
|
||||||
|
assertEquals(2.0, expr.derivative(x)(x to 3.0)) // dy/dx = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testPlus() {
|
fun testPlus() {
|
||||||
// two variables
|
// two variables
|
||||||
val z = d(x to 2.0, y to 3.0) {
|
val z = RealField.simpleAutoDiff(x to 2.0, y to 3.0) {
|
||||||
val x = bind(x)
|
val x = bind(x)
|
||||||
val y = bind(y)
|
val y = bind(y)
|
||||||
x + y
|
x + y
|
||||||
@ -63,7 +70,7 @@ class SimpleAutoDiffTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testMinus() {
|
fun testMinus() {
|
||||||
// two variables
|
// two variables
|
||||||
val z = d(x to 7.0, y to 3.0) {
|
val z = RealField.simpleAutoDiff(x to 7.0, y to 3.0) {
|
||||||
val x = bind(x)
|
val x = bind(x)
|
||||||
val y = bind(y)
|
val y = bind(y)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user