Copy DerivativeStructure to multiplatform

This commit is contained in:
Iaroslav Postovalov 2021-08-12 16:37:53 +03:00
parent ef747f642f
commit 7b1bdc21a4
7 changed files with 2122 additions and 3 deletions

View File

@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
break any moment. You can still use it, but be sure to fix the specific version.
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
with `@UnstableKmathAPI` or other stability warning annotations.
with `@UnstableKMathAPI` or other stability warning annotations.
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
versions, but not in patch versions. API is protected
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.

View File

@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels:
* **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could
break any moment. You can still use it, but be sure to fix the specific version.
* **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked
with `@UnstableKmathAPI` or other stability warning annotations.
with `@UnstableKMathAPI` or other stability warning annotations.
* **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor
versions, but not in patch versions. API is protected
with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,186 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.expressions
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.NumericAlgebra
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.structures.MutableBuffer
/**
* Class representing both the value and the differentials of a function.
*
* This class is the workhorse of the differentiation package.
*
* This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper [Doubly Recursive
* Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf),
* Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used
* throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's
* derivative structures hold all partial derivatives up to any specified order, with respect to any number of free
* parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free
* parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters.
*
* Derived from
* [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java).
*/
@UnstableKMathAPI
public open class DerivativeStructure<T, A> internal constructor(
internal val derivativeAlgebra: DerivativeStructureRing<T, A>,
internal val compiler: DSCompiler<T, A>,
) where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
/**
* Combined array holding all values.
*/
internal var data: MutableBuffer<T> =
derivativeAlgebra.bufferFactory(compiler.size) { derivativeAlgebra.algebra.zero }
/**
* Build an instance with all values and derivatives set to 0.
*
* @param parameters number of free parameters.
* @param order derivation order.
*/
public constructor (
derivativeAlgebra: DerivativeStructureRing<T, A>,
parameters: Int,
order: Int,
) : this(
derivativeAlgebra,
getCompiler<T, A>(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, parameters, order),
)
/**
* Build an instance representing a constant value.
*
* @param parameters number of free parameters.
* @param order derivation order.
* @param value value of the constant.
* @see DerivativeStructure
*/
public constructor (
derivativeAlgebra: DerivativeStructureRing<T, A>,
parameters: Int,
order: Int,
value: T,
) : this(
derivativeAlgebra,
parameters,
order,
) {
data[0] = value
}
/**
* Build an instance representing a variable.
*
* Instances built using this constructor are considered to be the free variables with respect to which
* differentials are computed. As such, their differential with respect to themselves is +1.
*
* @param parameters number of free parameters.
* @param order derivation order.
* @param index index of the variable (from 0 to `parameters - 1`).
* @param value value of the variable.
*/
public constructor (
derivativeAlgebra: DerivativeStructureRing<T, A>,
parameters: Int,
order: Int,
index: Int,
value: T,
) : this(derivativeAlgebra, parameters, order, value) {
require(index < parameters) { "number is too large: $index >= $parameters" }
if (order > 0) {
// the derivative of the variable with respect to itself is 1.
data[getCompiler(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, index, order).size] =
derivativeAlgebra.algebra.one
}
}
/**
* Build an instance from all its derivatives.
*
* @param parameters number of free parameters.
* @param order derivation order.
* @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex].
*/
public constructor (
derivativeAlgebra: DerivativeStructureRing<T, A>,
parameters: Int,
order: Int,
vararg derivatives: T,
) : this(
derivativeAlgebra,
parameters,
order,
) {
require(derivatives.size == data.size) { "dimension mismatch: ${derivatives.size} and ${data.size}" }
data = derivativeAlgebra.bufferFactory(data.size) { derivatives[it] }
}
/**
* Copy constructor.
*
* @param ds instance to copy.
*/
internal constructor(ds: DerivativeStructure<T, A>) : this(ds.derivativeAlgebra, ds.compiler) {
this.data = ds.data.copy()
}
/**
* The number of free parameters.
*/
public val freeParameters: Int
get() = compiler.freeParameters
/**
* The derivation order.
*/
public val order: Int
get() = compiler.order
/**
* The value part of the derivative structure.
*
* @see getPartialDerivative
*/
public val value: T
get() = data[0]
/**
* Get a partial derivative.
*
* @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned).
* @return partial derivative.
* @see value
*/
public fun getPartialDerivative(vararg orders: Int): T = data[compiler.getPartialDerivativeIndex(*orders)]
/**
* Test for the equality of two derivative structures.
*
* Derivative structures are considered equal if they have the same number
* of free parameters, the same derivation order, and the same derivatives.
*
* @return `true` if two derivative structures are equal.
*/
public override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other is DerivativeStructure<*, *>) {
return ((freeParameters == other.freeParameters) &&
(order == other.order) &&
data == other.data)
}
return false
}
public override fun hashCode(): Int =
227 + 229 * freeParameters + 233 * order + 239 * data.hashCode()
}

View File

@ -0,0 +1,332 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.expressions
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.MutableBufferFactory
import space.kscience.kmath.structures.indices
/**
* A class implementing both [DerivativeStructure] and [Symbol].
*/
@UnstableKMathAPI
public class DerivativeStructureSymbol<T, A>(
derivativeAlgebra: DerivativeStructureRing<T, A>,
size: Int,
order: Int,
index: Int,
symbol: Symbol,
value: T,
) : Symbol by symbol, DerivativeStructure<T, A>(
derivativeAlgebra,
size,
order,
index,
value
) where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
override fun toString(): String = symbol.toString()
override fun equals(other: Any?): Boolean = (other as? Symbol) == symbol
override fun hashCode(): Int = symbol.hashCode()
}
/**
* A ring over [DerivativeStructure].
*
* @property order The derivation order.
* @param bindings The map of bindings values. All bindings are considered free parameters.
*/
@UnstableKMathAPI
public open class DerivativeStructureRing<T, A>(
public val algebra: A,
public val bufferFactory: MutableBufferFactory<T>,
public val order: Int,
bindings: Map<Symbol, T>,
) : Ring<DerivativeStructure<T, A>>, ScaleOperations<DerivativeStructure<T, A>>,
NumericAlgebra<DerivativeStructure<T, A>>,
ExpressionAlgebra<T, DerivativeStructure<T, A>>,
NumbersAddOps<DerivativeStructure<T, A>> where A : Ring<T>, A : NumericAlgebra<T>, A : ScaleOperations<T> {
public val numberOfVariables: Int = bindings.size
override val zero: DerivativeStructure<T, A> by lazy {
DerivativeStructure(
this,
numberOfVariables,
order,
)
}
override val one: DerivativeStructure<T, A> by lazy {
DerivativeStructure(
this,
numberOfVariables,
order,
algebra.one,
)
}
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
private val variables: Map<Symbol, DerivativeStructureSymbol<T, A>> =
bindings.entries.mapIndexed { index, (key, value) ->
key to DerivativeStructureSymbol(
this,
numberOfVariables,
order,
index,
key,
value,
)
}.toMap()
public override fun const(value: T): DerivativeStructure<T, A> =
DerivativeStructure(this, numberOfVariables, order, value)
override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol<T, A>? = variables[StringSymbol(value)]
override fun bindSymbol(value: String): DerivativeStructureSymbol<T, A> =
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol<T, A>? = variables[symbol.identity]
public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol<T, A> =
bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this")
public fun DerivativeStructure<T, A>.derivative(symbols: List<Symbol>): T {
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
val ordersCount = symbols.groupBy { it }.mapValues { it.value.size }
return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray())
}
public fun DerivativeStructure<T, A>.derivative(vararg symbols: Symbol): T = derivative(symbols.toList())
override fun DerivativeStructure<T, A>.unaryMinus(): DerivativeStructure<T, A> {
val ds = DerivativeStructure(this@DerivativeStructureRing, compiler)
for (i in ds.data.indices) {
ds.data[i] = algebra { -data[i] }
}
return ds
}
override fun add(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
left.compiler.checkCompatibility(right.compiler)
val ds = DerivativeStructure(left)
left.compiler.add(left.data, 0, right.data, 0, ds.data, 0)
return ds
}
override fun scale(a: DerivativeStructure<T, A>, value: Double): DerivativeStructure<T, A> {
val ds = DerivativeStructure(a)
for (i in ds.data.indices) {
ds.data[i] = algebra { ds.data[i].times(value) }
}
return ds
}
override fun multiply(
left: DerivativeStructure<T, A>,
right: DerivativeStructure<T, A>
): DerivativeStructure<T, A> {
left.compiler.checkCompatibility(right.compiler)
val result = DerivativeStructure(this, left.compiler)
left.compiler.multiply(left.data, 0, right.data, 0, result.data, 0)
return result
}
override fun DerivativeStructure<T, A>.minus(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
compiler.checkCompatibility(arg.compiler)
val ds = DerivativeStructure(this)
compiler.subtract(data, 0, arg.data, 0, ds.data, 0)
return ds
}
override operator fun DerivativeStructure<T, A>.plus(other: Number): DerivativeStructure<T, A> {
val ds = DerivativeStructure(this)
ds.data[0] = algebra { ds.data[0] + number(other) }
return ds
}
override operator fun DerivativeStructure<T, A>.minus(other: Number): DerivativeStructure<T, A> =
this + -other.toDouble()
override operator fun Number.plus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other + this
override operator fun Number.minus(other: DerivativeStructure<T, A>): DerivativeStructure<T, A> = other - this
}
@UnstableKMathAPI
public class DerivativeStructureRingExpression<T, A>(
public val algebra: A,
public val bufferFactory: MutableBufferFactory<T>,
public val function: DerivativeStructureRing<T, A>.() -> DerivativeStructure<T, A>,
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
override operator fun invoke(arguments: Map<Symbol, T>): T =
DerivativeStructureRing(algebra, bufferFactory, 0, arguments).function().value
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
with(
DerivativeStructureRing(
algebra,
bufferFactory,
symbols.size,
arguments
)
) { function().derivative(symbols) }
}
}
/**
* A field over commons-math [DerivativeStructure].
*
* @property order The derivation order.
* @param bindings The map of bindings values. All bindings are considered free parameters.
*/
@UnstableKMathAPI
public class DerivativeStructureField<T, A : ExtendedField<T>>(
algebra: A,
bufferFactory: MutableBufferFactory<T>,
order: Int,
bindings: Map<Symbol, T>,
) : DerivativeStructureRing<T, A>(algebra, bufferFactory, order, bindings), ExtendedField<DerivativeStructure<T, A>> {
override fun number(value: Number): DerivativeStructure<T, A> = const(algebra.number(value))
override fun divide(left: DerivativeStructure<T, A>, right: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
left.compiler.checkCompatibility(right.compiler)
val result = DerivativeStructure(this, left.compiler)
left.compiler.divide(left.data, 0, right.data, 0, result.data, 0)
return result
}
override fun sin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.sin(arg.data, 0, result.data, 0)
return result
}
override fun cos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.cos(arg.data, 0, result.data, 0)
return result
}
override fun tan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.tan(arg.data, 0, result.data, 0)
return result
}
override fun asin(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.asin(arg.data, 0, result.data, 0)
return result
}
override fun acos(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.acos(arg.data, 0, result.data, 0)
return result
}
override fun atan(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.atan(arg.data, 0, result.data, 0)
return result
}
override fun sinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.sinh(arg.data, 0, result.data, 0)
return result
}
override fun cosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.cosh(arg.data, 0, result.data, 0)
return result
}
override fun tanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.tanh(arg.data, 0, result.data, 0)
return result
}
override fun asinh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.asinh(arg.data, 0, result.data, 0)
return result
}
override fun acosh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.acosh(arg.data, 0, result.data, 0)
return result
}
override fun atanh(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.atanh(arg.data, 0, result.data, 0)
return result
}
override fun power(arg: DerivativeStructure<T, A>, pow: Number): DerivativeStructure<T, A> = when (pow) {
is Int -> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.pow(arg.data, 0, pow, result.data, 0)
result
}
else -> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.pow(arg.data, 0, pow.toDouble(), result.data, 0)
result
}
}
override fun sqrt(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.sqrt(arg.data, 0, result.data, 0)
return result
}
public fun power(arg: DerivativeStructure<T, A>, pow: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
arg.compiler.checkCompatibility(pow.compiler)
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.pow(arg.data, 0, pow.data, 0, result.data, 0)
return result
}
override fun exp(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.exp(arg.data, 0, result.data, 0)
return result
}
override fun ln(arg: DerivativeStructure<T, A>): DerivativeStructure<T, A> {
val result = DerivativeStructure(this, arg.compiler)
arg.compiler.ln(arg.data, 0, result.data, 0)
return result
}
}
@UnstableKMathAPI
public class DerivativeStructureFieldExpression<T, A : ExtendedField<T>>(
public val algebra: A,
public val bufferFactory: MutableBufferFactory<T>,
public val function: DerivativeStructureField<T, A>.() -> DerivativeStructure<T, A>,
) : DifferentiableExpression<T> {
override operator fun invoke(arguments: Map<Symbol, T>): T =
DerivativeStructureField(algebra, bufferFactory, 0, arguments).function().value
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
with(
DerivativeStructureField(
algebra,
bufferFactory,
symbols.size,
arguments,
)
) { function().derivative(symbols) }
}
}

View File

@ -0,0 +1,59 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
@file:OptIn(UnstableKMathAPI::class)
package space.kscience.kmath.expressions
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFails
internal inline fun diff(
order: Int,
vararg parameters: Pair<Symbol, Double>,
block: DerivativeStructureField<Double, DoubleField>.() -> Unit,
) {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
DerivativeStructureField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block()
}
internal class AutoDiffTest {
private val x by symbol
private val y by symbol
@Test
fun derivativeStructureFieldTest() {
diff(2, x to 1.0, y to 1.0) {
val x = bindSymbol(x)//by binding()
val y = bindSymbol("y")
val z = x * (-sin(x * y) + y) + 2.0
println(z.derivative(x))
println(z.derivative(y, x))
assertEquals(z.derivative(x, y), z.derivative(y, x))
// check improper order cause failure
assertFails { z.derivative(x, x, y) }
}
}
@Test
fun autoDifTest() {
val f = DerivativeStructureFieldExpression(DoubleField, ::DoubleBuffer) {
val x by binding
val y by binding
x.pow(2) + 2 * x * y + y.pow(2) + 1
}
assertEquals(10.0, f(x to 1.0, y to 2.0))
assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0))
assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0))
assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0))
}
}

View File

@ -48,7 +48,8 @@ internal object InternalUtils {
cache.copyInto(
logFactorials,
BEGIN_LOG_FACTORIALS,
BEGIN_LOG_FACTORIALS, endCopy
BEGIN_LOG_FACTORIALS,
endCopy,
)
} else
// All values to be computed