Copy DerivativeStructure to multiplatform
This commit is contained in:
parent
ef747f642f
commit
7b1bdc21a4
@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
|
|||||||
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
|
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
|
||||||
break any moment. You can still use it, but be sure to fix the specific version.
|
break any moment. You can still use it, but be sure to fix the specific version.
|
||||||
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
|
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
|
||||||
with `@UnstableKmathAPI` or other stability warning annotations.
|
with `@UnstableKMathAPI` or other stability warning annotations.
|
||||||
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
|
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
|
||||||
versions, but not in patch versions. API is protected
|
versions, but not in patch versions. API is protected
|
||||||
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
||||||
|
2
docs/templates/README-TEMPLATE.md
vendored
2
docs/templates/README-TEMPLATE.md
vendored
@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
|
|||||||
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
|
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
|
||||||
break any moment. You can still use it, but be sure to fix the specific version.
|
break any moment. You can still use it, but be sure to fix the specific version.
|
||||||
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
|
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
|
||||||
with `@UnstableKmathAPI` or other stability warning annotations.
|
with `@UnstableKMathAPI` or other stability warning annotations.
|
||||||
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
|
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
|
||||||
versions, but not in patch versions. API is protected
|
versions, but not in patch versions. API is protected
|
||||||
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,186 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
import space.kscience.kmath.operations.Ring
|
||||||
|
import space.kscience.kmath.operations.ScaleOperations
|
||||||
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class representing both the value and the differentials of a function.
|
||||||
|
*
|
||||||
|
* This class is the workhorse of the differentiation package.
|
||||||
|
*
|
||||||
|
* This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper [Doubly Recursive
|
||||||
|
* Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf),
|
||||||
|
* Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used
|
||||||
|
* throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's
|
||||||
|
* derivative structures hold all partial derivatives up to any specified order, with respect to any number of free
|
||||||
|
* parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free
|
||||||
|
* parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters.
|
||||||
|
*
|
||||||
|
* Derived from
|
||||||
|
* [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java).
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public open class DerivativeStructure<T, A> internal constructor(
|
||||||
|
internal val derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||||
|
internal val compiler: DSCompiler<T, A>,
|
||||||
|
) where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||||
|
/**
|
||||||
|
* Combined array holding all values.
|
||||||
|
*/
|
||||||
|
internal var data: MutableBuffer<T> =
|
||||||
|
derivativeAlgebra.bufferFactory(compiler.size) { derivativeAlgebra.algebra.zero }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance with all values and derivatives set to 0.
|
||||||
|
*
|
||||||
|
* @param parameters number of free parameters.
|
||||||
|
* @param order derivation order.
|
||||||
|
*/
|
||||||
|
public constructor (
|
||||||
|
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||||
|
parameters: Int,
|
||||||
|
order: Int,
|
||||||
|
) : this(
|
||||||
|
derivativeAlgebra,
|
||||||
|
getCompiler<T, A>(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, parameters, order),
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance representing a constant value.
|
||||||
|
*
|
||||||
|
* @param parameters number of free parameters.
|
||||||
|
* @param order derivation order.
|
||||||
|
* @param value value of the constant.
|
||||||
|
* @see DerivativeStructure
|
||||||
|
*/
|
||||||
|
public constructor (
|
||||||
|
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||||
|
parameters: Int,
|
||||||
|
order: Int,
|
||||||
|
value: T,
|
||||||
|
) : this(
|
||||||
|
derivativeAlgebra,
|
||||||
|
parameters,
|
||||||
|
order,
|
||||||
|
) {
|
||||||
|
data[0] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance representing a variable.
|
||||||
|
*
|
||||||
|
* Instances built using this constructor are considered to be the free variables with respect to which
|
||||||
|
* differentials are computed. As such, their differential with respect to themselves is +1.
|
||||||
|
*
|
||||||
|
* @param parameters number of free parameters.
|
||||||
|
* @param order derivation order.
|
||||||
|
* @param index index of the variable (from 0 to `parameters - 1`).
|
||||||
|
* @param value value of the variable.
|
||||||
|
*/
|
||||||
|
public constructor (
|
||||||
|
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||||
|
parameters: Int,
|
||||||
|
order: Int,
|
||||||
|
index: Int,
|
||||||
|
value: T,
|
||||||
|
) : this(derivativeAlgebra, parameters, order, value) {
|
||||||
|
require(index < parameters) { "number is too large: $index >= $parameters" }
|
||||||
|
|
||||||
|
if (order > 0) {
|
||||||
|
// the derivative of the variable with respect to itself is 1.
|
||||||
|
data[getCompiler(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, index, order).size] =
|
||||||
|
derivativeAlgebra.algebra.one
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build an instance from all its derivatives.
|
||||||
|
*
|
||||||
|
* @param parameters number of free parameters.
|
||||||
|
* @param order derivation order.
|
||||||
|
* @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex].
|
||||||
|
*/
|
||||||
|
public constructor (
|
||||||
|
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||||
|
parameters: Int,
|
||||||
|
order: Int,
|
||||||
|
vararg derivatives: T,
|
||||||
|
) : this(
|
||||||
|
derivativeAlgebra,
|
||||||
|
parameters,
|
||||||
|
order,
|
||||||
|
) {
|
||||||
|
require(derivatives.size == data.size) { "dimension mismatch: ${derivatives.size} and ${data.size}" }
|
||||||
|
data = derivativeAlgebra.bufferFactory(data.size) { derivatives[it] }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy constructor.
|
||||||
|
*
|
||||||
|
* @param ds instance to copy.
|
||||||
|
*/
|
||||||
|
internal constructor(ds: DerivativeStructure<T, A>) : this(ds.derivativeAlgebra, ds.compiler) {
|
||||||
|
this.data = ds.data.copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of free parameters.
|
||||||
|
*/
|
||||||
|
public val freeParameters: Int
|
||||||
|
get() = compiler.freeParameters
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The derivation order.
|
||||||
|
*/
|
||||||
|
public val order: Int
|
||||||
|
get() = compiler.order
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The value part of the derivative structure.
|
||||||
|
*
|
||||||
|
* @see getPartialDerivative
|
||||||
|
*/
|
||||||
|
public val value: T
|
||||||
|
get() = data[0]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a partial derivative.
|
||||||
|
*
|
||||||
|
* @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned).
|
||||||
|
* @return partial derivative.
|
||||||
|
* @see value
|
||||||
|
*/
|
||||||
|
public fun getPartialDerivative(vararg orders: Int): T = data[compiler.getPartialDerivativeIndex(*orders)]
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test for the equality of two derivative structures.
|
||||||
|
*
|
||||||
|
* Derivative structures are considered equal if they have the same number
|
||||||
|
* of free parameters, the same derivation order, and the same derivatives.
|
||||||
|
*
|
||||||
|
* @return `true` if two derivative structures are equal.
|
||||||
|
*/
|
||||||
|
public override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
|
||||||
|
if (other is DerivativeStructure<*, *>) {
|
||||||
|
return ((freeParameters == other.freeParameters) &&
|
||||||
|
(order == other.order) &&
|
||||||
|
data == other.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun hashCode(): Int =
|
||||||
|
227 + 229 * freeParameters + 233 * order + 239 * data.hashCode()
|
||||||
|
}
|
@ -0,0 +1,332 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.*
|
||||||
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
|
import space.kscience.kmath.structures.indices
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class implementing both [DerivativeStructure] and [Symbol].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DerivativeStructureSymbol<T, A>(
|
||||||
|
derivativeAlgebra: DerivativeStructureRing<T, A>,
|
||||||
|
size: Int,
|
||||||
|
order: Int,
|
||||||
|
index: Int,
|
||||||
|
symbol: Symbol,
|
||||||
|
value: T,
|
||||||
|
) : Symbol by symbol, DerivativeStructure<T, A>(
|
||||||
|
derivativeAlgebra,
|
||||||
|
size,
|
||||||
|
order,
|
||||||
|
index,
|
||||||
|
value
|
||||||
|
) where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||||
|
override fun toString(): String = symbol.toString()
|
||||||
|
override fun equals(other: Any?): Boolean = (other as? Symbol) == symbol
|
||||||
|
override fun hashCode(): Int = symbol.hashCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A ring over [DerivativeStructure].
|
||||||
|
*
|
||||||
|
* @property order The derivation order.
|
||||||
|
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public open class DerivativeStructureRing<T, A>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
|
public val order: Int,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : Ring<DerivativeStructure<T, A>>, ScaleOperations<DerivativeStructure<T, A>>,
|
||||||
|
NumericAlgebra<DerivativeStructure<T, A>>,
|
||||||
|
ExpressionAlgebra<T, DerivativeStructure<T, A>>,
|
||||||
|
NumbersAddOps<DerivativeStructure<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
|
||||||
|
public val numberOfVariables: Int = bindings.size
|
||||||
|
|
||||||
|
override val zero: DerivativeStructure<T, A> by lazy {
|
||||||
|
DerivativeStructure(
|
||||||
|
this,
|
||||||
|
numberOfVariables,
|
||||||
|
order,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val one: DerivativeStructure<T, A> by lazy {
|
||||||
|
DerivativeStructure(
|
||||||
|
this,
|
||||||
|
numberOfVariables,
|
||||||
|
order,
|
||||||
|
algebra.one,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
|
||||||
|
|
||||||
|
private val variables: Map<Symbol, DerivativeStructureSymbol<T, A>> =
|
||||||
|
bindings.entries.mapIndexed { index, (key, value) ->
|
||||||
|
key to DerivativeStructureSymbol(
|
||||||
|
this,
|
||||||
|
numberOfVariables,
|
||||||
|
order,
|
||||||
|
index,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
}.toMap()
|
||||||
|
|
||||||
|
public override fun const(value: T): DerivativeStructure<T, A> =
|
||||||
|
DerivativeStructure(this, numberOfVariables, order, value)
|
||||||
|
|
||||||
|
override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol<T, A>? = variables[StringSymbol(value)]
|
||||||
|
|
||||||
|
override fun bindSymbol(value: String): DerivativeStructureSymbol<T, A> =
|
||||||
|
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
|
||||||
|
|
||||||
|
public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol<T, A>? = variables[symbol.identity]
|
||||||
|
|
||||||
|
public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol<T, A> =
|
||||||
|
bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this")
|
||||||
|
|
||||||
|
public fun DerivativeStructure<T, A>.derivative(symbols: List<Symbol>): T {
|
||||||
|
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
|
||||||
|
val ordersCount = symbols.groupBy { it }.mapValues { it.value.size }
|
||||||
|
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun DerivativeStructure<T, A>.derivative(vararg symbols: Symbol): T = derivative(symbols.toList())
|
||||||
|
|
||||||
|
override fun DerivativeStructure<T, A>.unaryMinus(): DerivativeStructure<T, A> {
|
||||||
|
val ds = DerivativeStructure(this@DerivativeStructureRing, compiler)
|
||||||
|
for (i in ds.data.indices) {
|
||||||
|
ds.data[i] = algebra { -data[i] }
|
||||||
|
}
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun add(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
left.compiler.checkCompatibility(right.compiler)
|
||||||
|
val ds = DerivativeStructure(left)
|
||||||
|
left.compiler.add(left.data, 0, right.data, 0, ds.data, 0)
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun scale(a: DerivativeStructure<T, A>, value: Double): DerivativeStructure<T, A> {
|
||||||
|
val ds = DerivativeStructure(a)
|
||||||
|
for (i in ds.data.indices) {
|
||||||
|
ds.data[i] = algebra { ds.data[i].times(value) }
|
||||||
|
}
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(
|
||||||
|
left: DerivativeStructure<T, A>,
|
||||||
|
right: DerivativeStructure<T, A>
|
||||||
|
): DerivativeStructure<T, A> {
|
||||||
|
left.compiler.checkCompatibility(right.compiler)
|
||||||
|
val result = DerivativeStructure(this, left.compiler)
|
||||||
|
left.compiler.multiply(left.data, 0, right.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DerivativeStructure<T, A>.minus(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
compiler.checkCompatibility(arg.compiler)
|
||||||
|
val ds = DerivativeStructure(this)
|
||||||
|
compiler.subtract(data, 0, arg.data, 0, ds.data, 0)
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun DerivativeStructure<T, A>.plus(other: Number): DerivativeStructure<T, A> {
|
||||||
|
val ds = DerivativeStructure(this)
|
||||||
|
ds.data[0] = algebra { ds.data[0] + number(other) }
|
||||||
|
return ds
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun DerivativeStructure<T, A>.minus(other: Number): DerivativeStructure<T, A> =
|
||||||
|
this + -other.toDouble()
|
||||||
|
|
||||||
|
override operator fun Number.plus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other + this
|
||||||
|
override operator fun Number.minus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other - this
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DerivativeStructureRingExpression<T, A>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
|
public val function: DerivativeStructureRing<T, A>.() -> DerivativeStructure<T, A>,
|
||||||
|
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
||||||
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
|
DerivativeStructureRing(algebra, bufferFactory, 0, arguments).function().value
|
||||||
|
|
||||||
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||||
|
with(
|
||||||
|
DerivativeStructureRing(
|
||||||
|
algebra,
|
||||||
|
bufferFactory,
|
||||||
|
symbols.size,
|
||||||
|
arguments
|
||||||
|
)
|
||||||
|
) { function().derivative(symbols) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field over commons-math [DerivativeStructure].
|
||||||
|
*
|
||||||
|
* @property order The derivation order.
|
||||||
|
* @param bindings The map of bindings values. All bindings are considered free parameters.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DerivativeStructureField<T, A : ExtendedField<T>>(
|
||||||
|
algebra: A,
|
||||||
|
bufferFactory: MutableBufferFactory<T>,
|
||||||
|
order: Int,
|
||||||
|
bindings: Map<Symbol, T>,
|
||||||
|
) : DerivativeStructureRing<T, A>(algebra, bufferFactory, order, bindings), ExtendedField<DerivativeStructure<T, A>> {
|
||||||
|
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
|
||||||
|
|
||||||
|
override fun divide(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
left.compiler.checkCompatibility(right.compiler)
|
||||||
|
val result = DerivativeStructure(this, left.compiler)
|
||||||
|
left.compiler.divide(left.data, 0, right.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.sin(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.cos(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun tan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.tan(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun asin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.asin(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun acos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.acos(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun atan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.atan(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.sinh(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.cosh(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun tanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.tanh(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun asinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.asinh(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun acosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.acosh(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun atanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.atanh(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun power(arg: DerivativeStructure<T, A>, pow: Number): DerivativeStructure<T, A> = when (pow) {
|
||||||
|
is Int -> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.pow(arg.data, 0, pow, result.data, 0)
|
||||||
|
result
|
||||||
|
}
|
||||||
|
else -> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.pow(arg.data, 0, pow.toDouble(), result.data, 0)
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sqrt(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.sqrt(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun power(arg: DerivativeStructure<T, A>, pow: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
arg.compiler.checkCompatibility(pow.compiler)
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.pow(arg.data, 0, pow.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun exp(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.exp(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun ln(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
|
||||||
|
val result = DerivativeStructure(this, arg.compiler)
|
||||||
|
arg.compiler.ln(arg.data, 0, result.data, 0)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class DerivativeStructureFieldExpression<T, A : ExtendedField<T>>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
|
public val function: DerivativeStructureField<T, A>.() -> DerivativeStructure<T, A>,
|
||||||
|
) : DifferentiableExpression<T> {
|
||||||
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
|
DerivativeStructureField(algebra, bufferFactory, 0, arguments).function().value
|
||||||
|
|
||||||
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
||||||
|
with(
|
||||||
|
DerivativeStructureField(
|
||||||
|
algebra,
|
||||||
|
bufferFactory,
|
||||||
|
symbols.size,
|
||||||
|
arguments,
|
||||||
|
)
|
||||||
|
) { function().derivative(symbols) }
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,59 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
@file:OptIn(UnstableKMathAPI::class)
|
||||||
|
|
||||||
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
|
import kotlin.contracts.InvocationKind
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertFails
|
||||||
|
|
||||||
|
internal inline fun diff(
|
||||||
|
order: Int,
|
||||||
|
vararg parameters: Pair<Symbol, Double>,
|
||||||
|
block: DerivativeStructureField<Double, DoubleField>.() -> Unit,
|
||||||
|
) {
|
||||||
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
DerivativeStructureField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block()
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AutoDiffTest {
|
||||||
|
private val x by symbol
|
||||||
|
private val y by symbol
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun derivativeStructureFieldTest() {
|
||||||
|
diff(2, x to 1.0, y to 1.0) {
|
||||||
|
val x = bindSymbol(x)//by binding()
|
||||||
|
val y = bindSymbol("y")
|
||||||
|
val z = x * (-sin(x * y) + y) + 2.0
|
||||||
|
println(z.derivative(x))
|
||||||
|
println(z.derivative(y, x))
|
||||||
|
assertEquals(z.derivative(x, y), z.derivative(y, x))
|
||||||
|
// check improper order cause failure
|
||||||
|
assertFails { z.derivative(x, x, y) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun autoDifTest() {
|
||||||
|
val f = DerivativeStructureFieldExpression(DoubleField, ::DoubleBuffer) {
|
||||||
|
val x by binding
|
||||||
|
val y by binding
|
||||||
|
x.pow(2) + 2 * x * y + y.pow(2) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(10.0, f(x to 1.0, y to 2.0))
|
||||||
|
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
|
||||||
|
assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0))
|
||||||
|
assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0))
|
||||||
|
}
|
||||||
|
}
|
@ -48,7 +48,8 @@ internal object InternalUtils {
|
|||||||
cache.copyInto(
|
cache.copyInto(
|
||||||
logFactorials,
|
logFactorials,
|
||||||
BEGIN_LOG_FACTORIALS,
|
BEGIN_LOG_FACTORIALS,
|
||||||
BEGIN_LOG_FACTORIALS, endCopy
|
BEGIN_LOG_FACTORIALS,
|
||||||
|
endCopy,
|
||||||
)
|
)
|
||||||
} else
|
} else
|
||||||
// All values to be computed
|
// All values to be computed
|
||||||
|
Loading…
Reference in New Issue
Block a user