From f7614da230a9e1491263b71b955110382f58a9e4 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 22 Oct 2020 11:27:08 +0300 Subject: [PATCH] Refactoring --- .../DerivativeStructureExpression.kt | 3 +++ .../expressions/DifferentiableExpression.kt | 21 ++++++++++++++++ .../kscience/kmath/expressions/Expression.kt | 19 +++----------- .../FunctionalExpressionAlgebra.kt | 5 ++-- .../kmath/expressions/SimpleAutoDiff.kt | 14 ++++++----- .../kmath/expressions/SimpleAutoDiffTest.kt | 25 ++++++++++++------- 6 files changed, 55 insertions(+), 32 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 2ec69255e..a1ee91419 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -26,6 +26,9 @@ public class DerivativeStructureField( public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) : DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol { override val identity: String = symbol.identity + override fun toString(): String = identity + override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + override fun hashCode(): Int = identity.hashCode() } /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt new file mode 100644 index 000000000..841531d01 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/DifferentiableExpression.kt @@ -0,0 +1,21 @@ +package kscience.kmath.expressions + +/** + * And object that could be differentiated + */ +public interface Differentiable { + public fun derivative(orders: Map): T +} + +public interface DifferentiableExpression : Differentiable>, Expression + +public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression = + derivative(mapOf(*orders)) + +public fun DifferentiableExpression.derivative(symbol: Symbol): Expression = derivative(symbol to 1) + +public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name) to 1) + +//public interface DifferentiableExpressionBuilder>: ExpressionBuilder { +// public override fun expression(block: A.() -> E): DifferentiableExpression +//} diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt index bd83261f7..7da5a2529 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/Expression.kt @@ -56,21 +56,6 @@ public operator fun Expression.invoke(vararg pairs: Pair): T = public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) }) -/** - * And object that could be differentiated - */ -public interface Differentiable { - public fun derivative(orders: Map): T -} - -public interface DifferentiableExpression : Differentiable>, Expression - -public fun DifferentiableExpression.derivative(vararg orders: Pair): Expression = - derivative(mapOf(*orders)) - -public fun DifferentiableExpression.derivative(symbol: Symbol): Expression = derivative(symbol to 1) - -public fun DifferentiableExpression.derivative(name: String): Expression = derivative(StringSymbol(name) to 1) /** * A context for expression construction @@ -96,6 +81,10 @@ public interface ExpressionAlgebra : Algebra { public fun const(value: T): E } +//public interface ExpressionBuilder> { +// public fun expression(block: A.() -> E): Expression +//} + /** * Bind a given [Symbol] to this context variable and produce context-specific object. */ diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 9fd15238a..0630e8e4b 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -7,8 +7,9 @@ import kscience.kmath.operations.* * * @param algebra The algebra to provide for Expressions built. */ -public abstract class FunctionalExpressionAlgebra>(public val algebra: A) : - ExpressionAlgebra> { +public abstract class FunctionalExpressionAlgebra>( + public val algebra: A, +) : ExpressionAlgebra> { /** * Builds an Expression of constant expression which does not depend on arguments. */ diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index af7c8fbf2..e5ea33c81 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -15,11 +15,11 @@ import kotlin.contracts.contract public open class AutoDiffValue(public val value: T) /** - * Represents result of [withAutoDiff] call. + * Represents result of [simpleAutoDiff] call. * * @param T the non-nullable type of value. * @param value the value of result. - * @property withAutoDiff The mapping of differentiated variables to their derivatives. + * @property simpleAutoDiff The mapping of differentiated variables to their derivatives. * @property context The field over [T]. */ public class DerivationResult( @@ -62,7 +62,7 @@ public fun DerivationResult.grad(vararg variables: Symbol): Point> F.withAutoDiff( +public fun > F.simpleAutoDiff( bindings: Map, body: AutoDiffField.() -> AutoDiffValue, ): DerivationResult { @@ -71,10 +71,10 @@ public fun > F.withAutoDiff( return AutoDiffContext(this, bindings).derivate(body) } -public fun > F.withAutoDiff( +public fun > F.simpleAutoDiff( vararg bindings: Pair, body: AutoDiffField.() -> AutoDiffValue, -): DerivationResult = withAutoDiff(bindings.toMap(), body) +): DerivationResult = simpleAutoDiff(bindings.toMap(), body) /** * Represents field in context of which functions can be derived. @@ -136,7 +136,7 @@ private class AutoDiffContext>( override val one: AutoDiffValue get() = const(context.one) /** - * Differentiable variable with value and derivative of differentiation ([withAutoDiff]) result + * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result * with respect to this variable. * * @param T the non-nullable type of value. @@ -148,6 +148,8 @@ private class AutoDiffContext>( var d: T, ) : AutoDiffValue(value), Symbol{ override fun toString(): String = identity + override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + override fun hashCode(): Int = identity.hashCode() } private val bindings: Map> = bindings.entries.associate { diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt index ef4a6a06a..ca8ec1e17 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -10,21 +10,17 @@ import kotlin.test.assertEquals import kotlin.test.assertTrue class SimpleAutoDiffTest { - fun d( - vararg bindings: Pair, - body: AutoDiffField.() -> AutoDiffValue, - ): DerivationResult = RealField.withAutoDiff(bindings = bindings, body) fun dx( xBinding: Pair, body: AutoDiffField.(x: AutoDiffValue) -> AutoDiffValue, - ): DerivationResult = RealField.withAutoDiff(xBinding) { body(bind(xBinding.first)) } + ): DerivationResult = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) } fun dxy( xBinding: Pair, yBinding: Pair, body: AutoDiffField.(x: AutoDiffValue, y: AutoDiffValue) -> AutoDiffValue, - ): DerivationResult = RealField.withAutoDiff(xBinding, yBinding) { + ): DerivationResult = RealField.simpleAutoDiff(xBinding, yBinding) { body(bind(xBinding.first), bind(yBinding.first)) } @@ -38,7 +34,7 @@ class SimpleAutoDiffTest { @Test fun testPlusX2() { - val y = d(x to 3.0) { + val y = RealField.simpleAutoDiff(x to 3.0) { // diff w.r.t this x at 3 val x = bind(x) x + x @@ -47,10 +43,21 @@ class SimpleAutoDiffTest { assertEquals(2.0, y.derivative(x)) // dy/dx = 2 } + @Test + fun testPlusX2Expr() { + val expr = diff{ + val x = bind(x) + x + x + } + assertEquals(6.0, expr(x to 3.0)) // y = x + x = 6 + assertEquals(2.0, expr.derivative(x)(x to 3.0)) // dy/dx = 2 + } + + @Test fun testPlus() { // two variables - val z = d(x to 2.0, y to 3.0) { + val z = RealField.simpleAutoDiff(x to 2.0, y to 3.0) { val x = bind(x) val y = bind(y) x + y @@ -63,7 +70,7 @@ class SimpleAutoDiffTest { @Test fun testMinus() { // two variables - val z = d(x to 7.0, y to 3.0) { + val z = RealField.simpleAutoDiff(x to 7.0, y to 3.0) { val x = bind(x) val y = bind(y)