Change DifferentiableExpression API to use ordered symbol list instead of orders map.

This commit is contained in:
Alexander Nozik 2020-10-29 19:35:08 +03:00
parent 4b7bd3d174
commit fbe1ab94a4
4 changed files with 48 additions and 38 deletions

View File

@ -12,46 +12,51 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
*/ */
public class DerivativeStructureField( public class DerivativeStructureField(
public val order: Int, public val order: Int,
private val bindings: Map<Symbol, Double> bindings: Map<Symbol, Double>,
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> { ) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order) } public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0) }
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, order, 1.0) } public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0, 1.0) }
/** /**
* A class that implements both [DerivativeStructure] and a [Symbol] * A class that implements both [DerivativeStructure] and a [Symbol]
*/ */
public inner class DerivativeStructureSymbol(symbol: Symbol, value: Double) : public inner class DerivativeStructureSymbol(
DerivativeStructure(bindings.size, order, bindings.keys.indexOf(symbol), value), Symbol { size: Int,
index: Int,
symbol: Symbol,
value: Double,
) : DerivativeStructure(size, order, index, value), Symbol {
override val identity: String = symbol.identity override val identity: String = symbol.identity
override fun toString(): String = identity override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode() override fun hashCode(): Int = identity.hashCode()
} }
public val numberOfVariables: Int = bindings.size
/** /**
* Identity-based symbol bindings map * Identity-based symbol bindings map
*/ */
private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.associate { (key, value) -> private val variables: Map<String, DerivativeStructureSymbol> = bindings.entries.mapIndexed { index, (key, value) ->
key.identity to DerivativeStructureSymbol(key, value) key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value)
} }.toMap()
override fun const(value: Double): DerivativeStructure = DerivativeStructure(bindings.size, order, value) override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, 0, value)
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity) public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
//public fun Number.const(): DerivativeStructure = const(toDouble()) override fun symbol(value: String): DerivativeStructureSymbol = bind(StringSymbol(value))
public fun DerivativeStructure.derivative(parameter: Symbol, order: Int = 1): Double { public fun DerivativeStructure.derivative(symbols: List<Symbol>): Double {
return derivative(mapOf(parameter to order)) require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
val ordersCount = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
} }
public fun DerivativeStructure.derivative(orders: Map<Symbol, Int>): Double { public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList())
return getPartialDerivative(*bindings.keys.map { orders[it] ?: 0 }.toIntArray())
}
public fun DerivativeStructure.derivative(vararg orders: Pair<Symbol, Int>): Double = derivative(mapOf(*orders))
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) { public override fun multiply(a: DerivativeStructure, k: Number): DerivativeStructure = when (k) {
@ -97,6 +102,7 @@ public class DerivativeStructureField(
} }
} }
/** /**
* A constructs that creates a derivative structure with required order on-demand * A constructs that creates a derivative structure with required order on-demand
*/ */
@ -109,7 +115,7 @@ public class DerivativeStructureExpression(
/** /**
* Get the derivative expression with given orders * Get the derivative expression with given orders
*/ */
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<Double> = Expression { arguments -> public override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
with(DerivativeStructureField(orders.values.maxOrNull() ?: 0, arguments)) { function().derivative(orders) } with(DerivativeStructureField(symbols.size, arguments)) { function().derivative(symbols) }
} }
} }

View File

@ -5,14 +5,15 @@ import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFails
internal inline fun <R> diff( internal inline fun diff(
order: Int, order: Int,
vararg parameters: Pair<Symbol, Double>, vararg parameters: Pair<Symbol, Double>,
block: DerivativeStructureField.() -> R, block: DerivativeStructureField.() -> Unit,
): R { ): Unit {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return DerivativeStructureField(order, mapOf(*parameters)).run(block) DerivativeStructureField(order, mapOf(*parameters)).run(block)
} }
internal class AutoDiffTest { internal class AutoDiffTest {
@ -21,13 +22,16 @@ internal class AutoDiffTest {
@Test @Test
fun derivativeStructureFieldTest() { fun derivativeStructureFieldTest() {
val res: Double = diff(3, x to 1.0, y to 1.0) { diff(2, x to 1.0, y to 1.0) {
val x = bind(x)//by binding() val x = bind(x)//by binding()
val y = symbol("y") val y = symbol("y")
val z = x * (-sin(x * y) + y) val z = x * (-sin(x * y) + y) + 2.0
z.derivative(x) println(z.derivative(x))
println(z.derivative(y,x))
assertEquals(z.derivative(x, y), z.derivative(y, x))
//check that improper order cause failure
assertFails { z.derivative(x,x,y) }
} }
println(res)
} }
@Test @Test
@ -40,5 +44,7 @@ internal class AutoDiffTest {
assertEquals(10.0, f(x to 1.0, y to 2.0)) assertEquals(10.0, f(x to 1.0, y to 2.0))
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0)) assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0))
assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0))
} }
} }

View File

@ -4,19 +4,17 @@ package kscience.kmath.expressions
* An expression that provides derivatives * An expression that provides derivatives
*/ */
public interface DifferentiableExpression<T> : Expression<T> { public interface DifferentiableExpression<T> : Expression<T> {
public fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>? public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>?
} }
public fun <T> DifferentiableExpression<T>.derivative(orders: Map<Symbol, Int>): Expression<T> = public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> =
derivativeOrNull(orders) ?: error("Derivative with orders $orders not provided") derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
public fun <T> DifferentiableExpression<T>.derivative(vararg orders: Pair<Symbol, Int>): Expression<T> = public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> =
derivative(mapOf(*orders)) derivative(symbols.toList())
public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expression<T> = derivative(symbol to 1)
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
derivative(StringSymbol(name) to 1) derivative(StringSymbol(name))
/** /**
* A [DifferentiableExpression] that defines only first derivatives * A [DifferentiableExpression] that defines only first derivatives
@ -25,8 +23,8 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>? public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
public override fun derivativeOrNull(orders: Map<Symbol, Int>): Expression<T>? { public override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null val dSymbol = symbols.firstOrNull() ?: return null
return derivativeOrNull(dSymbol) return derivativeOrNull(dSymbol)
} }
} }

View File

@ -87,7 +87,7 @@ public fun <T, E> ExpressionAlgebra<T, E>.bind(symbol: Symbol): E =
/** /**
* A delegate to create a symbol with a string identity in this scope * A delegate to create a symbol with a string identity in this scope
*/ */
public val symbol: ReadOnlyProperty<Any?, StringSymbol> = ReadOnlyProperty { thisRef, property -> public val symbol: ReadOnlyProperty<Any?, StringSymbol> = ReadOnlyProperty { _, property ->
StringSymbol(property.name) StringSymbol(property.name)
} }