forked from kscience/kmath
Grand derivative refactoring. Phase 2
This commit is contained in:
parent
5846f42141
commit
f5fe53a9f2
@ -0,0 +1,437 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.*
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
|
import space.kscience.kmath.structures.asBuffer
|
||||||
|
import kotlin.math.max
|
||||||
|
import kotlin.math.min
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class representing both the value and the differentials of a function.
|
||||||
|
*
|
||||||
|
* This class is the workhorse of the differentiation package.
|
||||||
|
*
|
||||||
|
* This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper
|
||||||
|
* [Doubly Recursive Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf),
|
||||||
|
* Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used
|
||||||
|
* throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's
|
||||||
|
* derivative structures hold all partial derivatives up to any specified order, with respect to any number of free
|
||||||
|
* parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free
|
||||||
|
* parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters.
|
||||||
|
*
|
||||||
|
* Derived from
|
||||||
|
* [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java).
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public interface DS<T, A : Ring<T>> {
|
||||||
|
public val derivativeAlgebra: DSAlgebra<T, A>
|
||||||
|
public val data: Buffer<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a partial derivative.
|
||||||
|
*
|
||||||
|
* @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned).
|
||||||
|
* @return partial derivative.
|
||||||
|
* @see value
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
|
||||||
|
data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The value part of the derivative structure.
|
||||||
|
*
|
||||||
|
* @see getPartialDerivative
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public val <T, A : Ring<T>> DS<T, A>.value: T get() = data[0]
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
|
public val order: Int,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : ExpressionAlgebra<T, DS<T, A>> {
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
|
||||||
|
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||||
|
buffer[0] = value
|
||||||
|
if (compiler.order > 0) {
|
||||||
|
// the derivative of the variable with respect to itself is 1.
|
||||||
|
|
||||||
|
val indexOfDerivative = compiler.getPartialDerivativeIndex(*IntArray(numberOfVariables).apply {
|
||||||
|
set(index, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
buffer[indexOfDerivative] = algebra.one
|
||||||
|
}
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
protected inner class DSImpl internal constructor(
|
||||||
|
override val data: Buffer<T>,
|
||||||
|
) : DS<T, A> {
|
||||||
|
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun DS(data: Buffer<T>): DS<T, A> = DSImpl(data)
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance representing a variable.
|
||||||
|
*
|
||||||
|
* Instances built using this constructor are considered to be the free variables with respect to which
|
||||||
|
* differentials are computed. As such, their differential with respect to themselves is +1.
|
||||||
|
*/
|
||||||
|
public fun variable(
|
||||||
|
index: Int,
|
||||||
|
value: T,
|
||||||
|
): DS<T, A> {
|
||||||
|
require(index < compiler.freeParameters) { "number is too large: $index >= ${compiler.freeParameters}" }
|
||||||
|
return DS(bufferForVariable(index, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance from all its derivatives.
|
||||||
|
*
|
||||||
|
* @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex].
|
||||||
|
*/
|
||||||
|
public fun ofDerivatives(
|
||||||
|
vararg derivatives: T,
|
||||||
|
): DS<T, A> {
|
||||||
|
require(derivatives.size == compiler.size) { "dimension mismatch: ${derivatives.size} and ${compiler.size}" }
|
||||||
|
val data = derivatives.asBuffer()
|
||||||
|
|
||||||
|
return DS(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class implementing both [DS] and [Symbol].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public inner class DSSymbol internal constructor(
|
||||||
|
index: Int,
|
||||||
|
symbol: Symbol,
|
||||||
|
value: T,
|
||||||
|
) : Symbol by symbol, DS<T, A> {
|
||||||
|
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
|
||||||
|
override val data: Buffer<T> = bufferForVariable(index, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public val numberOfVariables: Int = bindings.size
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the compiler for number of free parameters and order.
|
||||||
|
*
|
||||||
|
* @return cached rules set.
|
||||||
|
*/
|
||||||
|
@PublishedApi
|
||||||
|
internal val compiler: DSCompiler<T, A> by lazy {
|
||||||
|
// get the cached compilers
|
||||||
|
val cache: Array<Array<DSCompiler<T, A>?>>? = null
|
||||||
|
|
||||||
|
// we need to create more compilers
|
||||||
|
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
|
||||||
|
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
|
||||||
|
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
|
||||||
|
|
||||||
|
if (cache != null) {
|
||||||
|
// preserve the already created compilers
|
||||||
|
for (i in cache.indices) {
|
||||||
|
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the array in increasing diagonal order
|
||||||
|
for (diag in 0..numberOfVariables + order) {
|
||||||
|
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
|
||||||
|
val p: Int = diag - o
|
||||||
|
if (newCache[p][o] == null) {
|
||||||
|
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
|
||||||
|
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
|
||||||
|
|
||||||
|
newCache[p][o] = DSCompiler(
|
||||||
|
algebra,
|
||||||
|
bufferFactory,
|
||||||
|
p,
|
||||||
|
o,
|
||||||
|
valueCompiler,
|
||||||
|
derivativeCompiler,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return@lazy newCache[numberOfVariables][order]!!
|
||||||
|
}
|
||||||
|
|
||||||
|
private val variables: Map<Symbol, DSSymbol> = bindings.entries.mapIndexed { index, (key, value) ->
|
||||||
|
key to DSSymbol(
|
||||||
|
index,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
}.toMap()
|
||||||
|
|
||||||
|
|
||||||
|
public override fun const(value: T): DS<T, A> {
|
||||||
|
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||||
|
buffer[0] = value
|
||||||
|
|
||||||
|
return DS(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun bindSymbolOrNull(value: String): DSSymbol? = variables[StringSymbol(value)]
|
||||||
|
|
||||||
|
override fun bindSymbol(value: String): DSSymbol =
|
||||||
|
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
|
||||||
|
|
||||||
|
public fun bindSymbolOrNull(symbol: Symbol): DSSymbol? = variables[symbol.identity]
|
||||||
|
|
||||||
|
public fun bindSymbol(symbol: Symbol): DSSymbol =
|
||||||
|
bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this")
|
||||||
|
|
||||||
|
public fun DS<T, A>.derivative(symbols: List<Symbol>): T {
|
||||||
|
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
||||||
|
val ordersCount = symbols.groupBy { it }.mapValues { it.value.size }
|
||||||
|
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun DS<T, A>.derivative(vararg symbols: Symbol): T = derivative(symbols.toList())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A ring over [DS].
|
||||||
|
*
|
||||||
|
* @property order The derivation order.
|
||||||
|
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public open class DSRing<T, A>(
|
||||||
|
algebra: A,
|
||||||
|
bufferFactory: MutableBufferFactory<T>,
|
||||||
|
order: Int,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : DSAlgebra<T, A>(algebra, bufferFactory, order, bindings),
|
||||||
|
Ring<DS<T, A>>, ScaleOperations<DS<T, A>>,
|
||||||
|
NumericAlgebra<DS<T, A>>,
|
||||||
|
NumbersAddOps<DS<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||||
|
|
||||||
|
override fun bindSymbolOrNull(value: String): DSSymbol? =
|
||||||
|
super<DSAlgebra>.bindSymbolOrNull(value)
|
||||||
|
|
||||||
|
override fun DS<T, A>.unaryMinus(): DS<T, A> = mapData { -it }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a copy of given [Buffer] and modify it according to [block]
|
||||||
|
*/
|
||||||
|
protected inline fun DS<T, A>.transformDataBuffer(block: A.(MutableBuffer<T>) -> Unit): DS<T, A> {
|
||||||
|
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
|
val newData = bufferFactory(compiler.size) { data[it] }
|
||||||
|
algebra.block(newData)
|
||||||
|
return DS(newData)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun DS<T, A>.mapData(block: A.(T) -> T): DS<T, A> {
|
||||||
|
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
|
val newData: Buffer<T> = data.map(bufferFactory) {
|
||||||
|
algebra.block(it)
|
||||||
|
}
|
||||||
|
return DS(newData)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun DS<T, A>.mapDataIndexed(block: (Int, T) -> T): DS<T, A> {
|
||||||
|
require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
|
val newData: Buffer<T> = data.mapIndexed(bufferFactory, block)
|
||||||
|
return DS(newData)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val zero: DS<T, A> by lazy {
|
||||||
|
const(algebra.zero)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val one: DS<T, A> by lazy {
|
||||||
|
const(algebra.one)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
|
||||||
|
|
||||||
|
override fun add(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
|
||||||
|
require(right.derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" }
|
||||||
|
compiler.add(left.data, 0, right.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun scale(a: DS<T, A>, value: Double): DS<T, A> = a.mapData {
|
||||||
|
it.times(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(
|
||||||
|
left: DS<T, A>,
|
||||||
|
right: DS<T, A>,
|
||||||
|
): DS<T, A> = left.transformDataBuffer { result ->
|
||||||
|
compiler.multiply(left.data, 0, right.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// override fun DS<T, A>.minus(arg: DS): DS<T, A> = transformDataBuffer { result ->
|
||||||
|
// subtract(data, 0, arg.data, 0, result, 0)
|
||||||
|
// }
|
||||||
|
|
||||||
|
override operator fun DS<T, A>.plus(other: Number): DS<T, A> = transformDataBuffer {
|
||||||
|
it[0] += number(other)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// override operator fun DS<T, A>.minus(other: Number): DS<T, A> =
|
||||||
|
// this + (-other.toDouble())
|
||||||
|
|
||||||
|
override operator fun Number.plus(other: DS<T, A>): DS<T, A> = other + this
|
||||||
|
override operator fun Number.minus(other: DS<T, A>): DS<T, A> = other - this
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DerivativeStructureRingExpression<T, A>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
|
public val function: DSRing<T, A>.() -> DS<T, A>,
|
||||||
|
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
||||||
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
|
DSRing(algebra, bufferFactory, 0, arguments).function().value
|
||||||
|
|
||||||
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||||
|
with(
|
||||||
|
DSRing(
|
||||||
|
algebra,
|
||||||
|
bufferFactory,
|
||||||
|
symbols.size,
|
||||||
|
arguments
|
||||||
|
)
|
||||||
|
) { function().derivative(symbols) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field over commons-math [DerivativeStructure].
|
||||||
|
*
|
||||||
|
* @property order The derivation order.
|
||||||
|
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DSField<T, A : ExtendedField<T>>(
|
||||||
|
algebra: A,
|
||||||
|
bufferFactory: MutableBufferFactory<T>,
|
||||||
|
order: Int,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : DSRing<T, A>(algebra, bufferFactory, order, bindings), ExtendedField<DS<T, A>> {
|
||||||
|
override fun number(value: Number): DS<T, A> = const(algebra.number(value))
|
||||||
|
|
||||||
|
override fun divide(left: DS<T, A>, right: DS<T, A>): DS<T, A> = left.transformDataBuffer { result ->
|
||||||
|
compiler.divide(left.data, 0, right.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sin(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.sin(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cos(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.cos(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun tan(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.tan(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun asin(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.asin(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun acos(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.acos(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun atan(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.atan(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sinh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.sinh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cosh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.cosh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun tanh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.tanh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun asinh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.asinh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun acosh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.acosh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun atanh(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.atanh(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun power(arg: DS<T, A>, pow: Number): DS<T, A> = when (pow) {
|
||||||
|
is Int -> arg.transformDataBuffer { result ->
|
||||||
|
compiler.pow(arg.data, 0, pow, result, 0)
|
||||||
|
}
|
||||||
|
else -> arg.transformDataBuffer { result ->
|
||||||
|
compiler.pow(arg.data, 0, pow.toDouble(), result, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sqrt(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.sqrt(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun power(arg: DS<T, A>, pow: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.pow(arg.data, 0, pow.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun exp(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.exp(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun ln(arg: DS<T, A>): DS<T, A> = arg.transformDataBuffer { result ->
|
||||||
|
compiler.ln(arg.data, 0, result, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DerivativeStructureFieldExpression<T, A : ExtendedField<T>>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
|
public val function: DSField<T, A>.() -> DS<T, A>,
|
||||||
|
) : DifferentiableExpression<T> {
|
||||||
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
|
DSField(algebra, bufferFactory, 0, arguments).function().value
|
||||||
|
|
||||||
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||||
|
DSField(
|
||||||
|
algebra,
|
||||||
|
bufferFactory,
|
||||||
|
symbols.size,
|
||||||
|
arguments,
|
||||||
|
).run { function().derivative(symbols) }
|
||||||
|
}
|
||||||
|
}
|
@ -1,157 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2021 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.expressions
|
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.operations.Ring
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
|
||||||
import space.kscience.kmath.structures.asBuffer
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Class representing both the value and the differentials of a function.
|
|
||||||
*
|
|
||||||
* This class is the workhorse of the differentiation package.
|
|
||||||
*
|
|
||||||
* This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper [Doubly Recursive
|
|
||||||
* Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf),
|
|
||||||
* Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used
|
|
||||||
* throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's
|
|
||||||
* derivative structures hold all partial derivatives up to any specified order, with respect to any number of free
|
|
||||||
* parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free
|
|
||||||
* parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters.
|
|
||||||
*
|
|
||||||
* Derived from
|
|
||||||
* [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java).
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public open class DerivativeStructure<T, A : Ring<T>> @PublishedApi internal constructor(
|
|
||||||
private val derivativeAlgebra: DerivativeStructureAlgebra<T, A>,
|
|
||||||
@PublishedApi internal val data: Buffer<T>,
|
|
||||||
) {
|
|
||||||
|
|
||||||
public val compiler: DSCompiler<T, A> get() = derivativeAlgebra.compiler
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The number of free parameters.
|
|
||||||
*/
|
|
||||||
public val freeParameters: Int get() = compiler.freeParameters
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The derivation order.
|
|
||||||
*/
|
|
||||||
public val order: Int get() = compiler.order
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The value part of the derivative structure.
|
|
||||||
*
|
|
||||||
* @see getPartialDerivative
|
|
||||||
*/
|
|
||||||
public val value: T get() = data[0]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get a partial derivative.
|
|
||||||
*
|
|
||||||
* @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned).
|
|
||||||
* @return partial derivative.
|
|
||||||
* @see value
|
|
||||||
*/
|
|
||||||
public fun getPartialDerivative(vararg orders: Int): T = data[compiler.getPartialDerivativeIndex(*orders)]
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Test for the equality of two derivative structures.
|
|
||||||
*
|
|
||||||
* Derivative structures are considered equal if they have the same number
|
|
||||||
* of free parameters, the same derivation order, and the same derivatives.
|
|
||||||
*
|
|
||||||
* @return `true` if two derivative structures are equal.
|
|
||||||
*/
|
|
||||||
public override fun equals(other: Any?): Boolean {
|
|
||||||
if (this === other) return true
|
|
||||||
|
|
||||||
if (other is DerivativeStructure<*, *>) {
|
|
||||||
return ((freeParameters == other.freeParameters) &&
|
|
||||||
(order == other.order) &&
|
|
||||||
data == other.data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
public override fun hashCode(): Int =
|
|
||||||
227 + 229 * freeParameters + 233 * order + 239 * data.hashCode()
|
|
||||||
|
|
||||||
public companion object {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build an instance representing a variable.
|
|
||||||
*
|
|
||||||
* Instances built using this constructor are considered to be the free variables with respect to which
|
|
||||||
* differentials are computed. As such, their differential with respect to themselves is +1.
|
|
||||||
*/
|
|
||||||
public fun <T, A : Ring<T>> variable(
|
|
||||||
derivativeAlgebra: DerivativeStructureAlgebra<T, A>,
|
|
||||||
index: Int,
|
|
||||||
value: T,
|
|
||||||
): DerivativeStructure<T, A> {
|
|
||||||
val compiler = derivativeAlgebra.compiler
|
|
||||||
require(index < compiler.freeParameters) { "number is too large: $index >= ${compiler.freeParameters}" }
|
|
||||||
return DerivativeStructure(derivativeAlgebra, derivativeAlgebra.bufferForVariable(index, value))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build an instance from all its derivatives.
|
|
||||||
*
|
|
||||||
* @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex].
|
|
||||||
*/
|
|
||||||
public fun <T, A : Ring<T>> ofDerivatives(
|
|
||||||
derivativeAlgebra: DerivativeStructureAlgebra<T, A>,
|
|
||||||
vararg derivatives: T,
|
|
||||||
): DerivativeStructure<T, A> {
|
|
||||||
val compiler = derivativeAlgebra.compiler
|
|
||||||
require(derivatives.size == compiler.size) { "dimension mismatch: ${derivatives.size} and ${compiler.size}" }
|
|
||||||
val data = derivatives.asBuffer()
|
|
||||||
|
|
||||||
return DerivativeStructure(
|
|
||||||
derivativeAlgebra,
|
|
||||||
data
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
private fun <T, A : Ring<T>> DerivativeStructureAlgebra<T, A>.bufferForVariable(index: Int, value: T): Buffer<T> {
|
|
||||||
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
|
||||||
buffer[0] = value
|
|
||||||
if (compiler.order > 0) {
|
|
||||||
// the derivative of the variable with respect to itself is 1.
|
|
||||||
|
|
||||||
val indexOfDerivative = compiler.getPartialDerivativeIndex(*IntArray(numberOfVariables).apply {
|
|
||||||
set(index, 1)
|
|
||||||
})
|
|
||||||
|
|
||||||
buffer[indexOfDerivative] = algebra.one
|
|
||||||
}
|
|
||||||
return buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A class implementing both [DerivativeStructure] and [Symbol].
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class DerivativeStructureSymbol<T, A : Ring<T>> internal constructor(
|
|
||||||
derivativeAlgebra: DerivativeStructureAlgebra<T, A>,
|
|
||||||
index: Int,
|
|
||||||
symbol: Symbol,
|
|
||||||
value: T,
|
|
||||||
) : Symbol by symbol, DerivativeStructure<T, A>(
|
|
||||||
derivativeAlgebra, derivativeAlgebra.bufferForVariable(index, value)
|
|
||||||
) {
|
|
||||||
override fun toString(): String = symbol.toString()
|
|
||||||
override fun equals(other: Any?): Boolean = (other as? Symbol) == symbol
|
|
||||||
override fun hashCode(): Int = symbol.hashCode()
|
|
||||||
}
|
|
@ -1,349 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2021 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.expressions
|
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.operations.*
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
|
||||||
import space.kscience.kmath.structures.MutableBufferFactory
|
|
||||||
import kotlin.math.max
|
|
||||||
import kotlin.math.min
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public abstract class DerivativeStructureAlgebra<T, A : Ring<T>>(
|
|
||||||
public val algebra: A,
|
|
||||||
public val bufferFactory: MutableBufferFactory<T>,
|
|
||||||
public val order: Int,
|
|
||||||
bindings: Map<Symbol, T>,
|
|
||||||
) : ExpressionAlgebra<T, DerivativeStructure<T, A>> {
|
|
||||||
|
|
||||||
public val numberOfVariables: Int = bindings.size
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the compiler for number of free parameters and order.
|
|
||||||
*
|
|
||||||
* @return cached rules set.
|
|
||||||
*/
|
|
||||||
@PublishedApi
|
|
||||||
internal val compiler: DSCompiler<T, A> by lazy {
|
|
||||||
// get the cached compilers
|
|
||||||
val cache: Array<Array<DSCompiler<T, A>?>>? = null
|
|
||||||
|
|
||||||
// we need to create more compilers
|
|
||||||
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
|
|
||||||
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
|
|
||||||
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
|
|
||||||
|
|
||||||
if (cache != null) {
|
|
||||||
// preserve the already created compilers
|
|
||||||
for (i in cache.indices) {
|
|
||||||
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// create the array in increasing diagonal order
|
|
||||||
for (diag in 0..numberOfVariables + order) {
|
|
||||||
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
|
|
||||||
val p: Int = diag - o
|
|
||||||
if (newCache[p][o] == null) {
|
|
||||||
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
|
|
||||||
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
|
|
||||||
|
|
||||||
newCache[p][o] = DSCompiler(
|
|
||||||
algebra,
|
|
||||||
bufferFactory,
|
|
||||||
p,
|
|
||||||
o,
|
|
||||||
valueCompiler,
|
|
||||||
derivativeCompiler,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return@lazy newCache[numberOfVariables][order]!!
|
|
||||||
}
|
|
||||||
|
|
||||||
private val variables: Map<Symbol, DerivativeStructureSymbol<T, A>> =
|
|
||||||
bindings.entries.mapIndexed { index, (key, value) ->
|
|
||||||
key to DerivativeStructureSymbol(
|
|
||||||
this,
|
|
||||||
index,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
}.toMap()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public override fun const(value: T): DerivativeStructure<T, A> {
|
|
||||||
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
|
||||||
buffer[0] = value
|
|
||||||
|
|
||||||
return DerivativeStructure(
|
|
||||||
this,
|
|
||||||
buffer
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol<T, A>? = variables[StringSymbol(value)]
|
|
||||||
|
|
||||||
override fun bindSymbol(value: String): DerivativeStructureSymbol<T, A> =
|
|
||||||
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
|
|
||||||
|
|
||||||
public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol<T, A>? = variables[symbol.identity]
|
|
||||||
|
|
||||||
public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol<T, A> =
|
|
||||||
bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this")
|
|
||||||
|
|
||||||
public fun DerivativeStructure<T, A>.derivative(symbols: List<Symbol>): T {
|
|
||||||
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
|
||||||
val ordersCount = symbols.groupBy { it }.mapValues { it.value.size }
|
|
||||||
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DerivativeStructure<T, A>.derivative(vararg symbols: Symbol): T = derivative(symbols.toList())
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A ring over [DerivativeStructure].
|
|
||||||
*
|
|
||||||
* @property order The derivation order.
|
|
||||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public open class DerivativeStructureRing<T, A>(
|
|
||||||
algebra: A,
|
|
||||||
bufferFactory: MutableBufferFactory<T>,
|
|
||||||
order: Int,
|
|
||||||
bindings: Map<Symbol, T>,
|
|
||||||
) : DerivativeStructureAlgebra<T, A>(algebra, bufferFactory, order, bindings),
|
|
||||||
Ring<DerivativeStructure<T, A>>, ScaleOperations<DerivativeStructure<T, A>>,
|
|
||||||
NumericAlgebra<DerivativeStructure<T, A>>,
|
|
||||||
NumbersAddOps<DerivativeStructure<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol<T, A>? =
|
|
||||||
super<DerivativeStructureAlgebra>.bindSymbolOrNull(value)
|
|
||||||
|
|
||||||
override fun DerivativeStructure<T, A>.unaryMinus(): DerivativeStructure<T, A> {
|
|
||||||
val newData = algebra { data.map(bufferFactory) { -it } }
|
|
||||||
return DerivativeStructure(this@DerivativeStructureRing, newData)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a copy of given [Buffer] and modify it according to [block]
|
|
||||||
*/
|
|
||||||
protected inline fun DerivativeStructure<T, A>.transformDataBuffer(block: DSCompiler<T, A>.(MutableBuffer<T>) -> Unit): DerivativeStructure<T, A> {
|
|
||||||
val newData = bufferFactory(compiler.size) { data[it] }
|
|
||||||
compiler.block(newData)
|
|
||||||
return DerivativeStructure(this@DerivativeStructureRing, newData)
|
|
||||||
}
|
|
||||||
|
|
||||||
protected fun DerivativeStructure<T, A>.mapData(block: (T) -> T): DerivativeStructure<T, A> {
|
|
||||||
val newData: Buffer<T> = data.map(bufferFactory, block)
|
|
||||||
return DerivativeStructure(this@DerivativeStructureRing, newData)
|
|
||||||
}
|
|
||||||
|
|
||||||
protected fun DerivativeStructure<T, A>.mapDataIndexed(block: (Int, T) -> T): DerivativeStructure<T, A> {
|
|
||||||
val newData: Buffer<T> = data.mapIndexed(bufferFactory, block)
|
|
||||||
return DerivativeStructure(this@DerivativeStructureRing, newData)
|
|
||||||
}
|
|
||||||
|
|
||||||
override val zero: DerivativeStructure<T, A> by lazy {
|
|
||||||
const(algebra.zero)
|
|
||||||
}
|
|
||||||
|
|
||||||
override val one: DerivativeStructure<T, A> by lazy {
|
|
||||||
const(algebra.one)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
|
|
||||||
|
|
||||||
override fun add(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
left.compiler.checkCompatibility(right.compiler)
|
|
||||||
return left.transformDataBuffer { result ->
|
|
||||||
add(left.data, 0, right.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun scale(a: DerivativeStructure<T, A>, value: Double): DerivativeStructure<T, A> = algebra {
|
|
||||||
a.mapData { it.times(value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(
|
|
||||||
left: DerivativeStructure<T, A>,
|
|
||||||
right: DerivativeStructure<T, A>,
|
|
||||||
): DerivativeStructure<T, A> {
|
|
||||||
left.compiler.checkCompatibility(right.compiler)
|
|
||||||
return left.transformDataBuffer { result ->
|
|
||||||
multiply(left.data, 0, right.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DerivativeStructure<T, A>.minus(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
compiler.checkCompatibility(arg.compiler)
|
|
||||||
return transformDataBuffer { result ->
|
|
||||||
subtract(data, 0, arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun DerivativeStructure<T, A>.plus(other: Number): DerivativeStructure<T, A> = algebra {
|
|
||||||
transformDataBuffer {
|
|
||||||
it[0] += number(other)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun DerivativeStructure<T, A>.minus(other: Number): DerivativeStructure<T, A> =
|
|
||||||
this + (-other.toDouble())
|
|
||||||
|
|
||||||
override operator fun Number.plus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other + this
|
|
||||||
override operator fun Number.minus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other - this
|
|
||||||
}
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class DerivativeStructureRingExpression<T, A>(
|
|
||||||
public val algebra: A,
|
|
||||||
public val bufferFactory: MutableBufferFactory<T>,
|
|
||||||
public val function: DerivativeStructureRing<T, A>.() -> DerivativeStructure<T, A>,
|
|
||||||
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
|
||||||
DerivativeStructureRing(algebra, bufferFactory, 0, arguments).function().value
|
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
|
||||||
with(
|
|
||||||
DerivativeStructureRing(
|
|
||||||
algebra,
|
|
||||||
bufferFactory,
|
|
||||||
symbols.size,
|
|
||||||
arguments
|
|
||||||
)
|
|
||||||
) { function().derivative(symbols) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A field over commons-math [DerivativeStructure].
|
|
||||||
*
|
|
||||||
* @property order The derivation order.
|
|
||||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class DerivativeStructureField<T, A : ExtendedField<T>>(
|
|
||||||
algebra: A,
|
|
||||||
bufferFactory: MutableBufferFactory<T>,
|
|
||||||
order: Int,
|
|
||||||
bindings: Map<Symbol, T>,
|
|
||||||
) : DerivativeStructureRing<T, A>(algebra, bufferFactory, order, bindings), ExtendedField<DerivativeStructure<T, A>> {
|
|
||||||
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
|
|
||||||
|
|
||||||
override fun divide(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
left.compiler.checkCompatibility(right.compiler)
|
|
||||||
return left.transformDataBuffer { result ->
|
|
||||||
left.compiler.divide(left.data, 0, right.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
sin(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun cos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
cos(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun tan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
tan(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun asin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
asin(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun acos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
acos(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun atan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
atan(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
sinh(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun cosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
cosh(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun tanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
tanh(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun asinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
asinh(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun acosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
acosh(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun atanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
atanh(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun power(arg: DerivativeStructure<T, A>, pow: Number): DerivativeStructure<T, A> = when (pow) {
|
|
||||||
is Int -> arg.transformDataBuffer { result ->
|
|
||||||
pow(arg.data, 0, pow, result, 0)
|
|
||||||
}
|
|
||||||
else -> arg.transformDataBuffer { result ->
|
|
||||||
pow(arg.data, 0, pow.toDouble(), result, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sqrt(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
sqrt(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun power(arg: DerivativeStructure<T, A>, pow: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
|
||||||
arg.compiler.checkCompatibility(pow.compiler)
|
|
||||||
return arg.transformDataBuffer { result ->
|
|
||||||
pow(arg.data, 0, pow.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun exp(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
exp(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun ln(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> = arg.transformDataBuffer { result ->
|
|
||||||
ln(arg.data, 0, result, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class DerivativeStructureFieldExpression<T, A : ExtendedField<T>>(
|
|
||||||
public val algebra: A,
|
|
||||||
public val bufferFactory: MutableBufferFactory<T>,
|
|
||||||
public val function: DerivativeStructureField<T, A>.() -> DerivativeStructure<T, A>,
|
|
||||||
) : DifferentiableExpression<T> {
|
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
|
||||||
DerivativeStructureField(algebra, bufferFactory, 0, arguments).function().value
|
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
|
||||||
with(
|
|
||||||
DerivativeStructureField(
|
|
||||||
algebra,
|
|
||||||
bufferFactory,
|
|
||||||
symbols.size,
|
|
||||||
arguments,
|
|
||||||
)
|
|
||||||
) { function().derivative(symbols) }
|
|
||||||
}
|
|
||||||
}
|
|
@ -19,10 +19,10 @@ import kotlin.test.assertFails
|
|||||||
internal inline fun diff(
|
internal inline fun diff(
|
||||||
order: Int,
|
order: Int,
|
||||||
vararg parameters: Pair<Symbol, Double>,
|
vararg parameters: Pair<Symbol, Double>,
|
||||||
block: DerivativeStructureField<Double, DoubleField>.() -> Unit,
|
block: DSField<Double, DoubleField>.() -> Unit,
|
||||||
) {
|
) {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
DerivativeStructureField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block()
|
DSField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block()
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AutoDiffTest {
|
internal class AutoDiffTest {
|
||||||
|
Loading…
Reference in New Issue
Block a user