Refactoring

This commit is contained in:
Alexander Nozik 2020-10-22 11:27:08 +03:00
parent 04d3f4a99f
commit f7614da230
6 changed files with 55 additions and 32 deletions

View File

@ -26,6 +26,9 @@ public class DerivativeStructureField(
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) : public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) :
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol { DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol {
override val identity: String = symbol.identity override val identity: String = symbol.identity
override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
} }
/** /**

View File

@ -0,0 +1,21 @@
package kscience.kmath.expressions
/**
* And object that could be differentiated
*/
public interface Differentiable<T> {
public fun derivative(orders: Map<Symbol, Int>): T
}
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
derivative(mapOf(*orders))
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> {
// public override fun expression(block: A.() -> E): DifferentiableExpression<T>
//}

View File

@ -56,21 +56,6 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T =
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) }) invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
/**
* And object that could be differentiated
*/
public interface Differentiable<T> {
public fun derivative(orders: Map<Symbol, Int>): T
}
public interface DifferentiableExpression<T> : Differentiable<Expression<T>>, Expression<T>
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> =
derivative(mapOf(*orders))
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = derivative(StringSymbol(name) to 1)
/** /**
* A context for expression construction * A context for expression construction
@ -96,6 +81,10 @@ public interface ExpressionAlgebra<in T, E> : Algebra<E> {
public fun const(value: T): E public fun const(value: T): E
} }
//public interface ExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>> {
// public fun expression(block: A.() -> E): Expression<T>
//}
/** /**
* Bind a given [Symbol] to this context variable and produce context-specific object. * Bind a given [Symbol] to this context variable and produce context-specific object.
*/ */

View File

@ -7,8 +7,9 @@ import kscience.kmath.operations.*
* *
* @param algebra The algebra to provide for Expressions built. * @param algebra The algebra to provide for Expressions built.
*/ */
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) : public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
ExpressionAlgebra<T, Expression<T>> { public val algebra: A,
) : ExpressionAlgebra<T, Expression<T>> {
/** /**
* Builds an Expression of constant expression which does not depend on arguments. * Builds an Expression of constant expression which does not depend on arguments.
*/ */

View File

@ -15,11 +15,11 @@ import kotlin.contracts.contract
public open class AutoDiffValue<out T>(public val value: T) public open class AutoDiffValue<out T>(public val value: T)
/** /**
* Represents result of [withAutoDiff] call. * Represents result of [simpleAutoDiff] call.
* *
* @param T the non-nullable type of value. * @param T the non-nullable type of value.
* @param value the value of result. * @param value the value of result.
* @property withAutoDiff The mapping of differentiated variables to their derivatives. * @property simpleAutoDiff The mapping of differentiated variables to their derivatives.
* @property context The field over [T]. * @property context The field over [T].
*/ */
public class DerivationResult<T : Any>( public class DerivationResult<T : Any>(
@ -62,7 +62,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to. * @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
* @return the result of differentiation. * @return the result of differentiation.
*/ */
public fun <T : Any, F : Field<T>> F.withAutoDiff( public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>, body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> { ): DerivationResult<T> {
@ -71,10 +71,10 @@ public fun <T : Any, F : Field<T>> F.withAutoDiff(
return AutoDiffContext(this, bindings).derivate(body) return AutoDiffContext(this, bindings).derivate(body)
} }
public fun <T : Any, F : Field<T>> F.withAutoDiff( public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
vararg bindings: Pair<Symbol, T>, vararg bindings: Pair<Symbol, T>,
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>, body: AutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> = withAutoDiff(bindings.toMap(), body) ): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
/** /**
* Represents field in context of which functions can be derived. * Represents field in context of which functions can be derived.
@ -136,7 +136,7 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
override val one: AutoDiffValue<T> get() = const(context.one) override val one: AutoDiffValue<T> get() = const(context.one)
/** /**
* Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
* with respect to this variable. * with respect to this variable.
* *
* @param T the non-nullable type of value. * @param T the non-nullable type of value.
@ -148,6 +148,8 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
var d: T, var d: T,
) : AutoDiffValue<T>(value), Symbol{ ) : AutoDiffValue<T>(value), Symbol{
override fun toString(): String = identity override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
} }
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate { private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {

View File

@ -10,21 +10,17 @@ import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
class SimpleAutoDiffTest { class SimpleAutoDiffTest {
fun d(
vararg bindings: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(bindings = bindings, body)
fun dx( fun dx(
xBinding: Pair<Symbol, Double>, xBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>, body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) } ): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
fun dxy( fun dxy(
xBinding: Pair<Symbol, Double>, xBinding: Pair<Symbol, Double>,
yBinding: Pair<Symbol, Double>, yBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>, body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.withAutoDiff(xBinding, yBinding) { ): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
body(bind(xBinding.first), bind(yBinding.first)) body(bind(xBinding.first), bind(yBinding.first))
} }
@ -38,7 +34,7 @@ class SimpleAutoDiffTest {
@Test @Test
fun testPlusX2() { fun testPlusX2() {
val y = d(x to 3.0) { val y = RealField.simpleAutoDiff(x to 3.0) {
// diff w.r.t this x at 3 // diff w.r.t this x at 3
val x = bind(x) val x = bind(x)
x + x x + x
@ -47,10 +43,21 @@ class SimpleAutoDiffTest {
assertEquals(2.0, y.derivative(x)) // dy/dx = 2 assertEquals(2.0, y.derivative(x)) // dy/dx = 2
} }
@Test
fun testPlusX2Expr() {
val expr = diff{
val x = bind(x)
x + x
}
assertEquals(6.0, expr(x to 3.0)) // y = x + x = 6
assertEquals(2.0, expr.derivative(x)(x to 3.0)) // dy/dx = 2
}
@Test @Test
fun testPlus() { fun testPlus() {
// two variables // two variables
val z = d(x to 2.0, y to 3.0) { val z = RealField.simpleAutoDiff(x to 2.0, y to 3.0) {
val x = bind(x) val x = bind(x)
val y = bind(y) val y = bind(y)
x + y x + y
@ -63,7 +70,7 @@ class SimpleAutoDiffTest {
@Test @Test
fun testMinus() { fun testMinus() {
// two variables // two variables
val z = d(x to 7.0, y to 3.0) { val z = RealField.simpleAutoDiff(x to 7.0, y to 3.0) {
val x = bind(x) val x = bind(x)
val y = bind(y) val y = bind(y)