Feature/diff api #154

Merged
altavir merged 20 commits from feature/diff-api into dev 2020-10-28 13:25:24 +03:00
8 changed files with 214 additions and 175 deletions
Showing only changes of commit 9a147d033e - Show all commits

View File

@ -1,9 +1,6 @@
package kscience.kmath.commons.expressions package kscience.kmath.commons.expressions
import kscience.kmath.expressions.DifferentiableExpression import kscience.kmath.expressions.*
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.ExpressionAlgebra
import kscience.kmath.expressions.Symbol
import kscience.kmath.operations.ExtendedField import kscience.kmath.operations.ExtendedField
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
@ -92,6 +89,12 @@ public class DerivativeStructureField(
public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) public override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
return DerivativeStructureExpression(function)
}
}
} }
/** /**

View File

@ -71,12 +71,8 @@ public object CMFit {
public fun Expression<Double>.optimize( public fun Expression<Double>.optimize(
vararg symbols: Symbol, vararg symbols: Symbol,
configuration: CMOptimizationProblem.() -> Unit, configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> { ): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
problem.expression(this)
return problem.optimize()
}
/** /**
* Optimize differentiable expression * Optimize differentiable expression
@ -84,12 +80,7 @@ public fun Expression<Double>.optimize(
public fun DifferentiableExpression<Double>.optimize( public fun DifferentiableExpression<Double>.optimize(
vararg symbols: Symbol, vararg symbols: Symbol,
configuration: CMOptimizationProblem.() -> Unit, configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> { ): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
val problem = CMOptimizationProblem(symbols.toList()).apply(configuration)
problem.diffExpression(this)
return problem.optimize()
}
public fun DifferentiableExpression<Double>.minimize( public fun DifferentiableExpression<Double>.minimize(
vararg startPoint: Pair<Symbol, Double>, vararg startPoint: Pair<Symbol, Double>,

View File

@ -33,7 +33,7 @@ public class CMOptimizationProblem(
public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList() public fun exportOptimizationData(): List<OptimizationData> = optimizationData.values.toList()
public fun initialGuess(map: Map<Symbol, Double>): Unit { public override fun initialGuess(map: Map<Symbol, Double>): Unit {
addOptimizationData(InitialGuess(map.toDoubleArray())) addOptimizationData(InitialGuess(map.toDoubleArray()))
} }
@ -94,10 +94,12 @@ public class CMOptimizationProblem(
return OptimizationResult(point.toMap(), value, setOf(this)) return OptimizationResult(point.toMap(), value, setOf(this))
} }
public companion object { public companion object : OptimizationProblemFactory<Double, CMOptimizationProblem> {
public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4 public const val DEFAULT_RELATIVE_TOLERANCE: Double = 1e-4
public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4 public const val DEFAULT_ABSOLUTE_TOLERANCE: Double = 1e-4
public const val DEFAULT_MAX_ITER: Int = 1000 public const val DEFAULT_MAX_ITER: Int = 1000
override fun build(symbols: List<Symbol>): CMOptimizationProblem = CMOptimizationProblem(symbols)
} }
} }

View File

@ -1,37 +0,0 @@
package kscience.kmath.commons.optimization
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.Symbol
public interface OptimizationFeature
//TODO move to prob/stat
public class OptimizationResult<T>(
public val point: Map<Symbol, T>,
public val value: T,
public val features: Set<OptimizationFeature> = emptySet(),
){
override fun toString(): String {
return "OptimizationResult(point=$point, value=$value)"
}
}
/**
* A configuration builder for optimization problem
*/
public interface OptimizationProblem<T : Any> {
/**
* Set an objective function expression
*/
public fun expression(expression: Expression<Double>): Unit
/**
*
*/
public fun diffExpression(expression: DifferentiableExpression<Double>): Unit
public fun update(result: OptimizationResult<T>)
public fun optimize(): OptimizationResult<T>
}

View File

@ -23,10 +23,9 @@ public fun <T> DifferentiableExpression<T>.derivative(symbol: Symbol): Expressio
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
derivative(StringSymbol(name) to 1) derivative(StringSymbol(name) to 1)
//public interface DifferentiableExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>>: ExpressionBuilder<T,E,A> { /**
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
// public override fun expression(block: A.() -> E): DifferentiableExpression<T> * A [DifferentiableExpression] that defines only first derivatives
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
//} */
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> { public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>? public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
@ -35,4 +34,11 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null val dSymbol = orders.entries.singleOrNull { it.value == 1 }?.key ?: return null
return derivativeOrNull(dSymbol) return derivativeOrNull(dSymbol)
} }
}
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
CommanderTvis commented 2020-10-28 12:56:06 +03:00 (Migrated from github.com)
Review

SAM

SAM
altavir commented 2020-10-28 13:04:11 +03:00 (Migrated from github.com)
Review

I see no reason for that. This interface is mostly implemented by companion objects. Also I plan to add additional methods later.

I see no reason for that. This interface is mostly implemented by companion objects. Also I plan to add additional methods later.
/**
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
*/
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
public fun process(function: A.() -> I): DifferentiableExpression<T>
CommanderTvis commented 2020-10-26 22:57:33 +03:00 (Migrated from github.com)
Review
  1. SAM interface.
  2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
1. SAM interface. 2. Does it have any uses, except the one in DifferentiableExpression? This trait can be extracted to separate into separate interface at any moment.
altavir commented 2020-10-28 09:26:02 +03:00 (Migrated from github.com)
Review

Removed

Removed
} }

View File

@ -47,7 +47,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
} }
/** /**
* Runs differentiation and establishes [AutoDiffField] context inside the block of code. * Runs differentiation and establishes [SimpleAutoDiffField] context inside the block of code.
* *
* The partial derivatives are placed in argument `d` variable * The partial derivatives are placed in argument `d` variable
* *
@ -59,36 +59,90 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
* assertEquals(9.0, x.d) // dy/dx * assertEquals(9.0, x.d) // dy/dx
* ``` * ```
* *
* @param body the action in [AutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to. * @param body the action in [SimpleAutoDiffField] context returning [AutoDiffVariable] to differentiate with respect to.
* @return the result of differentiation. * @return the result of differentiation.
*/ */
public fun <T : Any, F : Field<T>> F.simpleAutoDiff( public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>, body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> { ): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return AutoDiffContext(this, bindings).derivate(body) return SimpleAutoDiffField(this, bindings).derivate(body)
} }
public fun <T : Any, F : Field<T>> F.simpleAutoDiff( public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
vararg bindings: Pair<Symbol, T>, vararg bindings: Pair<Symbol, T>,
body: AutoDiffField<T, F>.() -> AutoDiffValue<T>, body: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body) ): DerivationResult<T> = simpleAutoDiff(bindings.toMap(), body)
/** /**
* Represents field in context of which functions can be derived. * Represents field in context of which functions can be derived.
*/ */
public abstract class AutoDiffField<T : Any, F : Field<T>> public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
: Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> { public val context: F,
bindings: Map<Symbol, T>,
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
public abstract val context: F // this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
/**
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
* with respect to this variable.
*
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
private class AutoDiffVariableWithDerivative<T : Any>(
override val identity: String,
value: T,
var d: T,
) : AutoDiffValue<T>(value), Symbol {
override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
}
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
}
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
private fun getDerivative(variable: AutoDiffValue<T>): T =
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
private fun setDerivative(variable: AutoDiffValue<T>, value: T) {
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
}
@Suppress("UNCHECKED_CAST")
private fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one)
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
/** /**
* A variable accessing inner state of derivatives. * A variable accessing inner state of derivatives.
* Use this value in inner builders to avoid creating additional derivative bindings. * Use this value in inner builders to avoid creating additional derivative bindings.
*/ */
public abstract var AutoDiffValue<T>.d: T public var AutoDiffValue<T>.d: T
get() = getDerivative(this)
set(value) = setDerivative(this, value)
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
/** /**
* Performs update of derivative after the rest of the formula in the back-pass. * Performs update of derivative after the rest of the formula in the back-pass.
@ -101,9 +155,22 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
* } * }
* ``` * ```
*/ */
public abstract fun <R> derive(value: R, block: F.(R) -> Unit): R @Suppress("UNCHECKED_CAST")
public fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
public inline fun const(block: F.() -> T): AutoDiffValue<T> = const(context.block())
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
}
// Overloads for Double constants // Overloads for Double constants
@ -119,68 +186,7 @@ public abstract class AutoDiffField<T : Any, F : Field<T>>
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> = override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
}
/**
* Automatic Differentiation context class.
*/
private class AutoDiffContext<T : Any, F : Field<T>>(
override val context: F,
bindings: Map<Symbol, T>,
) : AutoDiffField<T, F>() {
// this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one)
/**
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
* with respect to this variable.
*
* @param T the non-nullable type of value.
* @property value The value of this variable.
*/
private class AutoDiffVariableWithDeriv<T : Any>(
override val identity: String,
value: T,
var d: T,
) : AutoDiffValue<T>(value), Symbol{
override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
}
private val bindings: Map<String, AutoDiffVariableWithDeriv<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDeriv(it.key.identity, it.value, context.zero)
}
override fun bindOrNull(symbol: Symbol): AutoDiffVariableWithDeriv<T>? = bindings[symbol.identity]
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
override var AutoDiffValue<T>.d: T
get() = (this as? AutoDiffVariableWithDeriv)?.d ?: derivatives[this] ?: context.zero
set(value) = if (this is AutoDiffVariableWithDeriv) d = value else derivatives[this] = value
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: F.(R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
@Suppress("UNCHECKED_CAST")
fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as F.(Any?) -> Unit
context.block(value)
}
}
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
@ -206,13 +212,6 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
derive(const { k.toDouble() * a.value }) { z -> derive(const { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble() a.d += z.d * k.toDouble()
} }
inline fun derivate(function: AutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function()
result.d = context.one // computing derivative w.r.t result
runBackwardPass()
return DerivationResult(result.value, bindings.mapValues { it.value.d }, context)
}
} }
/** /**
@ -220,99 +219,178 @@ private class AutoDiffContext<T : Any, F : Field<T>>(
*/ */
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>( public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F, public val field: F,
public val function: AutoDiffField<T, F>.() -> AutoDiffValue<T>, public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
) : FirstDerivativeExpression<T>() { ) : FirstDerivativeExpression<T>() {
public override operator fun invoke(arguments: Map<Symbol, T>): T { public override operator fun invoke(arguments: Map<Symbol, T>): T {
//val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
return AutoDiffContext(field, arguments).function().value return SimpleAutoDiffField(field, arguments).function().value
} }
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments -> override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
//val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = AutoDiffContext(field, arguments).derivate(function) val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
derivationResult.derivative(symbol) derivationResult.derivative(symbol)
} }
} }
/**
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
*/
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
return SimpleAutoDiffExpression(field, function)
}
}
}
// Extensions for differentiation of various basic mathematical functions // Extensions for differentiation of various basic mathematical functions
// x ^ 2 // x ^ 2
public fun <T : Any, F : Field<T>> AutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : Field<T>> SimpleAutoDiffField<T, F>.sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value } derive(const { x.value * x.value }) { z -> x.d += z.d * 2 * x.value }
// x ^ 1/2 // x ^ 1/2
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sqrt(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value } derive(const { sqrt(x.value) }) { z -> x.d += z.d * 0.5 / z.value }
// x ^ y (const) // x ^ y (const)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow( public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>, x: AutoDiffValue<T>,
y: Double, y: Double,
): AutoDiffValue<T> = ): AutoDiffValue<T> =
derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) } derive(const { power(x.value, y) }) { z -> x.d += z.d * y * power(x.value, y - 1) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow( public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>, x: AutoDiffValue<T>,
y: Int, y: Int,
): AutoDiffValue<T> = pow(x, y.toDouble()) ): AutoDiffValue<T> = pow(x, y.toDouble())
// exp(x) // exp(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.exp(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { exp(x.value) }) { z -> x.d += z.d * z.value } derive(const { exp(x.value) }) { z -> x.d += z.d * z.value }
// ln(x) // ln(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.ln(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { ln(x.value) }) { z -> x.d += z.d / x.value } derive(const { ln(x.value) }) { z -> x.d += z.d / x.value }
// x ^ y (any) // x ^ y (any)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.pow( public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.pow(
x: AutoDiffValue<T>, x: AutoDiffValue<T>,
y: AutoDiffValue<T>, y: AutoDiffValue<T>,
): AutoDiffValue<T> = ): AutoDiffValue<T> =
exp(y * ln(x)) exp(y * ln(x))
// sin(x) // sin(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) } derive(const { sin(x.value) }) { z -> x.d += z.d * cos(x.value) }
// cos(x) // cos(x)
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) } derive(const { cos(x.value) }) { z -> x.d -= z.d * sin(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tan(x.value) }) { z -> derive(const { tan(x.value) }) { z ->
val c = cos(x.value) val c = cos(x.value)
x.d += z.d / (c * c) x.d += z.d / (c * c)
} }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asin(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) } derive(const { asin(x.value) }) { z -> x.d += z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acos(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) } derive(const { acos(x.value) }) { z -> x.d -= z.d / sqrt(one - x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atan(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) } derive(const { atan(x.value) }) { z -> x.d += z.d / (one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.sinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { sin(x.value) }) { z -> x.d += z.d * cosh(x.value) } derive(const { sinh(x.value) }) { z -> x.d += z.d * cosh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.cosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { cos(x.value) }) { z -> x.d += z.d * sinh(x.value) } derive(const { cosh(x.value) }) { z -> x.d += z.d * sinh(x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.tanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { tan(x.value) }) { z -> derive(const { tanh(x.value) }) { z ->
val c = cosh(x.value) val c = cosh(x.value)
x.d += z.d / (c * c) x.d += z.d / (c * c)
} }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.asinh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) } derive(const { asinh(x.value) }) { z -> x.d += z.d / sqrt(one + x.value * x.value) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.acosh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) } derive(const { acosh(x.value) }) { z -> x.d += z.d / (sqrt((x.value - one) * (x.value + one))) }
public fun <T : Any, F : ExtendedField<T>> AutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> = public fun <T : Any, F : ExtendedField<T>> SimpleAutoDiffField<T, F>.atanh(x: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) } derive(const { atanh(x.value) }) { z -> x.d += z.d / (one - x.value * x.value) }
public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
context: F,
bindings: Map<Symbol, T>,
) : ExtendedField<AutoDiffValue<T>>, SimpleAutoDiffField<T, F>(context, bindings) {
// x ^ 2
public fun sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sqr(x)
// x ^ 1/2
public override fun sqrt(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sqrt(arg)
// x ^ y (const)
public override fun power(arg: AutoDiffValue<T>, pow: Number): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).pow(arg, pow.toDouble())
// exp(x)
public override fun exp(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).exp(arg)
// ln(x)
public override fun ln(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).ln(arg)
// x ^ y (any)
public fun pow(
x: AutoDiffValue<T>,
y: AutoDiffValue<T>,
): AutoDiffValue<T> = exp(y * ln(x))
// sin(x)
public override fun sin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sin(arg)
// cos(x)
public override fun cos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).cos(arg)
public override fun tan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).tan(arg)
public override fun asin(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).asin(arg)
public override fun acos(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).acos(arg)
public override fun atan(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).atan(arg)
public override fun sinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).sinh(arg)
public override fun cosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).cosh(arg)
public override fun tanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).tanh(arg)
public override fun asinh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).asinh(arg)
public override fun acosh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).acosh(arg)
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).atanh(arg)
}

View File

@ -8,11 +8,6 @@ import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
//public interface ExpressionBuilder<T, E, A : ExpressionAlgebra<T, E>> {
// public fun expression(block: A.() -> E): Expression<T>
//}
/** /**
* Creates a functional expression with this [Space]. * Creates a functional expression with this [Space].
*/ */

View File

@ -2,6 +2,7 @@ package kscience.kmath.expressions
import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
import kscience.kmath.structures.asBuffer import kscience.kmath.structures.asBuffer
import kotlin.math.E
import kotlin.math.PI import kotlin.math.PI
import kotlin.math.pow import kotlin.math.pow
import kotlin.math.sqrt import kotlin.math.sqrt
@ -13,18 +14,18 @@ class SimpleAutoDiffTest {
fun dx( fun dx(
xBinding: Pair<Symbol, Double>, xBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>, body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) } ): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding) { body(bind(xBinding.first)) }
fun dxy( fun dxy(
xBinding: Pair<Symbol, Double>, xBinding: Pair<Symbol, Double>,
yBinding: Pair<Symbol, Double>, yBinding: Pair<Symbol, Double>,
body: AutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>, body: SimpleAutoDiffField<Double, RealField>.(x: AutoDiffValue<Double>, y: AutoDiffValue<Double>) -> AutoDiffValue<Double>,
): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) { ): DerivationResult<Double> = RealField.simpleAutoDiff(xBinding, yBinding) {
body(bind(xBinding.first), bind(yBinding.first)) body(bind(xBinding.first), bind(yBinding.first))
} }
fun diff(block: AutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> { fun diff(block: SimpleAutoDiffField<Double, RealField>.() -> AutoDiffValue<Double>): SimpleAutoDiffExpression<Double, RealField> {
return SimpleAutoDiffExpression(RealField, block) return SimpleAutoDiffExpression(RealField, block)
} }
@ -45,7 +46,7 @@ class SimpleAutoDiffTest {
@Test @Test
fun testPlusX2Expr() { fun testPlusX2Expr() {
val expr = diff{ val expr = diff {
val x = bind(x) val x = bind(x)
x + x x + x
} }
@ -245,9 +246,9 @@ class SimpleAutoDiffTest {
@Test @Test
fun testTanh() { fun testTanh() {
val y = dx(x to PI / 6) { x -> tanh(x) } val y = dx(x to 1.0) { x -> tanh(x) }
assertApprox(1.0 / sqrt(3.0), y.value) // y = tanh(pi/6) assertApprox((E * E - 1) / (E * E + 1), y.value) // y = tanh(pi/6)
assertApprox(1.0 / kotlin.math.cosh(PI / 6.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2 assertApprox(1.0 / kotlin.math.cosh(1.0).pow(2), y.derivative(x)) // dy/dx = sech(pi/6)^2
} }
@Test @Test