Grand derivative refactoring. Phase 3

This commit is contained in:
Alexander Nozik 2022-07-15 17:20:00 +03:00
parent f5fe53a9f2
commit 846a6d2620
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
2 changed files with 96 additions and 75 deletions

View File

@ -44,9 +44,29 @@ public interface DS<T, A : Ring<T>> {
* @see value
*/
@UnstableKMathAPI
public fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
private fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)]
/**
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
*/
@UnstableKMathAPI
public fun <T, A : Ring<T>> DS<T, A>.derivative(symbols: List<Symbol>): T {
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
}
/**
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
*/
@UnstableKMathAPI
public fun <T, A : Ring<T>> DS<T, A>.derivative(vararg symbols: Symbol): T {
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
}
/**
* The value part of the derivative structure.
*
@ -61,9 +81,67 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
public val bufferFactory: MutableBufferFactory<T>,
public val order: Int,
bindings: Map<Symbol, T>,
) : ExpressionAlgebra<T, DS<T, A>> {
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
/**
* Get the compiler for number of free parameters and order.
*
* @return cached rules set.
*/
@PublishedApi
internal val compiler: DSCompiler<T, A> by lazy {
// get the cached compilers
val cache: Array<Array<DSCompiler<T, A>?>>? = null
// we need to create more compilers
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
if (cache != null) {
// preserve the already created compilers
for (i in cache.indices) {
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
}
}
// create the array in increasing diagonal order
for (diag in 0..numberOfVariables + order) {
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
val p: Int = diag - o
if (newCache[p][o] == null) {
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
newCache[p][o] = DSCompiler(
algebra,
bufferFactory,
p,
o,
valueCompiler,
derivativeCompiler,
)
}
}
}
return@lazy newCache[numberOfVariables][order]!!
}
private val variables: Map<Symbol, DSSymbol> by lazy {
bindings.entries.mapIndexed { index, (key, value) ->
key to DSSymbol(
index,
key,
value,
)
}.toMap()
}
override val symbols: List<Symbol> = bindings.map { it.key }
public val numberOfVariables: Int get() = symbols.size
@OptIn(UnstableKMathAPI::class)
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
val buffer = bufferFactory(compiler.size) { algebra.zero }
buffer[0] = value
@ -80,7 +158,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
}
@UnstableKMathAPI
protected inner class DSImpl internal constructor(
private inner class DSImpl(
override val data: Buffer<T>,
) : DS<T, A> {
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
@ -130,63 +208,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
override val data: Buffer<T> = bufferForVariable(index, value)
}
public val numberOfVariables: Int = bindings.size
/**
* Get the compiler for number of free parameters and order.
*
* @return cached rules set.
*/
@PublishedApi
internal val compiler: DSCompiler<T, A> by lazy {
// get the cached compilers
val cache: Array<Array<DSCompiler<T, A>?>>? = null
// we need to create more compilers
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
if (cache != null) {
// preserve the already created compilers
for (i in cache.indices) {
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
}
}
// create the array in increasing diagonal order
for (diag in 0..numberOfVariables + order) {
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
val p: Int = diag - o
if (newCache[p][o] == null) {
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
newCache[p][o] = DSCompiler(
algebra,
bufferFactory,
p,
o,
valueCompiler,
derivativeCompiler,
)
}
}
}
return@lazy newCache[numberOfVariables][order]!!
}
private val variables: Map<Symbol, DSSymbol> = bindings.entries.mapIndexed { index, (key, value) ->
key to DSSymbol(
index,
key,
value,
)
}.toMap()
public override fun const(value: T): DS<T, A> {
val buffer = bufferFactory(compiler.size) { algebra.zero }
buffer[0] = value

View File

@ -52,20 +52,20 @@ internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex:
*
* @property freeParameters Number of free parameters.
* @property order Derivation order.
* @see DerivativeStructure
* @see DS
*/
class DSCompiler<T, out A : Algebra<T>> internal constructor(
val algebra: A,
val bufferFactory: MutableBufferFactory<T>,
val freeParameters: Int,
val order: Int,
public class DSCompiler<T, out A : Algebra<T>> internal constructor(
public val algebra: A,
public val bufferFactory: MutableBufferFactory<T>,
public val freeParameters: Int,
public val order: Int,
valueCompiler: DSCompiler<T, A>?,
derivativeCompiler: DSCompiler<T, A>?,
) {
/**
* Number of partial derivatives (including the single 0 order derivative element).
*/
val sizes: Array<IntArray> by lazy {
public val sizes: Array<IntArray> by lazy {
compileSizes(
freeParameters,
order,
@ -76,7 +76,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
/**
* Indirection array for partial derivatives.
*/
val derivativesIndirection: Array<IntArray> by lazy {
internal val derivativesIndirection: Array<IntArray> by lazy {
compileDerivativesIndirection(
freeParameters, order,
valueCompiler, derivativeCompiler,
@ -86,7 +86,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
/**
* Indirection array of the lower derivative elements.
*/
val lowerIndirection: IntArray by lazy {
internal val lowerIndirection: IntArray by lazy {
compileLowerIndirection(
freeParameters, order,
valueCompiler, derivativeCompiler,
@ -96,7 +96,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
/**
* Indirection arrays for multiplication.
*/
val multIndirection: Array<Array<IntArray>> by lazy {
internal val multIndirection: Array<Array<IntArray>> by lazy {
compileMultiplicationIndirection(
freeParameters, order,
valueCompiler, derivativeCompiler, lowerIndirection,
@ -106,7 +106,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
/**
* Indirection arrays for function composition.
*/
val compositionIndirection: Array<Array<IntArray>> by lazy {
internal val compositionIndirection: Array<Array<IntArray>> by lazy {
compileCompositionIndirection(
freeParameters, order,
valueCompiler, derivativeCompiler,
@ -120,7 +120,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
* This number includes the single 0 order derivative element, which is
* guaranteed to be stored in the first element of the array.
*/
val size: Int get() = sizes[freeParameters][order]
public val size: Int get() = sizes[freeParameters][order]
/**
* Get the index of a partial derivative in the array.
@ -147,7 +147,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
* @return index of the partial derivative.
* @see getPartialDerivativeOrders
*/
fun getPartialDerivativeIndex(vararg orders: Int): Int {
public fun getPartialDerivativeIndex(vararg orders: Int): Int {
// safety check
require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" }
return getPartialDerivativeIndex(freeParameters, order, sizes, *orders)
@ -162,7 +162,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
* @return orders derivation orders with respect to each parameter
* @see getPartialDerivativeIndex
*/
fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index]
public fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index]
}
/**