altavir/diff #494
@ -44,9 +44,29 @@ public interface DS<T, A : Ring<T>> {
|
||||
* @see value
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
|
||||
private fun <T, A : Ring<T>> DS<T, A>.getPartialDerivative(vararg orders: Int): T =
|
||||
data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)]
|
||||
|
||||
/**
|
||||
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : Ring<T>> DS<T, A>.derivative(symbols: List<Symbol>): T {
|
||||
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
|
||||
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide a partial derivative with given symbols. On symbol could me mentioned multiple times
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : Ring<T>> DS<T, A>.derivative(vararg symbols: Symbol): T {
|
||||
require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" }
|
||||
val ordersCount: Map<String, Int> = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size }
|
||||
return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray())
|
||||
}
|
||||
|
||||
/**
|
||||
* The value part of the derivative structure.
|
||||
*
|
||||
@ -61,9 +81,67 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||
public val bufferFactory: MutableBufferFactory<T>,
|
||||
public val order: Int,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : ExpressionAlgebra<T, DS<T, A>> {
|
||||
) : ExpressionAlgebra<T, DS<T, A>>, SymbolIndexer {
|
||||
|
||||
/**
|
||||
* Get the compiler for number of free parameters and order.
|
||||
*
|
||||
* @return cached rules set.
|
||||
*/
|
||||
@PublishedApi
|
||||
internal val compiler: DSCompiler<T, A> by lazy {
|
||||
// get the cached compilers
|
||||
val cache: Array<Array<DSCompiler<T, A>?>>? = null
|
||||
|
||||
// we need to create more compilers
|
||||
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
|
||||
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
|
||||
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
|
||||
|
||||
if (cache != null) {
|
||||
// preserve the already created compilers
|
||||
for (i in cache.indices) {
|
||||
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
|
||||
}
|
||||
}
|
||||
|
||||
// create the array in increasing diagonal order
|
||||
for (diag in 0..numberOfVariables + order) {
|
||||
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
|
||||
val p: Int = diag - o
|
||||
if (newCache[p][o] == null) {
|
||||
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
|
||||
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
|
||||
|
||||
newCache[p][o] = DSCompiler(
|
||||
algebra,
|
||||
bufferFactory,
|
||||
p,
|
||||
o,
|
||||
valueCompiler,
|
||||
derivativeCompiler,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return@lazy newCache[numberOfVariables][order]!!
|
||||
}
|
||||
|
||||
private val variables: Map<Symbol, DSSymbol> by lazy {
|
||||
bindings.entries.mapIndexed { index, (key, value) ->
|
||||
key to DSSymbol(
|
||||
index,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
}.toMap()
|
||||
}
|
||||
override val symbols: List<Symbol> = bindings.map { it.key }
|
||||
|
||||
public val numberOfVariables: Int get() = symbols.size
|
||||
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
private fun bufferForVariable(index: Int, value: T): Buffer<T> {
|
||||
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||
buffer[0] = value
|
||||
@ -80,7 +158,7 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
protected inner class DSImpl internal constructor(
|
||||
private inner class DSImpl(
|
||||
override val data: Buffer<T>,
|
||||
) : DS<T, A> {
|
||||
override val derivativeAlgebra: DSAlgebra<T, A> get() = this@DSAlgebra
|
||||
@ -130,63 +208,6 @@ public abstract class DSAlgebra<T, A : Ring<T>>(
|
||||
override val data: Buffer<T> = bufferForVariable(index, value)
|
||||
}
|
||||
|
||||
|
||||
public val numberOfVariables: Int = bindings.size
|
||||
|
||||
/**
|
||||
* Get the compiler for number of free parameters and order.
|
||||
*
|
||||
* @return cached rules set.
|
||||
*/
|
||||
@PublishedApi
|
||||
internal val compiler: DSCompiler<T, A> by lazy {
|
||||
// get the cached compilers
|
||||
val cache: Array<Array<DSCompiler<T, A>?>>? = null
|
||||
|
||||
// we need to create more compilers
|
||||
val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0)
|
||||
val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size)
|
||||
val newCache: Array<Array<DSCompiler<T, A>?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) }
|
||||
|
||||
if (cache != null) {
|
||||
// preserve the already created compilers
|
||||
for (i in cache.indices) {
|
||||
cache[i].copyInto(newCache[i], endIndex = cache[i].size)
|
||||
}
|
||||
}
|
||||
|
||||
// create the array in increasing diagonal order
|
||||
for (diag in 0..numberOfVariables + order) {
|
||||
for (o in max(0, diag - numberOfVariables)..min(order, diag)) {
|
||||
val p: Int = diag - o
|
||||
if (newCache[p][o] == null) {
|
||||
val valueCompiler: DSCompiler<T, A>? = if (p == 0) null else newCache[p - 1][o]!!
|
||||
val derivativeCompiler: DSCompiler<T, A>? = if (o == 0) null else newCache[p][o - 1]!!
|
||||
|
||||
newCache[p][o] = DSCompiler(
|
||||
algebra,
|
||||
bufferFactory,
|
||||
p,
|
||||
o,
|
||||
valueCompiler,
|
||||
derivativeCompiler,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return@lazy newCache[numberOfVariables][order]!!
|
||||
}
|
||||
|
||||
private val variables: Map<Symbol, DSSymbol> = bindings.entries.mapIndexed { index, (key, value) ->
|
||||
key to DSSymbol(
|
||||
index,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
}.toMap()
|
||||
|
||||
|
||||
public override fun const(value: T): DS<T, A> {
|
||||
val buffer = bufferFactory(compiler.size) { algebra.zero }
|
||||
buffer[0] = value
|
||||
|
@ -52,20 +52,20 @@ internal fun <T> MutableBuffer<T>.fill(element: T, fromIndex: Int = 0, toIndex:
|
||||
*
|
||||
* @property freeParameters Number of free parameters.
|
||||
* @property order Derivation order.
|
||||
* @see DerivativeStructure
|
||||
* @see DS
|
||||
*/
|
||||
class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
val algebra: A,
|
||||
val bufferFactory: MutableBufferFactory<T>,
|
||||
val freeParameters: Int,
|
||||
val order: Int,
|
||||
public class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
public val algebra: A,
|
||||
public val bufferFactory: MutableBufferFactory<T>,
|
||||
public val freeParameters: Int,
|
||||
public val order: Int,
|
||||
valueCompiler: DSCompiler<T, A>?,
|
||||
derivativeCompiler: DSCompiler<T, A>?,
|
||||
) {
|
||||
/**
|
||||
* Number of partial derivatives (including the single 0 order derivative element).
|
||||
*/
|
||||
val sizes: Array<IntArray> by lazy {
|
||||
public val sizes: Array<IntArray> by lazy {
|
||||
compileSizes(
|
||||
freeParameters,
|
||||
order,
|
||||
@ -76,7 +76,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
/**
|
||||
* Indirection array for partial derivatives.
|
||||
*/
|
||||
val derivativesIndirection: Array<IntArray> by lazy {
|
||||
internal val derivativesIndirection: Array<IntArray> by lazy {
|
||||
compileDerivativesIndirection(
|
||||
freeParameters, order,
|
||||
valueCompiler, derivativeCompiler,
|
||||
@ -86,7 +86,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
/**
|
||||
* Indirection array of the lower derivative elements.
|
||||
*/
|
||||
val lowerIndirection: IntArray by lazy {
|
||||
internal val lowerIndirection: IntArray by lazy {
|
||||
compileLowerIndirection(
|
||||
freeParameters, order,
|
||||
valueCompiler, derivativeCompiler,
|
||||
@ -96,7 +96,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
/**
|
||||
* Indirection arrays for multiplication.
|
||||
*/
|
||||
val multIndirection: Array<Array<IntArray>> by lazy {
|
||||
internal val multIndirection: Array<Array<IntArray>> by lazy {
|
||||
compileMultiplicationIndirection(
|
||||
freeParameters, order,
|
||||
valueCompiler, derivativeCompiler, lowerIndirection,
|
||||
@ -106,7 +106,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
/**
|
||||
* Indirection arrays for function composition.
|
||||
*/
|
||||
val compositionIndirection: Array<Array<IntArray>> by lazy {
|
||||
internal val compositionIndirection: Array<Array<IntArray>> by lazy {
|
||||
compileCompositionIndirection(
|
||||
freeParameters, order,
|
||||
valueCompiler, derivativeCompiler,
|
||||
@ -120,7 +120,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
* This number includes the single 0 order derivative element, which is
|
||||
* guaranteed to be stored in the first element of the array.
|
||||
*/
|
||||
val size: Int get() = sizes[freeParameters][order]
|
||||
public val size: Int get() = sizes[freeParameters][order]
|
||||
|
||||
/**
|
||||
* Get the index of a partial derivative in the array.
|
||||
@ -147,7 +147,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
* @return index of the partial derivative.
|
||||
* @see getPartialDerivativeOrders
|
||||
*/
|
||||
fun getPartialDerivativeIndex(vararg orders: Int): Int {
|
||||
public fun getPartialDerivativeIndex(vararg orders: Int): Int {
|
||||
// safety check
|
||||
require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" }
|
||||
return getPartialDerivativeIndex(freeParameters, order, sizes, *orders)
|
||||
@ -162,7 +162,7 @@ class DSCompiler<T, out A : Algebra<T>> internal constructor(
|
||||
* @return orders derivation orders with respect to each parameter
|
||||
* @see getPartialDerivativeIndex
|
||||
*/
|
||||
fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index]
|
||||
public fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index]
|
||||
}
|
||||
|
||||
/**
|
||||
|
Loading…
Reference in New Issue
Block a user