Copy DerivativeStructure to multiplatform
This commit is contained in:
parent
ef747f642f
commit
7b1bdc21a4
@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
|
||||
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
|
||||
break any moment. You can still use it, but be sure to fix the specific version.
|
||||
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
|
||||
with `@UnstableKmathAPI` or other stability warning annotations.
|
||||
with `@UnstableKMathAPI` or other stability warning annotations.
|
||||
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
|
||||
versions, but not in patch versions. API is protected
|
||||
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
||||
|
2
docs/templates/README-TEMPLATE.md
vendored
2
docs/templates/README-TEMPLATE.md
vendored
@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
|
||||
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
|
||||
break any moment. You can still use it, but be sure to fix the specific version.
|
||||
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
|
||||
with `@UnstableKmathAPI` or other stability warning annotations.
|
||||
with `@UnstableKMathAPI` or other stability warning annotations.
|
||||
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
|
||||
versions, but not in patch versions. API is protected
|
||||
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,186 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.expressions
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.NumericAlgebra
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import space.kscience.kmath.operations.ScaleOperations
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
|
||||
/**
|
||||
* Class representing both the value and the differentials of a function.
|
||||
*
|
||||
* This class is the workhorse of the differentiation package.
|
||||
*
|
||||
* This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper [Doubly Recursive
|
||||
* Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf),
|
||||
* Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used
|
||||
* throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's
|
||||
* derivative structures hold all partial derivatives up to any specified order, with respect to any number of free
|
||||
* parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free
|
||||
* parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters.
|
||||
*
|
||||
* Derived from
|
||||
* [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java).
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public open class DerivativeStructure<T, A> internal constructor(
|
||||
internal val derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||
internal val compiler: DSCompiler<T, A>,
|
||||
) where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||
/**
|
||||
* Combined array holding all values.
|
||||
*/
|
||||
internal var data: MutableBuffer<T> =
|
||||
derivativeAlgebra.bufferFactory(compiler.size) { derivativeAlgebra.algebra.zero }
|
||||
|
||||
/**
|
||||
* Build an instance with all values and derivatives set to 0.
|
||||
*
|
||||
* @param parameters number of free parameters.
|
||||
* @param order derivation order.
|
||||
*/
|
||||
public constructor (
|
||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||
parameters: Int,
|
||||
order: Int,
|
||||
) : this(
|
||||
derivativeAlgebra,
|
||||
getCompiler<T, A>(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, parameters, order),
|
||||
)
|
||||
|
||||
/**
|
||||
* Build an instance representing a constant value.
|
||||
*
|
||||
* @param parameters number of free parameters.
|
||||
* @param order derivation order.
|
||||
* @param value value of the constant.
|
||||
* @see DerivativeStructure
|
||||
*/
|
||||
public constructor (
|
||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||
parameters: Int,
|
||||
order: Int,
|
||||
value: T,
|
||||
) : this(
|
||||
derivativeAlgebra,
|
||||
parameters,
|
||||
order,
|
||||
) {
|
||||
data[0] = value
|
||||
}
|
||||
|
||||
/**
|
||||
* Build an instance representing a variable.
|
||||
*
|
||||
* Instances built using this constructor are considered to be the free variables with respect to which
|
||||
* differentials are computed. As such, their differential with respect to themselves is +1.
|
||||
*
|
||||
* @param parameters number of free parameters.
|
||||
* @param order derivation order.
|
||||
* @param index index of the variable (from 0 to `parameters - 1`).
|
||||
* @param value value of the variable.
|
||||
*/
|
||||
public constructor (
|
||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||
parameters: Int,
|
||||
order: Int,
|
||||
index: Int,
|
||||
value: T,
|
||||
) : this(derivativeAlgebra, parameters, order, value) {
|
||||
require(index < parameters) { "number is too large: $index >= $parameters" }
|
||||
|
||||
if (order > 0) {
|
||||
// the derivative of the variable with respect to itself is 1.
|
||||
data[getCompiler(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, index, order).size] =
|
||||
derivativeAlgebra.algebra.one
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build an instance from all its derivatives.
|
||||
*
|
||||
* @param parameters number of free parameters.
|
||||
* @param order derivation order.
|
||||
* @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex].
|
||||
*/
|
||||
public constructor (
|
||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||
parameters: Int,
|
||||
order: Int,
|
||||
vararg derivatives: T,
|
||||
) : this(
|
||||
derivativeAlgebra,
|
||||
parameters,
|
||||
order,
|
||||
) {
|
||||
require(derivatives.size == data.size) { "dimension mismatch: ${derivatives.size} and ${data.size}" }
|
||||
data = derivativeAlgebra.bufferFactory(data.size) { derivatives[it] }
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy constructor.
|
||||
*
|
||||
* @param ds instance to copy.
|
||||
*/
|
||||
internal constructor(ds: DerivativeStructure<T, A>) : this(ds.derivativeAlgebra, ds.compiler) {
|
||||
this.data = ds.data.copy()
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of free parameters.
|
||||
*/
|
||||
public val freeParameters: Int
|
||||
get() = compiler.freeParameters
|
||||
|
||||
/**
|
||||
* The derivation order.
|
||||
*/
|
||||
public val order: Int
|
||||
get() = compiler.order
|
||||
|
||||
/**
|
||||
* The value part of the derivative structure.
|
||||
*
|
||||
* @see getPartialDerivative
|
||||
*/
|
||||
public val value: T
|
||||
get() = data[0]
|
||||
|
||||
/**
|
||||
* Get a partial derivative.
|
||||
*
|
||||
* @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned).
|
||||
* @return partial derivative.
|
||||
* @see value
|
||||
*/
|
||||
public fun getPartialDerivative(vararg orders: Int): T = data[compiler.getPartialDerivativeIndex(*orders)]
|
||||
|
||||
|
||||
/**
|
||||
* Test for the equality of two derivative structures.
|
||||
*
|
||||
* Derivative structures are considered equal if they have the same number
|
||||
* of free parameters, the same derivation order, and the same derivatives.
|
||||
*
|
||||
* @return `true` if two derivative structures are equal.
|
||||
*/
|
||||
public override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
|
||||
if (other is DerivativeStructure<*, *>) {
|
||||
return ((freeParameters == other.freeParameters) &&
|
||||
(order == other.order) &&
|
||||
data == other.data)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int =
|
||||
227 + 229 * freeParameters + 233 * order + 239 * data.hashCode()
|
||||
}
|
@ -0,0 +1,332 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.expressions
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.MutableBufferFactory
|
||||
import space.kscience.kmath.structures.indices
|
||||
|
||||
/**
|
||||
* A class implementing both [DerivativeStructure] and [Symbol].
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class DerivativeStructureSymbol<T, A>(
|
||||
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||
size: Int,
|
||||
order: Int,
|
||||
index: Int,
|
||||
symbol: Symbol,
|
||||
value: T,
|
||||
) : Symbol by symbol, DerivativeStructure<T, A>(
|
||||
derivativeAlgebra,
|
||||
size,
|
||||
order,
|
||||
index,
|
||||
value
|
||||
) where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||
override fun toString(): String = symbol.toString()
|
||||
override fun equals(other: Any?): Boolean = (other as? Symbol) == symbol
|
||||
override fun hashCode(): Int = symbol.hashCode()
|
||||
}
|
||||
|
||||
/**
|
||||
* A ring over [DerivativeStructure].
|
||||
*
|
||||
* @property order The derivation order.
|
||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public open class DerivativeStructureRing<T, A>(
|
||||
public val algebra: A,
|
||||
public val bufferFactory: MutableBufferFactory<T>,
|
||||
public val order: Int,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : Ring<DerivativeStructure<T, A>>, ScaleOperations<DerivativeStructure<T, A>>,
|
||||
NumericAlgebra<DerivativeStructure<T, A>>,
|
||||
ExpressionAlgebra<T, DerivativeStructure<T, A>>,
|
||||
NumbersAddOps<DerivativeStructure<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||
public val numberOfVariables: Int = bindings.size
|
||||
|
||||
override val zero: DerivativeStructure<T, A> by lazy {
|
||||
DerivativeStructure(
|
||||
this,
|
||||
numberOfVariables,
|
||||
order,
|
||||
)
|
||||
}
|
||||
|
||||
override val one: DerivativeStructure<T, A> by lazy {
|
||||
DerivativeStructure(
|
||||
this,
|
||||
numberOfVariables,
|
||||
order,
|
||||
algebra.one,
|
||||
)
|
||||
}
|
||||
|
||||
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
|
||||
|
||||
private val variables: Map<Symbol, DerivativeStructureSymbol<T, A>> =
|
||||
bindings.entries.mapIndexed { index, (key, value) ->
|
||||
key to DerivativeStructureSymbol(
|
||||
this,
|
||||
numberOfVariables,
|
||||
order,
|
||||
index,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
}.toMap()
|
||||
|
||||
public override fun const(value: T): DerivativeStructure<T, A> =
|
||||
DerivativeStructure(this, numberOfVariables, order, value)
|
||||
|
||||
override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol<T, A>? = variables[StringSymbol(value)]
|
||||
|
||||
override fun bindSymbol(value: String): DerivativeStructureSymbol<T, A> =
|
||||
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
|
||||
|
||||
public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol<T, A>? = variables[symbol.identity]
|
||||
|
||||
public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol<T, A> =
|
||||
bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this")
|
||||
|
||||
public fun DerivativeStructure<T, A>.derivative(symbols: List<Symbol>): T {
|
||||
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
||||
val ordersCount = symbols.groupBy { it }.mapValues { it.value.size }
|
||||
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
|
||||
public fun DerivativeStructure<T, A>.derivative(vararg symbols: Symbol): T = derivative(symbols.toList())
|
||||
|
||||
override fun DerivativeStructure<T, A>.unaryMinus(): DerivativeStructure<T, A> {
|
||||
val ds = DerivativeStructure(this@DerivativeStructureRing, compiler)
|
||||
for (i in ds.data.indices) {
|
||||
ds.data[i] = algebra { -data[i] }
|
||||
}
|
||||
return ds
|
||||
}
|
||||
|
||||
override fun add(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
left.compiler.checkCompatibility(right.compiler)
|
||||
val ds = DerivativeStructure(left)
|
||||
left.compiler.add(left.data, 0, right.data, 0, ds.data, 0)
|
||||
return ds
|
||||
}
|
||||
|
||||
override fun scale(a: DerivativeStructure<T, A>, value: Double): DerivativeStructure<T, A> {
|
||||
val ds = DerivativeStructure(a)
|
||||
for (i in ds.data.indices) {
|
||||
ds.data[i] = algebra { ds.data[i].times(value) }
|
||||
}
|
||||
return ds
|
||||
}
|
||||
|
||||
override fun multiply(
|
||||
left: DerivativeStructure<T, A>,
|
||||
right: DerivativeStructure<T, A>
|
||||
): DerivativeStructure<T, A> {
|
||||
left.compiler.checkCompatibility(right.compiler)
|
||||
val result = DerivativeStructure(this, left.compiler)
|
||||
left.compiler.multiply(left.data, 0, right.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun DerivativeStructure<T, A>.minus(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
compiler.checkCompatibility(arg.compiler)
|
||||
val ds = DerivativeStructure(this)
|
||||
compiler.subtract(data, 0, arg.data, 0, ds.data, 0)
|
||||
return ds
|
||||
}
|
||||
|
||||
override operator fun DerivativeStructure<T, A>.plus(other: Number): DerivativeStructure<T, A> {
|
||||
val ds = DerivativeStructure(this)
|
||||
ds.data[0] = algebra { ds.data[0] + number(other) }
|
||||
return ds
|
||||
}
|
||||
|
||||
override operator fun DerivativeStructure<T, A>.minus(other: Number): DerivativeStructure<T, A> =
|
||||
this + -other.toDouble()
|
||||
|
||||
override operator fun Number.plus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other + this
|
||||
override operator fun Number.minus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other - this
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public class DerivativeStructureRingExpression<T, A>(
|
||||
public val algebra: A,
|
||||
public val bufferFactory: MutableBufferFactory<T>,
|
||||
public val function: DerivativeStructureRing<T, A>.() -> DerivativeStructure<T, A>,
|
||||
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||
DerivativeStructureRing(algebra, bufferFactory, 0, arguments).function().value
|
||||
|
||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||
with(
|
||||
DerivativeStructureRing(
|
||||
algebra,
|
||||
bufferFactory,
|
||||
symbols.size,
|
||||
arguments
|
||||
)
|
||||
) { function().derivative(symbols) }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A field over commons-math [DerivativeStructure].
|
||||
*
|
||||
* @property order The derivation order.
|
||||
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class DerivativeStructureField<T, A : ExtendedField<T>>(
|
||||
algebra: A,
|
||||
bufferFactory: MutableBufferFactory<T>,
|
||||
order: Int,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : DerivativeStructureRing<T, A>(algebra, bufferFactory, order, bindings), ExtendedField<DerivativeStructure<T, A>> {
|
||||
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
|
||||
|
||||
override fun divide(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
left.compiler.checkCompatibility(right.compiler)
|
||||
val result = DerivativeStructure(this, left.compiler)
|
||||
left.compiler.divide(left.data, 0, right.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun sin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.sin(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun cos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.cos(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun tan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.tan(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun asin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.asin(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun acos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.acos(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun atan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.atan(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun sinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.sinh(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun cosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.cosh(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun tanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.tanh(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun asinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.asinh(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun acosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.acosh(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun atanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.atanh(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun power(arg: DerivativeStructure<T, A>, pow: Number): DerivativeStructure<T, A> = when (pow) {
|
||||
is Int -> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.pow(arg.data, 0, pow, result.data, 0)
|
||||
result
|
||||
}
|
||||
else -> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.pow(arg.data, 0, pow.toDouble(), result.data, 0)
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
override fun sqrt(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.sqrt(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
public fun power(arg: DerivativeStructure<T, A>, pow: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
arg.compiler.checkCompatibility(pow.compiler)
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.pow(arg.data, 0, pow.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun exp(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.exp(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
|
||||
override fun ln(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||
val result = DerivativeStructure(this, arg.compiler)
|
||||
arg.compiler.ln(arg.data, 0, result.data, 0)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public class DerivativeStructureFieldExpression<T, A : ExtendedField<T>>(
|
||||
public val algebra: A,
|
||||
public val bufferFactory: MutableBufferFactory<T>,
|
||||
public val function: DerivativeStructureField<T, A>.() -> DerivativeStructure<T, A>,
|
||||
) : DifferentiableExpression<T> {
|
||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||
DerivativeStructureField(algebra, bufferFactory, 0, arguments).function().value
|
||||
|
||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||
with(
|
||||
DerivativeStructureField(
|
||||
algebra,
|
||||
bufferFactory,
|
||||
symbols.size,
|
||||
arguments,
|
||||
)
|
||||
) { function().derivative(symbols) }
|
||||
}
|
||||
}
|
@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
@file:OptIn(UnstableKMathAPI::class)
|
||||
|
||||
package space.kscience.kmath.expressions
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFails
|
||||
|
||||
internal inline fun diff(
|
||||
order: Int,
|
||||
vararg parameters: Pair<Symbol, Double>,
|
||||
block: DerivativeStructureField<Double, DoubleField>.() -> Unit,
|
||||
) {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
DerivativeStructureField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block()
|
||||
}
|
||||
|
||||
internal class AutoDiffTest {
|
||||
private val x by symbol
|
||||
private val y by symbol
|
||||
|
||||
@Test
|
||||
fun derivativeStructureFieldTest() {
|
||||
diff(2, x to 1.0, y to 1.0) {
|
||||
val x = bindSymbol(x)//by binding()
|
||||
val y = bindSymbol("y")
|
||||
val z = x * (-sin(x * y) + y) + 2.0
|
||||
println(z.derivative(x))
|
||||
println(z.derivative(y, x))
|
||||
assertEquals(z.derivative(x, y), z.derivative(y, x))
|
||||
// check improper order cause failure
|
||||
assertFails { z.derivative(x, x, y) }
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun autoDifTest() {
|
||||
val f = DerivativeStructureFieldExpression(DoubleField, ::DoubleBuffer) {
|
||||
val x by binding
|
||||
val y by binding
|
||||
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||
}
|
||||
|
||||
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||
assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0))
|
||||
assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0))
|
||||
}
|
||||
}
|
@ -48,7 +48,8 @@ internal object InternalUtils {
|
||||
cache.copyInto(
|
||||
logFactorials,
|
||||
BEGIN_LOG_FACTORIALS,
|
||||
BEGIN_LOG_FACTORIALS, endCopy
|
||||
BEGIN_LOG_FACTORIALS,
|
||||
endCopy,
|
||||
)
|
||||
} else
|
||||
// All values to be computed
|
||||
|
Loading…
Reference in New Issue
Block a user