forked from kscience/kmath
Grand derivative refactoring. Phase 3
This commit is contained in:
parent
f5fe53a9f2
commit
846a6d2620
@ -44,9 +44,29 @@ public interface DS<T, A : Ring<T>> {
|
|||||||
* @see value
|
* @see value
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
|
private fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
|
||||||
data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)]
|
data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T, A : Ring<T>> DS<T, A>.derivative(symbols: List<Symbol>): T {
|
||||||
|
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
|
||||||
|
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||||
|
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T, A : Ring<T>> DS<T, A>.derivative(vararg symbols: Symbol): T {
|
||||||
|
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
|
||||||
|
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||||
|
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The value part of the derivative structure.
|
* The value part of the derivative structure.
|
||||||
*
|
*
|
||||||
@ -61,9 +81,67 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
|||||||
public val bufferFactory: MutableBufferFactory<T>,
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
public val order: Int,
|
public val order: Int,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : ExpressionAlgebra<T, DS<T, A>> {
|
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the compiler for number of free parameters and order.
|
||||||
|
*
|
||||||
|
* @return cached rules set.
|
||||||
|
*/
|
||||||
|
@PublishedApi
|
||||||
|
internal val compiler: DSCompiler<T, A> by lazy {
|
||||||
|
// get the cached compilers
|
||||||
|
val cache: Array<Array<DSCompiler<T, A>?>>? = null
|
||||||
|
|
||||||
|
// we need to create more compilers
|
||||||
|
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
|
||||||
|
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
|
||||||
|
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
|
||||||
|
|
||||||
|
if (cache != null) {
|
||||||
|
// preserve the already created compilers
|
||||||
|
for (i in cache.indices) {
|
||||||
|
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the array in increasing diagonal order
|
||||||
|
for (diag in 0..numberOfVariables + order) {
|
||||||
|
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
|
||||||
|
val p: Int = diag - o
|
||||||
|
if (newCache[p][o] == null) {
|
||||||
|
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
|
||||||
|
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
|
||||||
|
|
||||||
|
newCache[p][o] = DSCompiler(
|
||||||
|
algebra,
|
||||||
|
bufferFactory,
|
||||||
|
p,
|
||||||
|
o,
|
||||||
|
valueCompiler,
|
||||||
|
derivativeCompiler,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return@lazy newCache[numberOfVariables][order]!!
|
||||||
|
}
|
||||||
|
|
||||||
|
private val variables: Map<Symbol, DSSymbol> by lazy {
|
||||||
|
bindings.entries.mapIndexed { index, (key, value) ->
|
||||||
|
key to DSSymbol(
|
||||||
|
index,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
}.toMap()
|
||||||
|
}
|
||||||
|
override val symbols: List<Symbol> = bindings.map { it.key }
|
||||||
|
|
||||||
|
public val numberOfVariables: Int get() = symbols.size
|
||||||
|
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
|
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
|
||||||
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||||
buffer[0] = value
|
buffer[0] = value
|
||||||
@ -80,7 +158,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
protected inner class DSImpl internal constructor(
|
private inner class DSImpl(
|
||||||
override val data: Buffer<T>,
|
override val data: Buffer<T>,
|
||||||
) : DS<T, A> {
|
) : DS<T, A> {
|
||||||
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
|
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
|
||||||
@ -130,63 +208,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
|||||||
override val data: Buffer<T> = bufferForVariable(index, value)
|
override val data: Buffer<T> = bufferForVariable(index, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public val numberOfVariables: Int = bindings.size
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the compiler for number of free parameters and order.
|
|
||||||
*
|
|
||||||
* @return cached rules set.
|
|
||||||
*/
|
|
||||||
@PublishedApi
|
|
||||||
internal val compiler: DSCompiler<T, A> by lazy {
|
|
||||||
// get the cached compilers
|
|
||||||
val cache: Array<Array<DSCompiler<T, A>?>>? = null
|
|
||||||
|
|
||||||
// we need to create more compilers
|
|
||||||
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
|
|
||||||
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
|
|
||||||
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
|
|
||||||
|
|
||||||
if (cache != null) {
|
|
||||||
// preserve the already created compilers
|
|
||||||
for (i in cache.indices) {
|
|
||||||
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// create the array in increasing diagonal order
|
|
||||||
for (diag in 0..numberOfVariables + order) {
|
|
||||||
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
|
|
||||||
val p: Int = diag - o
|
|
||||||
if (newCache[p][o] == null) {
|
|
||||||
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
|
|
||||||
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
|
|
||||||
|
|
||||||
newCache[p][o] = DSCompiler(
|
|
||||||
algebra,
|
|
||||||
bufferFactory,
|
|
||||||
p,
|
|
||||||
o,
|
|
||||||
valueCompiler,
|
|
||||||
derivativeCompiler,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return@lazy newCache[numberOfVariables][order]!!
|
|
||||||
}
|
|
||||||
|
|
||||||
private val variables: Map<Symbol, DSSymbol> = bindings.entries.mapIndexed { index, (key, value) ->
|
|
||||||
key to DSSymbol(
|
|
||||||
index,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
}.toMap()
|
|
||||||
|
|
||||||
|
|
||||||
public override fun const(value: T): DS<T, A> {
|
public override fun const(value: T): DS<T, A> {
|
||||||
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||||
buffer[0] = value
|
buffer[0] = value
|
||||||
|
@ -52,20 +52,20 @@ internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex:
|
|||||||
*
|
*
|
||||||
* @property freeParameters Number of free parameters.
|
* @property freeParameters Number of free parameters.
|
||||||
* @property order Derivation order.
|
* @property order Derivation order.
|
||||||
* @see DerivativeStructure
|
* @see DS
|
||||||
*/
|
*/
|
||||||
class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
public class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||||
val algebra: A,
|
public val algebra: A,
|
||||||
val bufferFactory: MutableBufferFactory<T>,
|
public val bufferFactory: MutableBufferFactory<T>,
|
||||||
val freeParameters: Int,
|
public val freeParameters: Int,
|
||||||
val order: Int,
|
public val order: Int,
|
||||||
valueCompiler: DSCompiler<T, A>?,
|
valueCompiler: DSCompiler<T, A>?,
|
||||||
derivativeCompiler: DSCompiler<T, A>?,
|
derivativeCompiler: DSCompiler<T, A>?,
|
||||||
) {
|
) {
|
||||||
/**
|
/**
|
||||||
* Number of partial derivatives (including the single 0 order derivative element).
|
* Number of partial derivatives (including the single 0 order derivative element).
|
||||||
*/
|
*/
|
||||||
val sizes: Array<IntArray> by lazy {
|
public val sizes: Array<IntArray> by lazy {
|
||||||
compileSizes(
|
compileSizes(
|
||||||
freeParameters,
|
freeParameters,
|
||||||
order,
|
order,
|
||||||
@ -76,7 +76,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Indirection array for partial derivatives.
|
* Indirection array for partial derivatives.
|
||||||
*/
|
*/
|
||||||
val derivativesIndirection: Array<IntArray> by lazy {
|
internal val derivativesIndirection: Array<IntArray> by lazy {
|
||||||
compileDerivativesIndirection(
|
compileDerivativesIndirection(
|
||||||
freeParameters, order,
|
freeParameters, order,
|
||||||
valueCompiler, derivativeCompiler,
|
valueCompiler, derivativeCompiler,
|
||||||
@ -86,7 +86,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Indirection array of the lower derivative elements.
|
* Indirection array of the lower derivative elements.
|
||||||
*/
|
*/
|
||||||
val lowerIndirection: IntArray by lazy {
|
internal val lowerIndirection: IntArray by lazy {
|
||||||
compileLowerIndirection(
|
compileLowerIndirection(
|
||||||
freeParameters, order,
|
freeParameters, order,
|
||||||
valueCompiler, derivativeCompiler,
|
valueCompiler, derivativeCompiler,
|
||||||
@ -96,7 +96,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Indirection arrays for multiplication.
|
* Indirection arrays for multiplication.
|
||||||
*/
|
*/
|
||||||
val multIndirection: Array<Array<IntArray>> by lazy {
|
internal val multIndirection: Array<Array<IntArray>> by lazy {
|
||||||
compileMultiplicationIndirection(
|
compileMultiplicationIndirection(
|
||||||
freeParameters, order,
|
freeParameters, order,
|
||||||
valueCompiler, derivativeCompiler, lowerIndirection,
|
valueCompiler, derivativeCompiler, lowerIndirection,
|
||||||
@ -106,7 +106,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
/**
|
/**
|
||||||
* Indirection arrays for function composition.
|
* Indirection arrays for function composition.
|
||||||
*/
|
*/
|
||||||
val compositionIndirection: Array<Array<IntArray>> by lazy {
|
internal val compositionIndirection: Array<Array<IntArray>> by lazy {
|
||||||
compileCompositionIndirection(
|
compileCompositionIndirection(
|
||||||
freeParameters, order,
|
freeParameters, order,
|
||||||
valueCompiler, derivativeCompiler,
|
valueCompiler, derivativeCompiler,
|
||||||
@ -120,7 +120,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
* This number includes the single 0 order derivative element, which is
|
* This number includes the single 0 order derivative element, which is
|
||||||
* guaranteed to be stored in the first element of the array.
|
* guaranteed to be stored in the first element of the array.
|
||||||
*/
|
*/
|
||||||
val size: Int get() = sizes[freeParameters][order]
|
public val size: Int get() = sizes[freeParameters][order]
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the index of a partial derivative in the array.
|
* Get the index of a partial derivative in the array.
|
||||||
@ -147,7 +147,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
* @return index of the partial derivative.
|
* @return index of the partial derivative.
|
||||||
* @see getPartialDerivativeOrders
|
* @see getPartialDerivativeOrders
|
||||||
*/
|
*/
|
||||||
fun getPartialDerivativeIndex(vararg orders: Int): Int {
|
public fun getPartialDerivativeIndex(vararg orders: Int): Int {
|
||||||
// safety check
|
// safety check
|
||||||
require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" }
|
require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" }
|
||||||
return getPartialDerivativeIndex(freeParameters, order, sizes, *orders)
|
return getPartialDerivativeIndex(freeParameters, order, sizes, *orders)
|
||||||
@ -162,7 +162,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
|||||||
* @return orders derivation orders with respect to each parameter
|
* @return orders derivation orders with respect to each parameter
|
||||||
* @see getPartialDerivativeIndex
|
* @see getPartialDerivativeIndex
|
||||||
*/
|
*/
|
||||||
fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index]
|
public fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Loading…
Reference in New Issue
Block a user