diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt index d9fc46b47..506fbd001 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt @@ -44,9 +44,29 @@ public interface DS> { * @see value */ @UnstableKMathAPI -public fun > DS.getPartialDerivative(vararg orders: Int): T = +private fun > DS.getPartialDerivative(vararg orders: Int): T = data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)] +/** + * Provide a partial derivative with given symbols. On symbol could me mentioned multiple times + */ +@UnstableKMathAPI +public fun > DS.derivative(symbols: List): T { + require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" } + val ordersCount: Map = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size } + return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray()) +} + +/** + * Provide a partial derivative with given symbols. On symbol could me mentioned multiple times + */ +@UnstableKMathAPI +public fun > DS.derivative(vararg symbols: Symbol): T { + require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" } + val ordersCount: Map = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size } + return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray()) +} + /** * The value part of the derivative structure. * @@ -61,9 +81,67 @@ public abstract class DSAlgebra>( public val bufferFactory: MutableBufferFactory, public val order: Int, bindings: Map, -) : ExpressionAlgebra> { +) : ExpressionAlgebra>, SymbolIndexer { + + /** + * Get the compiler for number of free parameters and order. + * + * @return cached rules set. + */ + @PublishedApi + internal val compiler: DSCompiler by lazy { + // get the cached compilers + val cache: Array?>>? = null + + // we need to create more compilers + val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0) + val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size) + val newCache: Array?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) } + + if (cache != null) { + // preserve the already created compilers + for (i in cache.indices) { + cache[i].copyInto(newCache[i], endIndex = cache[i].size) + } + } + + // create the array in increasing diagonal order + for (diag in 0..numberOfVariables + order) { + for (o in max(0, diag - numberOfVariables)..min(order, diag)) { + val p: Int = diag - o + if (newCache[p][o] == null) { + val valueCompiler: DSCompiler? = if (p == 0) null else newCache[p - 1][o]!! + val derivativeCompiler: DSCompiler? = if (o == 0) null else newCache[p][o - 1]!! + + newCache[p][o] = DSCompiler( + algebra, + bufferFactory, + p, + o, + valueCompiler, + derivativeCompiler, + ) + } + } + } + + return@lazy newCache[numberOfVariables][order]!! + } + + private val variables: Map by lazy { + bindings.entries.mapIndexed { index, (key, value) -> + key to DSSymbol( + index, + key, + value, + ) + }.toMap() + } + override val symbols: List = bindings.map { it.key } + + public val numberOfVariables: Int get() = symbols.size + - @OptIn(UnstableKMathAPI::class) private fun bufferForVariable(index: Int, value: T): Buffer { val buffer = bufferFactory(compiler.size) { algebra.zero } buffer[0] = value @@ -80,7 +158,7 @@ public abstract class DSAlgebra>( } @UnstableKMathAPI - protected inner class DSImpl internal constructor( + private inner class DSImpl( override val data: Buffer, ) : DS { override val derivativeAlgebra: DSAlgebra get() = this@DSAlgebra @@ -130,63 +208,6 @@ public abstract class DSAlgebra>( override val data: Buffer = bufferForVariable(index, value) } - - public val numberOfVariables: Int = bindings.size - - /** - * Get the compiler for number of free parameters and order. - * - * @return cached rules set. - */ - @PublishedApi - internal val compiler: DSCompiler by lazy { - // get the cached compilers - val cache: Array?>>? = null - - // we need to create more compilers - val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0) - val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size) - val newCache: Array?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) } - - if (cache != null) { - // preserve the already created compilers - for (i in cache.indices) { - cache[i].copyInto(newCache[i], endIndex = cache[i].size) - } - } - - // create the array in increasing diagonal order - for (diag in 0..numberOfVariables + order) { - for (o in max(0, diag - numberOfVariables)..min(order, diag)) { - val p: Int = diag - o - if (newCache[p][o] == null) { - val valueCompiler: DSCompiler? = if (p == 0) null else newCache[p - 1][o]!! - val derivativeCompiler: DSCompiler? = if (o == 0) null else newCache[p][o - 1]!! - - newCache[p][o] = DSCompiler( - algebra, - bufferFactory, - p, - o, - valueCompiler, - derivativeCompiler, - ) - } - } - } - - return@lazy newCache[numberOfVariables][order]!! - } - - private val variables: Map = bindings.entries.mapIndexed { index, (key, value) -> - key to DSSymbol( - index, - key, - value, - ) - }.toMap() - - public override fun const(value: T): DS { val buffer = bufferFactory(compiler.size) { algebra.zero } buffer[0] = value diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt index e0050cf03..b5b2988a3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt @@ -52,20 +52,20 @@ internal fun MutableBuffer.fill(element: T, fromIndex: Int = 0, toIndex: * * @property freeParameters Number of free parameters. * @property order Derivation order. - * @see DerivativeStructure + * @see DS */ -class DSCompiler> internal constructor( - val algebra: A, - val bufferFactory: MutableBufferFactory, - val freeParameters: Int, - val order: Int, +public class DSCompiler> internal constructor( + public val algebra: A, + public val bufferFactory: MutableBufferFactory, + public val freeParameters: Int, + public val order: Int, valueCompiler: DSCompiler?, derivativeCompiler: DSCompiler?, ) { /** * Number of partial derivatives (including the single 0 order derivative element). */ - val sizes: Array by lazy { + public val sizes: Array by lazy { compileSizes( freeParameters, order, @@ -76,7 +76,7 @@ class DSCompiler> internal constructor( /** * Indirection array for partial derivatives. */ - val derivativesIndirection: Array by lazy { + internal val derivativesIndirection: Array by lazy { compileDerivativesIndirection( freeParameters, order, valueCompiler, derivativeCompiler, @@ -86,7 +86,7 @@ class DSCompiler> internal constructor( /** * Indirection array of the lower derivative elements. */ - val lowerIndirection: IntArray by lazy { + internal val lowerIndirection: IntArray by lazy { compileLowerIndirection( freeParameters, order, valueCompiler, derivativeCompiler, @@ -96,7 +96,7 @@ class DSCompiler> internal constructor( /** * Indirection arrays for multiplication. */ - val multIndirection: Array> by lazy { + internal val multIndirection: Array> by lazy { compileMultiplicationIndirection( freeParameters, order, valueCompiler, derivativeCompiler, lowerIndirection, @@ -106,7 +106,7 @@ class DSCompiler> internal constructor( /** * Indirection arrays for function composition. */ - val compositionIndirection: Array> by lazy { + internal val compositionIndirection: Array> by lazy { compileCompositionIndirection( freeParameters, order, valueCompiler, derivativeCompiler, @@ -120,7 +120,7 @@ class DSCompiler> internal constructor( * This number includes the single 0 order derivative element, which is * guaranteed to be stored in the first element of the array. */ - val size: Int get() = sizes[freeParameters][order] + public val size: Int get() = sizes[freeParameters][order] /** * Get the index of a partial derivative in the array. @@ -147,7 +147,7 @@ class DSCompiler> internal constructor( * @return index of the partial derivative. * @see getPartialDerivativeOrders */ - fun getPartialDerivativeIndex(vararg orders: Int): Int { + public fun getPartialDerivativeIndex(vararg orders: Int): Int { // safety check require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" } return getPartialDerivativeIndex(freeParameters, order, sizes, *orders) @@ -162,7 +162,7 @@ class DSCompiler> internal constructor( * @return orders derivation orders with respect to each parameter * @see getPartialDerivativeIndex */ - fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index] + public fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index] } /**