Grand derivative refactoring. Phase 3

This commit is contained in:
Alexander Nozik 2022-07-15 17:20:00 +03:00
parent f5fe53a9f2
commit 846a6d2620
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
2 changed files with 96 additions and 75 deletions

View File

@ -44,9 +44,29 @@ public interface DS<T, A : Ring<T>> {
* @see value * @see value
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T = private fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)] data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)]
/**
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
*/
@UnstableKMathAPI
public fun <T, A : Ring<T>> DS<T, A>.derivative(symbols: List<Symbol>): T {
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
}
/**
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
*/
@UnstableKMathAPI
public fun <T, A : Ring<T>> DS<T, A>.derivative(vararg symbols: Symbol): T {
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
}
/** /**
* The value part of the derivative structure. * The value part of the derivative structure.
* *
@ -61,9 +81,67 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
public val bufferFactory: MutableBufferFactory<T>, public val bufferFactory: MutableBufferFactory<T>,
public val order: Int, public val order: Int,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
) : ExpressionAlgebra<T, DS<T, A>> { ) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
/**
* Get the compiler for number of free parameters and order.
*
* @return cached rules set.
*/
@PublishedApi
internal val compiler: DSCompiler<T, A> by lazy {
// get the cached compilers
val cache: Array<Array<DSCompiler<T, A>?>>? = null
// we need to create more compilers
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
if (cache != null) {
// preserve the already created compilers
for (i in cache.indices) {
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
}
}
// create the array in increasing diagonal order
for (diag in 0..numberOfVariables + order) {
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
val p: Int = diag - o
if (newCache[p][o] == null) {
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
newCache[p][o] = DSCompiler(
algebra,
bufferFactory,
p,
o,
valueCompiler,
derivativeCompiler,
)
}
}
}
return@lazy newCache[numberOfVariables][order]!!
}
private val variables: Map<Symbol, DSSymbol> by lazy {
bindings.entries.mapIndexed { index, (key, value) ->
key to DSSymbol(
index,
key,
value,
)
}.toMap()
}
override val symbols: List<Symbol> = bindings.map { it.key }
public val numberOfVariables: Int get() = symbols.size
@OptIn(UnstableKMathAPI::class)
private fun bufferForVariable(index: Int, value: T): Buffer<T> { private fun bufferForVariable(index: Int, value: T): Buffer<T> {
val buffer = bufferFactory(compiler.size) { algebra.zero } val buffer = bufferFactory(compiler.size) { algebra.zero }
buffer[0] = value buffer[0] = value
@ -80,7 +158,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
} }
@UnstableKMathAPI @UnstableKMathAPI
protected inner class DSImpl internal constructor( private inner class DSImpl(
override val data: Buffer<T>, override val data: Buffer<T>,
) : DS<T, A> { ) : DS<T, A> {
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
@ -130,63 +208,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
override val data: Buffer<T> = bufferForVariable(index, value) override val data: Buffer<T> = bufferForVariable(index, value)
} }
public val numberOfVariables: Int = bindings.size
/**
* Get the compiler for number of free parameters and order.
*
* @return cached rules set.
*/
@PublishedApi
internal val compiler: DSCompiler<T, A> by lazy {
// get the cached compilers
val cache: Array<Array<DSCompiler<T, A>?>>? = null
// we need to create more compilers
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
if (cache != null) {
// preserve the already created compilers
for (i in cache.indices) {
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
}
}
// create the array in increasing diagonal order
for (diag in 0..numberOfVariables + order) {
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
val p: Int = diag - o
if (newCache[p][o] == null) {
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
newCache[p][o] = DSCompiler(
algebra,
bufferFactory,
p,
o,
valueCompiler,
derivativeCompiler,
)
}
}
}
return@lazy newCache[numberOfVariables][order]!!
}
private val variables: Map<Symbol, DSSymbol> = bindings.entries.mapIndexed { index, (key, value) ->
key to DSSymbol(
index,
key,
value,
)
}.toMap()
public override fun const(value: T): DS<T, A> { public override fun const(value: T): DS<T, A> {
val buffer = bufferFactory(compiler.size) { algebra.zero } val buffer = bufferFactory(compiler.size) { algebra.zero }
buffer[0] = value buffer[0] = value

View File

@ -52,20 +52,20 @@ internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex:
* *
* @property freeParameters Number of free parameters. * @property freeParameters Number of free parameters.
* @property order Derivation order. * @property order Derivation order.
* @see DerivativeStructure * @see DS
*/ */
class DSCompiler<T, out A : Algebra<T>> internal constructor( public class DSCompiler<T, out A : Algebra<T>> internal constructor(
val algebra: A, public val algebra: A,
val bufferFactory: MutableBufferFactory<T>, public val bufferFactory: MutableBufferFactory<T>,
val freeParameters: Int, public val freeParameters: Int,
val order: Int, public val order: Int,
valueCompiler: DSCompiler<T, A>?, valueCompiler: DSCompiler<T, A>?,
derivativeCompiler: DSCompiler<T, A>?, derivativeCompiler: DSCompiler<T, A>?,
) { ) {
/** /**
* Number of partial derivatives (including the single 0 order derivative element). * Number of partial derivatives (including the single 0 order derivative element).
*/ */
val sizes: Array<IntArray> by lazy { public val sizes: Array<IntArray> by lazy {
compileSizes( compileSizes(
freeParameters, freeParameters,
order, order,
@ -76,7 +76,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
/** /**
* Indirection array for partial derivatives. * Indirection array for partial derivatives.
*/ */
val derivativesIndirection: Array<IntArray> by lazy { internal val derivativesIndirection: Array<IntArray> by lazy {
compileDerivativesIndirection( compileDerivativesIndirection(
freeParameters, order, freeParameters, order,
valueCompiler, derivativeCompiler, valueCompiler, derivativeCompiler,
@ -86,7 +86,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
/** /**
* Indirection array of the lower derivative elements. * Indirection array of the lower derivative elements.
*/ */
val lowerIndirection: IntArray by lazy { internal val lowerIndirection: IntArray by lazy {
compileLowerIndirection( compileLowerIndirection(
freeParameters, order, freeParameters, order,
valueCompiler, derivativeCompiler, valueCompiler, derivativeCompiler,
@ -96,7 +96,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
/** /**
* Indirection arrays for multiplication. * Indirection arrays for multiplication.
*/ */
val multIndirection: Array<Array<IntArray>> by lazy { internal val multIndirection: Array<Array<IntArray>> by lazy {
compileMultiplicationIndirection( compileMultiplicationIndirection(
freeParameters, order, freeParameters, order,
valueCompiler, derivativeCompiler, lowerIndirection, valueCompiler, derivativeCompiler, lowerIndirection,
@ -106,7 +106,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
/** /**
* Indirection arrays for function composition. * Indirection arrays for function composition.
*/ */
val compositionIndirection: Array<Array<IntArray>> by lazy { internal val compositionIndirection: Array<Array<IntArray>> by lazy {
compileCompositionIndirection( compileCompositionIndirection(
freeParameters, order, freeParameters, order,
valueCompiler, derivativeCompiler, valueCompiler, derivativeCompiler,
@ -120,7 +120,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
* This number includes the single 0 order derivative element, which is * This number includes the single 0 order derivative element, which is
* guaranteed to be stored in the first element of the array. * guaranteed to be stored in the first element of the array.
*/ */
val size: Int get() = sizes[freeParameters][order] public val size: Int get() = sizes[freeParameters][order]
/** /**
* Get the index of a partial derivative in the array. * Get the index of a partial derivative in the array.
@ -147,7 +147,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
* @return index of the partial derivative. * @return index of the partial derivative.
* @see getPartialDerivativeOrders * @see getPartialDerivativeOrders
*/ */
fun getPartialDerivativeIndex(vararg orders: Int): Int { public fun getPartialDerivativeIndex(vararg orders: Int): Int {
// safety check // safety check
require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" } require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" }
return getPartialDerivativeIndex(freeParameters, order, sizes, *orders) return getPartialDerivativeIndex(freeParameters, order, sizes, *orders)
@ -162,7 +162,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
* @return orders derivation orders with respect to each parameter * @return orders derivation orders with respect to each parameter
* @see getPartialDerivativeIndex * @see getPartialDerivativeIndex
*/ */
fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index] public fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index]
} }
/** /**