From 7b1bdc21a4ba5eb5d36d32afeffde7055c26f88b Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 12 Aug 2021 16:37:53 +0300 Subject: [PATCH 1/6] Copy DerivativeStructure to multiplatform --- README.md | 2 +- docs/templates/README-TEMPLATE.md | 2 +- .../kscience/kmath/expressions/DSCompiler.kt | 1541 +++++++++++++++++ .../kmath/expressions/DerivativeStructure.kt | 186 ++ .../DerivativeStructureExpression.kt | 332 ++++ .../DerivativeStructureExpressionTest.kt | 59 + .../kscience/kmath/internal/InternalUtils.kt | 3 +- 7 files changed, 2122 insertions(+), 3 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt create mode 100644 kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt create mode 100644 kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt create mode 100644 kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt diff --git a/README.md b/README.md index 92260716e..aeedfefbb 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels: * **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could break any moment. You can still use it, but be sure to fix the specific version. * **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked - with `@UnstableKmathAPI` or other stability warning annotations. + with `@UnstableKMathAPI` or other stability warning annotations. * **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor versions, but not in patch versions. API is protected with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool. diff --git a/docs/templates/README-TEMPLATE.md b/docs/templates/README-TEMPLATE.md index b0c418697..2e64a3e09 100644 --- a/docs/templates/README-TEMPLATE.md +++ b/docs/templates/README-TEMPLATE.md @@ -44,7 +44,7 @@ module definitions below. The module stability could have the following levels: * **PROTOTYPE**. On this level there are no compatibility guarantees. All methods and classes form those modules could break any moment. You can still use it, but be sure to fix the specific version. * **EXPERIMENTAL**. The general API is decided, but some changes could be made. Volatile API is marked - with `@UnstableKmathAPI` or other stability warning annotations. + with `@UnstableKMathAPI` or other stability warning annotations. * **DEVELOPMENT**. API breaking generally follows semantic versioning ideology. There could be changes in minor versions, but not in patch versions. API is protected with [binary-compatibility-validator](https://github.com/Kotlin/binary-compatibility-validator) tool. diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt new file mode 100644 index 000000000..bb88ce52c --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt @@ -0,0 +1,1541 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.expressions + +import space.kscience.kmath.operations.* +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.MutableBuffer +import space.kscience.kmath.structures.MutableBufferFactory +import kotlin.math.max +import kotlin.math.min + +internal fun MutableBuffer.fill(element: T, fromIndex: Int = 0, toIndex: Int = size) { + for (i in fromIndex until toIndex) this[i] = element +} + +/** + * Class holding "compiled" computation rules for derivative structures. + * + * This class implements the computation rules described in Dan Kalman's paper + * [Doubly Recursive Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf), + * Mathematics Magazine, vol. 75, no. 3, June 2002. However, to avoid performances bottlenecks, the recursive rules are + * "compiled" once in an unfolded form. This class does this recursion unrolling and stores the computation rules as + * simple loops with pre-computed indirection arrays. + * + * This class maps all derivative computation into single dimension arrays that hold the value and partial derivatives. + * The class does not hold these arrays, which remains under the responsibility of the caller. For each combination of + * number of free parameters and derivation order, only one compiler is necessary, and this compiler will be used to + * perform computations on all arrays provided to it, which can represent hundreds or thousands of different parameters + * kept together with all their partial derivatives. + * + * The arrays on which compilers operate contain only the partial derivatives together with the 0th + * derivative, i.e., the value. The partial derivatives are stored in a compiler-specific order, which can be retrieved + * using methods [getPartialDerivativeIndex] and [getPartialDerivativeOrders]. The value is guaranteed to be stored as + * the first element (i.e., the [getPartialDerivativeIndex] method returns 0 when called with 0 for all derivation + * orders and [getPartialDerivativeOrders] returns an array filled with 0 when called with 0 as the index). + * + * Note that the ordering changes with number of parameters and derivation order. For example given 2 parameters x and + * y, df/dy is stored at index 2 when derivation order is set to 1 (in this case the array has three elements: f, + * df/dx and df/dy). If derivation order is set to 2, then df/dy will be stored at index 3 (in this case the array has + * six elements: f, df/dx, df/dxdx, df/dy, df/dxdy and df/dydy). + * + * Given this structure, users can perform some simple operations like adding, subtracting or multiplying constants and + * negating the elements by themselves, knowing if they want to mutate their array or create a new array. These simple + * operations are not provided by the compiler. The compiler provides only the more complex operations between several + * arrays. + * + * Derived from + * [Commons Math's `DSCompiler`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DSCompiler.java). + * + * @property freeParameters Number of free parameters. + * @property order Derivation order. + * @see DerivativeStructure + */ +internal class DSCompiler> internal constructor( + val algebra: A, + val bufferFactory: MutableBufferFactory, + val freeParameters: Int, + val order: Int, + valueCompiler: DSCompiler?, + derivativeCompiler: DSCompiler?, +) { + /** + * Number of partial derivatives (including the single 0 order derivative element). + */ + val sizes: Array by lazy { + compileSizes( + freeParameters, + order, + valueCompiler, + ) + } + + /** + * Indirection array for partial derivatives. + */ + val derivativesIndirection: Array by lazy { + compileDerivativesIndirection( + freeParameters, order, + valueCompiler, derivativeCompiler, + ) + } + + /** + * Indirection array of the lower derivative elements. + */ + val lowerIndirection: IntArray by lazy { + compileLowerIndirection( + freeParameters, order, + valueCompiler, derivativeCompiler, + ) + } + + /** + * Indirection arrays for multiplication. + */ + val multIndirection: Array> by lazy { + compileMultiplicationIndirection( + freeParameters, order, + valueCompiler, derivativeCompiler, lowerIndirection, + ) + } + + /** + * Indirection arrays for function composition. + */ + val compositionIndirection: Array> by lazy { + compileCompositionIndirection( + freeParameters, order, + valueCompiler, derivativeCompiler, + sizes, derivativesIndirection, + ) + } + + /** + * Get the array size required for holding partial derivatives' data. + * + * This number includes the single 0 order derivative element, which is + * guaranteed to be stored in the first element of the array. + */ + val size: Int + get() = sizes[freeParameters][order] + + /** + * Get the index of a partial derivative in the array. + * + * If all orders are set to 0, then the 0th order derivative is returned, which is the value of the + * function. + * + * The indices of derivatives are between 0 and [size] − 1. Their specific order is fixed for a given compiler, but + * otherwise not publicly specified. There are however some simple cases which have guaranteed indices: + * + * * the index of 0th order derivative is always 0 + * * if there is only 1 [freeParameters], then the + * derivatives are sorted in increasing derivation order (i.e., f at index 0, df/dp + * at index 1, d2f/dp2 at index 2 … + * dkf/dpk at index k), + * * if the [order] is 1, then the derivatives + * are sorted in increasing free parameter order (i.e., f at index 0, df/dx1 + * at index 1, df/dx2 at index 2 … df/dxk at index k), + * * all other cases are not publicly specified. + * + * This method is the inverse of method [getPartialDerivativeOrders]. + * + * @param orders derivation orders with respect to each parameter. + * @return index of the partial derivative. + * @see getPartialDerivativeOrders + */ + fun getPartialDerivativeIndex(vararg orders: Int): Int { + // safety check + require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" } + return getPartialDerivativeIndex(freeParameters, order, sizes, *orders) + } + + /** + * Get the derivation orders for a specific index in the array. + * + * This method is the inverse of [getPartialDerivativeIndex]. + * + * @param index of the partial derivative + * @return orders derivation orders with respect to each parameter + * @see getPartialDerivativeIndex + */ + fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index] +} + +/** + * Compute natural logarithm of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for logarithm the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.ln( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : ExponentialOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + function[0] = ln(operand[operandOffset]) + + if (order > 0) { + val inv = one / operand[operandOffset] + var xk = inv + for (i in 1..order) { + function[i] = xk + xk *= (-i * inv) + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute integer power of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param n power to apply. + * @param result array where result must be stored (for power the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.pow( + operand: Buffer, + operandOffset: Int, + n: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : PowerOperations = algebra { + if (n == 0) { + // special case, x^0 = 1 for all x + result[resultOffset] = one + result.fill(zero, resultOffset + 1, resultOffset + size) + return + } + + // create the power function value and derivatives + // [x^n, nx^(n-1), n(n-1)x^(n-2), ... ] + val function = bufferFactory(1 + order) { zero } + + if (n > 0) { + // strictly positive power + val maxOrder: Int = min(order, n) + var xk = operand[operandOffset] pow n - maxOrder + for (i in maxOrder downTo 1) { + function[i] = xk + xk *= operand[operandOffset] + } + function[0] = xk + } else { + // strictly negative power + val inv = one / operand[operandOffset] + var xk = inv pow -n + + for (i in 0..order) { + function[i] = xk + xk *= inv + } + } + + var coefficient = number(n) + + for (i in 1..order) { + function[i] = function[i] * coefficient + coefficient *= (n - i).toDouble() + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute exponential of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for exponential the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.exp( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Ring, A : ScaleOperations, A : ExponentialOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + function.fill(exp(operand[operandOffset])) + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute square root of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for nth root the result array *cannot* be the input + * array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.sqrt( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : PowerOperations = algebra { + // create the function value and derivatives + // [x^(1/n), (1/n)x^((1/n)-1), (1-n)/n^2x^((1/n)-2), ... ] + val function = bufferFactory(1 + order) { zero } + function[0] = sqrt(operand[operandOffset]) + var xk: T = 0.5 * one / function[0] + val xReciprocal = one / operand[operandOffset] + + for (i in 1..order) { + function[i] = xk + xk *= xReciprocal * (0.5 - i) + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute cosine of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for cosine the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.cos( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int, +) where A : Ring, A : TrigonometricOperations, A : ScaleOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + function[0] = cos(operand[operandOffset]) + + if (order > 0) { + function[1] = -sin(operand[operandOffset]) + for (i in 2..order) { + function[i] = -function[i - 2] + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute power of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param p power to apply. + * @param result array where result must be stored (for power the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.pow( + operand: Buffer, + operandOffset: Int, + p: Double, + result: MutableBuffer, + resultOffset: Int +) where A : Ring, A : NumericAlgebra, A : PowerOperations, A : ScaleOperations = algebra { + // create the function value and derivatives + // [x^p, px^(p-1), p(p-1)x^(p-2), ... ] + val function = bufferFactory(1 + order) { zero } + var xk = operand[operandOffset] pow p - order + + for (i in order downTo 1) { + function[i] = xk + xk *= operand[operandOffset] + } + + function[0] = xk + var coefficient = p + + for (i in 1..order) { + function[i] = function[i] * coefficient + coefficient *= p - i + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute tangent of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for tangent the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.tan( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Ring, A : TrigonometricOperations, A : ScaleOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + val t = tan(operand[operandOffset]) + function[0] = t + + if (order > 0) { + + // the nth order derivative of tan has the form: + // dn(tan(x)/dxn = P_n(tan(x)) + // where P_n(t) is a degree n+1 polynomial with same parity as n+1 + // P_0(t) = t, P_1(t) = 1 + t^2, P_2(t) = 2 t (1 + t^2) ... + // the general recurrence relation for P_n is: + // P_n(x) = (1+t^2) P_(n-1)'(t) + // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array + val p = bufferFactory(order + 2) { zero } + p[1] = one + val t2 = t * t + for (n in 1..order) { + + // update and evaluate polynomial P_n(t) + var v = one + p[n + 1] = n * p[n] + var k = n + 1 + while (k >= 0) { + v = v * t2 + p[k] + if (k > 2) { + p[k - 2] = (k - 1) * p[k - 1] + (k - 3) * p[k - 3] + } else if (k == 2) { + p[0] = p[1] + } + k -= 2 + } + if (n and 0x1 == 0) { + v *= t + } + function[n] = v + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute power of a derivative structure. + * + * @param x array holding the base. + * @param xOffset offset of the base in its array. + * @param y array holding the exponent. + * @param yOffset offset of the exponent in its array. + * @param result array where result must be stored (for power the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.pow( + x: Buffer, + xOffset: Int, + y: Buffer, + yOffset: Int, + result: MutableBuffer, + resultOffset: Int, +) where A : Field, A : ExponentialOperations = algebra { + val logX = bufferFactory(size) { zero } + ln(x, xOffset, logX, 0) + val yLogX = bufferFactory(size) { zero } + multiply(logX, 0, y, yOffset, yLogX, 0) + exp(yLogX, 0, result, resultOffset) +} + +/** + * Compute sine of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for sine the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.sin( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Ring, A : ScaleOperations, A : TrigonometricOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + function[0] = sin(operand[operandOffset]) + if (order > 0) { + function[1] = cos(operand[operandOffset]) + for (i in 2..order) { + function[i] = -function[i - 2] + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute arc cosine of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for arc cosine the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.acos( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : TrigonometricOperations, A : PowerOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + val x = operand[operandOffset] + function[0] = acos(x) + if (order > 0) { + // the nth order derivative of acos has the form: + // dn(acos(x)/dxn = P_n(x) / [1 - x^2]^((2n-1)/2) + // where P_n(x) is a degree n-1 polynomial with same parity as n-1 + // P_1(x) = -1, P_2(x) = -x, P_3(x) = -2x^2 - 1 ... + // the general recurrence relation for P_n is: + // P_n(x) = (1-x^2) P_(n-1)'(x) + (2n-3) x P_(n-1)(x) + // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array + val p = bufferFactory(order) { zero } + p[0] = -one + val x2 = x * x + val f = one / (one - x2) + var coeff = sqrt(f) + function[1] = coeff * p[0] + + for (n in 2..order) { + // update and evaluate polynomial P_n(x) + var v = zero + p[n - 1] = (n - 1) * p[n - 2] + var k = n - 1 + + while (k >= 0) { + v = v * x2 + p[k] + if (k > 2) { + p[k - 2] = (k - 1) * p[k - 1] + (2 * n - k) * p[k - 3] + } else if (k == 2) { + p[0] = p[1] + } + k -= 2 + } + + if (n and 0x1 == 0) { + v *= x + } + + coeff *= f + function[n] = coeff * v + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute arc sine of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for arc sine the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.asin( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : TrigonometricOperations, A : PowerOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + val x = operand[operandOffset] + function[0] = asin(x) + if (order > 0) { + // the nth order derivative of asin has the form: + // dn(asin(x)/dxn = P_n(x) / [1 - x^2]^((2n-1)/2) + // where P_n(x) is a degree n-1 polynomial with same parity as n-1 + // P_1(x) = 1, P_2(x) = x, P_3(x) = 2x^2 + 1 ... + // the general recurrence relation for P_n is: + // P_n(x) = (1-x^2) P_(n-1)'(x) + (2n-3) x P_(n-1)(x) + // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array + val p = bufferFactory(order) { zero } + p[0] = one + val x2 = x * x + val f = one / (one - x2) + var coeff = sqrt(f) + function[1] = coeff * p[0] + for (n in 2..order) { + + // update and evaluate polynomial P_n(x) + var v = zero + p[n - 1] = (n - 1) * p[n - 2] + var k = n - 1 + while (k >= 0) { + v = v * x2 + p[k] + if (k > 2) { + p[k - 2] = (k - 1) * p[k - 1] + (2 * n - k) * p[k - 3] + } else if (k == 2) { + p[0] = p[1] + } + k -= 2 + } + if (n and 0x1 == 0) { + v *= x + } + coeff *= f + function[n] = coeff * v + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute arc tangent of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for arc tangent the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.atan( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : TrigonometricOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + val x = operand[operandOffset] + function[0] = atan(x) + + if (order > 0) { + // the nth order derivative of atan has the form: + // dn(atan(x)/dxn = Q_n(x) / (1 + x^2)^n + // where Q_n(x) is a degree n-1 polynomial with same parity as n-1 + // Q_1(x) = 1, Q_2(x) = -2x, Q_3(x) = 6x^2 - 2 ... + // the general recurrence relation for Q_n is: + // Q_n(x) = (1+x^2) Q_(n-1)'(x) - 2(n-1) x Q_(n-1)(x) + // as per polynomial parity, we can store coefficients of both Q_(n-1) and Q_n in the same array + val q = bufferFactory(order) { zero } + q[0] = one + val x2 = x * x + val f = one / (one + x2) + var coeff = f + function[1] = coeff * q[0] + for (n in 2..order) { + + // update and evaluate polynomial Q_n(x) + var v = zero + q[n - 1] = -n * q[n - 2] + var k = n - 1 + while (k >= 0) { + v = v * x2 + q[k] + if (k > 2) { + q[k - 2] = (k - 1) * q[k - 1] + (k - 1 - 2 * n) * q[k - 3] + } else if (k == 2) { + q[0] = q[1] + } + k -= 2 + } + if (n and 0x1 == 0) { + v *= x + } + coeff *= f + function[n] = coeff * v + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute hyperbolic cosine of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for hyperbolic cosine the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.cosh( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Ring, A : ScaleOperations, A : ExponentialOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + function[0] = cosh(operand[operandOffset]) + + if (order > 0) { + function[1] = sinh(operand[operandOffset]) + for (i in 2..order) { + function[i] = function[i - 2] + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute hyperbolic tangent of a derivative structure. + * + * @param operand array holding the operand + * @param operandOffset offset of the operand in its array + * @param result array where result must be stored (for hyperbolic tangent the result array *cannot* be the input + * array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.tanh( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : ExponentialOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + val t = tanh(operand[operandOffset]) + function[0] = t + if (order > 0) { + + // the nth order derivative of tanh has the form: + // dn(tanh(x)/dxn = P_n(tanh(x)) + // where P_n(t) is a degree n+1 polynomial with same parity as n+1 + // P_0(t) = t, P_1(t) = 1 - t^2, P_2(t) = -2 t (1 - t^2) ... + // the general recurrence relation for P_n is: + // P_n(x) = (1-t^2) P_(n-1)'(t) + // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array + val p = bufferFactory(order + 2) { zero } + p[1] = one + val t2 = t * t + for (n in 1..order) { + + // update and evaluate polynomial P_n(t) + var v = zero + p[n + 1] = -n * p[n] + var k = n + 1 + while (k >= 0) { + v = v * t2 + p[k] + if (k > 2) { + p[k - 2] = (k - 1) * p[k - 1] - (k - 3) * p[k - 3] + } else if (k == 2) { + p[0] = p[1] + } + k -= 2 + } + if (n and 0x1 == 0) { + v *= t + } + function[n] = v + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute inverse hyperbolic cosine of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for inverse hyperbolic cosine the result array *cannot* be the input + * array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.acosh( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : ExponentialOperations, A : PowerOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + val x = operand[operandOffset] + function[0] = acosh(x) + + if (order > 0) { + // the nth order derivative of acosh has the form: + // dn(acosh(x)/dxn = P_n(x) / [x^2 - 1]^((2n-1)/2) + // where P_n(x) is a degree n-1 polynomial with same parity as n-1 + // P_1(x) = 1, P_2(x) = -x, P_3(x) = 2x^2 + 1 ... + // the general recurrence relation for P_n is: + // P_n(x) = (x^2-1) P_(n-1)'(x) - (2n-3) x P_(n-1)(x) + // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array + val p = bufferFactory(order) { zero } + p[0] = one + val x2 = x * x + val f = one / (x2 - one) + var coeff = sqrt(f) + function[1] = coeff * p[0] + for (n in 2..order) { + + // update and evaluate polynomial P_n(x) + var v = zero + p[n - 1] = (1 - n) * p[n - 2] + var k = n - 1 + while (k >= 0) { + v = v * x2 + p[k] + if (k > 2) { + p[k - 2] = (1 - k) * p[k - 1] + (k - 2 * n) * p[k - 3] + } else if (k == 2) { + p[0] = -p[1] + } + k -= 2 + } + if (n and 0x1 == 0) { + v *= x + } + + coeff *= f + function[n] = coeff * v + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Compute composition of a derivative structure by a function. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param f array of value and derivatives of the function at the current point (i.e. at `operand[operandOffset]`). + * @param result array where result must be stored (for composition the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.compose( + operand: Buffer, + operandOffset: Int, + f: Buffer, + result: MutableBuffer, + resultOffset: Int, +) where A : Ring, A : ScaleOperations = algebra { + for (i in compositionIndirection.indices) { + val mappingI = compositionIndirection[i] + var r = zero + for (j in mappingI.indices) { + val mappingIJ = mappingI[j] + var product = mappingIJ[0] * f[mappingIJ[1]] + for (k in 2 until mappingIJ.size) { + product *= operand[operandOffset + mappingIJ[k]] + } + r += product + } + result[resultOffset + i] = r + } +} + +/** + * Compute hyperbolic sine of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for hyperbolic sine the result array *cannot* be the input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.sinh( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : ExponentialOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + function[0] = sinh(operand[operandOffset]) + + if (order > 0) { + function[1] = cosh(operand[operandOffset]) + for (i in 2..order) { + function[i] = function[i - 2] + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Perform division of two derivative structures. + * + * @param lhs array holding left-hand side of division. + * @param lhsOffset offset of the left-hand side in its array. + * @param rhs array right-hand side of division. + * @param rhsOffset offset of the right-hand side in its array. + * @param result array where result must be stored (for division the result array *cannot* be one of the input arrays). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.divide( + lhs: Buffer, + lhsOffset: Int, + rhs: Buffer, + rhsOffset: Int, + result: MutableBuffer, + resultOffset: Int, +) where A : Field, A : PowerOperations, A : ScaleOperations = algebra { + val reciprocal = bufferFactory(size) { zero } + pow(rhs, lhsOffset, -1, reciprocal, 0) + multiply(lhs, lhsOffset, reciprocal, rhsOffset, result, resultOffset) +} + +/** + * Perform multiplication of two derivative structures. + * + * @param lhs array holding left-hand side of multiplication. + * @param lhsOffset offset of the left-hand side in its array. + * @param rhs array right-hand side of multiplication. + * @param rhsOffset offset of the right-hand side in its array. + * @param result array where result must be stored (for multiplication the result array *cannot* be one of the input + * arrays). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.multiply( + lhs: Buffer, + lhsOffset: Int, + rhs: Buffer, + rhsOffset: Int, + result: MutableBuffer, + resultOffset: Int, +) where A : Ring, A : ScaleOperations = algebra { + for (i in multIndirection.indices) { + val mappingI = multIndirection[i] + var r = zero + + for (j in mappingI.indices) { + r += mappingI[j][0] * lhs[lhsOffset + mappingI[j][1]] * rhs[rhsOffset + mappingI[j][2]] + } + + result[resultOffset + i] = r + } +} + +/** + * Perform subtraction of two derivative structures. + * + * @param lhs array holding left-hand side of subtraction. + * @param lhsOffset offset of the left-hand side in its array. + * @param rhs array right-hand side of subtraction. + * @param rhsOffset offset of the right-hand side in its array. + * @param result array where result must be stored (it may be one of the input arrays). + * @param resultOffset offset of the result in its array. + */ +internal fun > DSCompiler.subtract( + lhs: Buffer, + lhsOffset: Int, + rhs: Buffer, + rhsOffset: Int, + result: MutableBuffer, + resultOffset: Int, +) = algebra { + for (i in 0 until size) { + result[resultOffset + i] = lhs[lhsOffset + i] - rhs[rhsOffset + i] + } +} + +/** + * Compute inverse hyperbolic sine of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for inverse hyperbolic sine the result array *cannot* be the input + * array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.asinh( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int +) where A : Field, A : ExponentialOperations, A : PowerOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + val x = operand[operandOffset] + function[0] = asinh(x) + if (order > 0) { + // the nth order derivative of asinh has the form: + // dn(asinh(x)/dxn = P_n(x) / [x^2 + 1]^((2n-1)/2) + // where P_n(x) is a degree n-1 polynomial with same parity as n-1 + // P_1(x) = 1, P_2(x) = -x, P_3(x) = 2x^2 - 1 ... + // the general recurrence relation for P_n is: + // P_n(x) = (x^2+1) P_(n-1)'(x) - (2n-3) x P_(n-1)(x) + // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array + val p = bufferFactory(order) { zero } + p[0] = one + val x2 = x * x + val f = one / (one + x2) + var coeff = sqrt(f) + function[1] = coeff * p[0] + for (n in 2..order) { + + // update and evaluate polynomial P_n(x) + var v = zero + p[n - 1] = (1 - n) * p[n - 2] + var k = n - 1 + while (k >= 0) { + v = v * x2 + p[k] + if (k > 2) { + p[k - 2] = (k - 1) * p[k - 1] + (k - 2 * n) * p[k - 3] + } else if (k == 2) { + p[0] = p[1] + } + k -= 2 + } + if (n and 0x1 == 0) { + v *= x + } + coeff *= f + function[n] = coeff * v + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Perform addition of two derivative structures. + * + * @param lhs array holding left-hand side of addition. + * @param lhsOffset offset of the left-hand side in its array. + * @param rhs array right-hand side of addition. + * @param rhsOffset offset of the right-hand side in its array. + * @param result array where result must be stored (it may be one of the input arrays). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.add( + lhs: Buffer, + lhsOffset: Int, + rhs: Buffer, + rhsOffset: Int, + result: MutableBuffer, + resultOffset: Int, +) where A : Group = algebra { + for (i in 0 until size) { + result[resultOffset + i] = lhs[lhsOffset + i] + rhs[rhsOffset + i] + } +} + +/** + * Check rules set compatibility. + * + * @param compiler other compiler to check against instance. + */ +internal fun > DSCompiler.checkCompatibility(compiler: DSCompiler) { + require(freeParameters == compiler.freeParameters) { + "dimension mismatch: $freeParameters and ${compiler.freeParameters}" + } + require(order == compiler.order) { + "dimension mismatch: $order and ${compiler.order}" + } +} + +/** + * Compute inverse hyperbolic tangent of a derivative structure. + * + * @param operand array holding the operand. + * @param operandOffset offset of the operand in its array. + * @param result array where result must be stored (for inverse hyperbolic tangent the result array *cannot* be the + * input array). + * @param resultOffset offset of the result in its array. + */ +internal fun DSCompiler.atanh( + operand: Buffer, + operandOffset: Int, + result: MutableBuffer, + resultOffset: Int, +) where A : Field, A : ExponentialOperations = algebra { + // create the function value and derivatives + val function = bufferFactory(1 + order) { zero } + val x = operand[operandOffset] + function[0] = atanh(x) + + if (order > 0) { + // the nth order derivative of atanh has the form: + // dn(atanh(x)/dxn = Q_n(x) / (1 - x^2)^n + // where Q_n(x) is a degree n-1 polynomial with same parity as n-1 + // Q_1(x) = 1, Q_2(x) = 2x, Q_3(x) = 6x^2 + 2 ... + // the general recurrence relation for Q_n is: + // Q_n(x) = (1-x^2) Q_(n-1)'(x) + 2(n-1) x Q_(n-1)(x) + // as per polynomial parity, we can store coefficients of both Q_(n-1) and Q_n in the same array + val q = bufferFactory(order) { zero } + q[0] = one + val x2 = x * x + val f = one / (one - x2) + var coeff = f + function[1] = coeff * q[0] + for (n in 2..order) { + + // update and evaluate polynomial Q_n(x) + var v = zero + q[n - 1] = n * q[n - 2] + var k = n - 1 + while (k >= 0) { + v = v * x2 + q[k] + if (k > 2) { + q[k - 2] = (k - 1) * q[k - 1] + (2 * n - k + 1) * q[k - 3] + } else if (k == 2) { + q[0] = q[1] + } + k -= 2 + } + if (n and 0x1 == 0) { + v *= x + } + coeff *= f + function[n] = coeff * v + } + } + + // apply function composition + compose(operand, operandOffset, function, result, resultOffset) +} + +/** + * Get the compiler for number of free parameters and order. + * + * @param parameters number of free parameters. + * @param order derivation order. + * @return cached rules set. + */ +internal fun > getCompiler( + algebra: A, + bufferFactory: MutableBufferFactory, + parameters: Int, + order: Int +): DSCompiler { + // get the cached compilers + val cache: Array?>>? = null + + // we need to create more compilers + val maxParameters: Int = max(parameters, cache?.size ?: 0) + val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size) + val newCache: Array?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) } + + if (cache != null) { + // preserve the already created compilers + for (i in cache.indices) { + cache[i].copyInto(newCache[i], endIndex = cache[i].size) + } + } + + // create the array in increasing diagonal order + + // create the array in increasing diagonal order + for (diag in 0..parameters + order) { + for (o in max(0, diag - parameters)..min(order, diag)) { + val p: Int = diag - o + if (newCache[p][o] == null) { + val valueCompiler: DSCompiler? = if (p == 0) null else newCache[p - 1][o]!! + val derivativeCompiler: DSCompiler? = if (o == 0) null else newCache[p][o - 1]!! + + newCache[p][o] = DSCompiler( + algebra, + bufferFactory, + p, + o, + valueCompiler, + derivativeCompiler, + ) + } + } + } + + return newCache[parameters][order]!! +} + +/** + * Compile the sizes array. + * + * @param parameters number of free parameters. + * @param order derivation order. + * @param valueCompiler compiler for the value part. + * @return sizes array. + */ +private fun > compileSizes( + parameters: Int, order: Int, + valueCompiler: DSCompiler?, +): Array { + val sizes = Array(parameters + 1) { + IntArray(order + 1) + } + + if (parameters == 0) { + sizes[0].fill(1) + } else { + checkNotNull(valueCompiler) + valueCompiler.sizes.copyInto(sizes, endIndex = parameters) + sizes[parameters][0] = 1 + for (i in 0 until order) { + sizes[parameters][i + 1] = sizes[parameters][i] + sizes[parameters - 1][i + 1] + } + } + return sizes +} + +/** + * Compile the derivatives' indirection array. + * + * @param parameters number of free parameters. + * @param order derivation order. + * @param valueCompiler compiler for the value part. + * @param derivativeCompiler compiler for the derivative part. + * @return derivatives indirection array. + */ +private fun > compileDerivativesIndirection( + parameters: Int, + order: Int, + valueCompiler: DSCompiler?, + derivativeCompiler: DSCompiler?, +): Array { + if (parameters == 0 || order == 0) { + return Array(1) { IntArray(parameters) } + } + + val vSize: Int = valueCompiler!!.derivativesIndirection.size + val dSize: Int = derivativeCompiler!!.derivativesIndirection.size + val derivativesIndirection = Array(vSize + dSize) { IntArray(parameters) } + + // set up the indices for the value part + for (i in 0 until vSize) { + // copy the first indices, the last one remaining set to 0 + valueCompiler.derivativesIndirection[i].copyInto(derivativesIndirection[i], endIndex = parameters - 1) + } + + // set up the indices for the derivative part + for (i in 0 until dSize) { + // copy the indices + derivativeCompiler.derivativesIndirection[i].copyInto(derivativesIndirection[vSize], 0, 0, parameters) + + // increment the derivation order for the last parameter + derivativesIndirection[vSize + i][parameters - 1]++ + } + + return derivativesIndirection +} + +/** + * Compile the lower derivatives' indirection array. + * + * This indirection array contains the indices of all elements except derivatives for last derivation order. + * + * @param parameters number of free parameters. + * @param order derivation order. + * @param valueCompiler compiler for the value part. + * @param derivativeCompiler compiler for the derivative part. + * @return lower derivatives' indirection array. + */ +private fun > compileLowerIndirection( + parameters: Int, + order: Int, + valueCompiler: DSCompiler?, + derivativeCompiler: DSCompiler?, +): IntArray { + if (parameters == 0 || order <= 1) return intArrayOf(0) + checkNotNull(valueCompiler) + checkNotNull(derivativeCompiler) + + // this is an implementation of definition 6 in Dan Kalman's paper. + val vSize: Int = valueCompiler.lowerIndirection.size + val dSize: Int = derivativeCompiler.lowerIndirection.size + val lowerIndirection = IntArray(vSize + dSize) + valueCompiler.lowerIndirection.copyInto(lowerIndirection, endIndex = vSize) + for (i in 0 until dSize) { + lowerIndirection[vSize + i] = valueCompiler.size + derivativeCompiler.lowerIndirection[i] + } + return lowerIndirection +} + +/** + * Compile the multiplication indirection array. + * + * This indirection array contains the indices of all pairs of elements involved when computing a multiplication. This + * allows a straightforward loop-based multiplication (see [multiply]). + * + * @param parameters number of free parameters. + * @param order derivation order. + * @param valueCompiler compiler for the value part. + * @param derivativeCompiler compiler for the derivative part. + * @param lowerIndirection lower derivatives' indirection array. + * @return multiplication indirection array. + */ +@Suppress("UNCHECKED_CAST") +private fun > compileMultiplicationIndirection( + parameters: Int, + order: Int, + valueCompiler: DSCompiler?, + derivativeCompiler: DSCompiler?, + lowerIndirection: IntArray, +): Array> { + if (parameters == 0 || order == 0) return arrayOf(arrayOf(intArrayOf(1, 0, 0))) + + // this is an implementation of definition 3 in Dan Kalman's paper. + val vSize = valueCompiler!!.multIndirection.size + val dSize = derivativeCompiler!!.multIndirection.size + val multIndirection: Array?> = arrayOfNulls(vSize + dSize) + valueCompiler.multIndirection.copyInto(multIndirection, endIndex = vSize) + + for (i in 0 until dSize) { + val dRow = derivativeCompiler.multIndirection[i] + val row: List = buildList(dRow.size * 2) { + for (j in dRow.indices) { + add(intArrayOf(dRow[j][0], lowerIndirection[dRow[j][1]], vSize + dRow[j][2])) + add(intArrayOf(dRow[j][0], vSize + dRow[j][1], lowerIndirection[dRow[j][2]])) + } + } + + // combine terms with similar derivation orders + val combined: List = buildList(row.size) { + for (j in row.indices) { + val termJ = row[j] + if (termJ[0] > 0) { + for (k in j + 1 until row.size) { + val termK = row[k] + + if (termJ[1] == termK[1] && termJ[2] == termK[2]) { + // combine termJ and termK + termJ[0] += termK[0] + // make sure we will skip termK later on in the outer loop + termK[0] = 0 + } + } + + add(termJ) + } + } + } + + multIndirection[vSize + i] = combined.toTypedArray() + } + + return multIndirection as Array> +} + +/** + * Compile the indirection array of function composition. + * + * This indirection array contains the indices of all sets of elements involved when computing a composition. This + * allows a straightforward loop-based composition (see [compose]). + * + * @param parameters number of free parameters. + * @param order derivation order. + * @param valueCompiler compiler for the value part. + * @param derivativeCompiler compiler for the derivative part. + * @param sizes sizes array. + * @param derivativesIndirection derivatives indirection array. + * @return multiplication indirection array. + */ +@Suppress("UNCHECKED_CAST") +private fun > compileCompositionIndirection( + parameters: Int, + order: Int, + valueCompiler: DSCompiler?, + derivativeCompiler: DSCompiler?, + sizes: Array, + derivativesIndirection: Array, +): Array> { + if (parameters == 0 || order == 0) { + return arrayOf(arrayOf(intArrayOf(1, 0))) + } + + val vSize = valueCompiler!!.compositionIndirection.size + val dSize = derivativeCompiler!!.compositionIndirection.size + val compIndirection: Array?> = arrayOfNulls(vSize + dSize) + + // the composition rules from the value part can be reused as is + valueCompiler.compositionIndirection.copyInto(compIndirection, endIndex = vSize) + + // the composition rules for the derivative part are deduced by differentiation the rules from the + // underlying compiler once with respect to the parameter this compiler handles and the underlying one + // did not handle + + // the composition rules for the derivative part are deduced by differentiation the rules from the + // underlying compiler once with respect to the parameter this compiler handles and the underlying one did + // not handle + for (i in 0 until dSize) { + val row: List = buildList { + for (term in derivativeCompiler.compositionIndirection[i]) { + + // handle term p * f_k(g(x)) * g_l1(x) * g_l2(x) * ... * g_lp(x) + + // derive the first factor in the term: f_k with respect to new parameter + val derivedTermF = IntArray(term.size + 1) + derivedTermF[0] = term[0] // p + derivedTermF[1] = term[1] + 1 // f_(k+1) + val orders = IntArray(parameters) + orders[parameters - 1] = 1 + derivedTermF[term.size] = getPartialDerivativeIndex( + parameters, + order, + sizes, + *orders + ) // g_1 + + for (j in 2 until term.size) { + // convert the indices as the mapping for the current order is different from the mapping with one + // less order + derivedTermF[j] = convertIndex( + term[j], parameters, + derivativeCompiler.derivativesIndirection, + parameters, order, sizes + ) + } + + derivedTermF.sort(2, derivedTermF.size) + add(derivedTermF) + + // derive the various g_l + for (l in 2 until term.size) { + val derivedTermG = IntArray(term.size) + derivedTermG[0] = term[0] + derivedTermG[1] = term[1] + + for (j in 2 until term.size) { + // convert the indices as the mapping for the current order + // is different from the mapping with one less order + derivedTermG[j] = convertIndex( + term[j], + parameters, + derivativeCompiler.derivativesIndirection, + parameters, + order, + sizes, + ) + + if (j == l) { + // derive this term + derivativesIndirection[derivedTermG[j]].copyInto(orders, endIndex = parameters) + orders[parameters - 1]++ + + derivedTermG[j] = getPartialDerivativeIndex( + parameters, + order, + sizes, + *orders, + ) + } + } + + derivedTermG.sort(2, derivedTermG.size) + add(derivedTermG) + } + } + } + + // combine terms with similar derivation orders + val combined: List = buildList(row.size) { + for (j in row.indices) { + val termJ = row[j] + + if (termJ[0] > 0) { + (j + 1 until row.size).map { k -> row[k] }.forEach { termK -> + var equals = termJ.size == termK.size + var l = 1 + + while (equals && l < termJ.size) { + equals = equals and (termJ[l] == termK[l]) + ++l + } + + if (equals) { + // combine termJ and termK + termJ[0] += termK[0] + // make sure we will skip termK later on in the outer loop + termK[0] = 0 + } + } + + add(termJ) + } + } + } + + compIndirection[vSize + i] = combined.toTypedArray() + } + + return compIndirection as Array> +} + +/** + * Get the index of a partial derivative in an array. + * + * @param parameters number of free parameters. + * @param order derivation order. + * @param sizes sizes array. + * @param orders derivation orders with respect to each parameter (the length of this array must match the number of + * parameters). + * @return index of the partial derivative. + */ +private fun getPartialDerivativeIndex( + parameters: Int, + order: Int, + sizes: Array, + vararg orders: Int, +): Int { + + // the value is obtained by diving into the recursive Dan Kalman's structure + // this is theorem 2 of his paper, with recursion replaced by iteration + var index = 0 + var m = order + var ordersSum = 0 + + for (i in parameters - 1 downTo 0) { + // derivative order for current free parameter + var derivativeOrder = orders[i] + + // safety check + ordersSum += derivativeOrder + require(ordersSum <= order) { "number is too large: $ordersSum > $order" } + + while (derivativeOrder-- > 0) { + // as long as we differentiate according to current free parameter, + // we have to skip the value part and dive into the derivative part, + // so we add the size of the value part to the base index + index += sizes[i][m--] + } + } + + return index +} + +/** + * Convert an index from one (parameters, order) structure to another. + * + * @param index index of a partial derivative in source derivative structure. + * @param srcP number of free parameters in source derivative structure. + * @param srcDerivativesIndirection derivatives indirection array for the source derivative structure. + * @param destP number of free parameters in destination derivative structure. + * @param destO derivation order in destination derivative structure. + * @param destSizes sizes array for the destination derivative structure. + * @return index of the partial derivative with the *same* characteristics in destination derivative structure. + */ +private fun convertIndex( + index: Int, + srcP: Int, + srcDerivativesIndirection: Array, + destP: Int, + destO: Int, + destSizes: Array, +): Int { + val orders = IntArray(destP) + srcDerivativesIndirection[index].copyInto(orders, endIndex = min(srcP, destP)) + return getPartialDerivativeIndex(destP, destO, destSizes, *orders) +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt new file mode 100644 index 000000000..a1a6354f0 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt @@ -0,0 +1,186 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.expressions + +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.NumericAlgebra +import space.kscience.kmath.operations.Ring +import space.kscience.kmath.operations.ScaleOperations +import space.kscience.kmath.structures.MutableBuffer + +/** + * Class representing both the value and the differentials of a function. + * + * This class is the workhorse of the differentiation package. + * + * This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper [Doubly Recursive + * Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf), + * Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used + * throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's + * derivative structures hold all partial derivatives up to any specified order, with respect to any number of free + * parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free + * parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters. + * + * Derived from + * [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java). + */ +@UnstableKMathAPI +public open class DerivativeStructure internal constructor( + internal val derivativeAlgebra: DerivativeStructureRing, + internal val compiler: DSCompiler, +) where A : Ring, A : NumericAlgebra, A : ScaleOperations { + /** + * Combined array holding all values. + */ + internal var data: MutableBuffer = + derivativeAlgebra.bufferFactory(compiler.size) { derivativeAlgebra.algebra.zero } + + /** + * Build an instance with all values and derivatives set to 0. + * + * @param parameters number of free parameters. + * @param order derivation order. + */ + public constructor ( + derivativeAlgebra: DerivativeStructureRing, + parameters: Int, + order: Int, + ) : this( + derivativeAlgebra, + getCompiler(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, parameters, order), + ) + + /** + * Build an instance representing a constant value. + * + * @param parameters number of free parameters. + * @param order derivation order. + * @param value value of the constant. + * @see DerivativeStructure + */ + public constructor ( + derivativeAlgebra: DerivativeStructureRing, + parameters: Int, + order: Int, + value: T, + ) : this( + derivativeAlgebra, + parameters, + order, + ) { + data[0] = value + } + + /** + * Build an instance representing a variable. + * + * Instances built using this constructor are considered to be the free variables with respect to which + * differentials are computed. As such, their differential with respect to themselves is +1. + * + * @param parameters number of free parameters. + * @param order derivation order. + * @param index index of the variable (from 0 to `parameters - 1`). + * @param value value of the variable. + */ + public constructor ( + derivativeAlgebra: DerivativeStructureRing, + parameters: Int, + order: Int, + index: Int, + value: T, + ) : this(derivativeAlgebra, parameters, order, value) { + require(index < parameters) { "number is too large: $index >= $parameters" } + + if (order > 0) { + // the derivative of the variable with respect to itself is 1. + data[getCompiler(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, index, order).size] = + derivativeAlgebra.algebra.one + } + } + + /** + * Build an instance from all its derivatives. + * + * @param parameters number of free parameters. + * @param order derivation order. + * @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex]. + */ + public constructor ( + derivativeAlgebra: DerivativeStructureRing, + parameters: Int, + order: Int, + vararg derivatives: T, + ) : this( + derivativeAlgebra, + parameters, + order, + ) { + require(derivatives.size == data.size) { "dimension mismatch: ${derivatives.size} and ${data.size}" } + data = derivativeAlgebra.bufferFactory(data.size) { derivatives[it] } + } + + /** + * Copy constructor. + * + * @param ds instance to copy. + */ + internal constructor(ds: DerivativeStructure) : this(ds.derivativeAlgebra, ds.compiler) { + this.data = ds.data.copy() + } + + /** + * The number of free parameters. + */ + public val freeParameters: Int + get() = compiler.freeParameters + + /** + * The derivation order. + */ + public val order: Int + get() = compiler.order + + /** + * The value part of the derivative structure. + * + * @see getPartialDerivative + */ + public val value: T + get() = data[0] + + /** + * Get a partial derivative. + * + * @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned). + * @return partial derivative. + * @see value + */ + public fun getPartialDerivative(vararg orders: Int): T = data[compiler.getPartialDerivativeIndex(*orders)] + + + /** + * Test for the equality of two derivative structures. + * + * Derivative structures are considered equal if they have the same number + * of free parameters, the same derivation order, and the same derivatives. + * + * @return `true` if two derivative structures are equal. + */ + public override fun equals(other: Any?): Boolean { + if (this === other) return true + + if (other is DerivativeStructure<*, *>) { + return ((freeParameters == other.freeParameters) && + (order == other.order) && + data == other.data) + } + + return false + } + + public override fun hashCode(): Int = + 227 + 229 * freeParameters + 233 * order + 239 * data.hashCode() +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt new file mode 100644 index 000000000..f91fb55e8 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt @@ -0,0 +1,332 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.expressions + +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.* +import space.kscience.kmath.structures.MutableBufferFactory +import space.kscience.kmath.structures.indices + +/** + * A class implementing both [DerivativeStructure] and [Symbol]. + */ +@UnstableKMathAPI +public class DerivativeStructureSymbol( + derivativeAlgebra: DerivativeStructureRing, + size: Int, + order: Int, + index: Int, + symbol: Symbol, + value: T, +) : Symbol by symbol, DerivativeStructure( + derivativeAlgebra, + size, + order, + index, + value +) where A : Ring, A : NumericAlgebra, A : ScaleOperations { + override fun toString(): String = symbol.toString() + override fun equals(other: Any?): Boolean = (other as? Symbol) == symbol + override fun hashCode(): Int = symbol.hashCode() +} + +/** + * A ring over [DerivativeStructure]. + * + * @property order The derivation order. + * @param bindings The map of bindings values. All bindings are considered free parameters. + */ +@UnstableKMathAPI +public open class DerivativeStructureRing( + public val algebra: A, + public val bufferFactory: MutableBufferFactory, + public val order: Int, + bindings: Map, +) : Ring>, ScaleOperations>, + NumericAlgebra>, + ExpressionAlgebra>, + NumbersAddOps> where A : Ring, A : NumericAlgebra, A : ScaleOperations { + public val numberOfVariables: Int = bindings.size + + override val zero: DerivativeStructure by lazy { + DerivativeStructure( + this, + numberOfVariables, + order, + ) + } + + override val one: DerivativeStructure by lazy { + DerivativeStructure( + this, + numberOfVariables, + order, + algebra.one, + ) + } + + override fun number(value: Number): DerivativeStructure = const(algebra.number(value)) + + private val variables: Map> = + bindings.entries.mapIndexed { index, (key, value) -> + key to DerivativeStructureSymbol( + this, + numberOfVariables, + order, + index, + key, + value, + ) + }.toMap() + + public override fun const(value: T): DerivativeStructure = + DerivativeStructure(this, numberOfVariables, order, value) + + override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = variables[StringSymbol(value)] + + override fun bindSymbol(value: String): DerivativeStructureSymbol = + bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this") + + public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] + + public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol = + bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this") + + public fun DerivativeStructure.derivative(symbols: List): T { + require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" } + val ordersCount = symbols.groupBy { it }.mapValues { it.value.size } + return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray()) + } + + public fun DerivativeStructure.derivative(vararg symbols: Symbol): T = derivative(symbols.toList()) + + override fun DerivativeStructure.unaryMinus(): DerivativeStructure { + val ds = DerivativeStructure(this@DerivativeStructureRing, compiler) + for (i in ds.data.indices) { + ds.data[i] = algebra { -data[i] } + } + return ds + } + + override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure { + left.compiler.checkCompatibility(right.compiler) + val ds = DerivativeStructure(left) + left.compiler.add(left.data, 0, right.data, 0, ds.data, 0) + return ds + } + + override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure { + val ds = DerivativeStructure(a) + for (i in ds.data.indices) { + ds.data[i] = algebra { ds.data[i].times(value) } + } + return ds + } + + override fun multiply( + left: DerivativeStructure, + right: DerivativeStructure + ): DerivativeStructure { + left.compiler.checkCompatibility(right.compiler) + val result = DerivativeStructure(this, left.compiler) + left.compiler.multiply(left.data, 0, right.data, 0, result.data, 0) + return result + } + + override fun DerivativeStructure.minus(arg: DerivativeStructure): DerivativeStructure { + compiler.checkCompatibility(arg.compiler) + val ds = DerivativeStructure(this) + compiler.subtract(data, 0, arg.data, 0, ds.data, 0) + return ds + } + + override operator fun DerivativeStructure.plus(other: Number): DerivativeStructure { + val ds = DerivativeStructure(this) + ds.data[0] = algebra { ds.data[0] + number(other) } + return ds + } + + override operator fun DerivativeStructure.minus(other: Number): DerivativeStructure = + this + -other.toDouble() + + override operator fun Number.plus(other: DerivativeStructure): DerivativeStructure = other + this + override operator fun Number.minus(other: DerivativeStructure): DerivativeStructure = other - this +} + +@UnstableKMathAPI +public class DerivativeStructureRingExpression( + public val algebra: A, + public val bufferFactory: MutableBufferFactory, + public val function: DerivativeStructureRing.() -> DerivativeStructure, +) : DifferentiableExpression where A : Ring, A : ScaleOperations, A : NumericAlgebra { + override operator fun invoke(arguments: Map): T = + DerivativeStructureRing(algebra, bufferFactory, 0, arguments).function().value + + override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + with( + DerivativeStructureRing( + algebra, + bufferFactory, + symbols.size, + arguments + ) + ) { function().derivative(symbols) } + } +} + +/** + * A field over commons-math [DerivativeStructure]. + * + * @property order The derivation order. + * @param bindings The map of bindings values. All bindings are considered free parameters. + */ +@UnstableKMathAPI +public class DerivativeStructureField>( + algebra: A, + bufferFactory: MutableBufferFactory, + order: Int, + bindings: Map, +) : DerivativeStructureRing(algebra, bufferFactory, order, bindings), ExtendedField> { + override fun number(value: Number): DerivativeStructure = const(algebra.number(value)) + + override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure { + left.compiler.checkCompatibility(right.compiler) + val result = DerivativeStructure(this, left.compiler) + left.compiler.divide(left.data, 0, right.data, 0, result.data, 0) + return result + } + + override fun sin(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.sin(arg.data, 0, result.data, 0) + return result + } + + override fun cos(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.cos(arg.data, 0, result.data, 0) + return result + } + + override fun tan(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.tan(arg.data, 0, result.data, 0) + return result + } + + override fun asin(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.asin(arg.data, 0, result.data, 0) + return result + } + + override fun acos(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.acos(arg.data, 0, result.data, 0) + return result + } + + override fun atan(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.atan(arg.data, 0, result.data, 0) + return result + } + + override fun sinh(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.sinh(arg.data, 0, result.data, 0) + return result + } + + override fun cosh(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.cosh(arg.data, 0, result.data, 0) + return result + } + + override fun tanh(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.tanh(arg.data, 0, result.data, 0) + return result + } + + override fun asinh(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.asinh(arg.data, 0, result.data, 0) + return result + } + + override fun acosh(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.acosh(arg.data, 0, result.data, 0) + return result + } + + override fun atanh(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.atanh(arg.data, 0, result.data, 0) + return result + } + + override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { + is Int -> { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.pow(arg.data, 0, pow, result.data, 0) + result + } + else -> { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.pow(arg.data, 0, pow.toDouble(), result.data, 0) + result + } + } + + override fun sqrt(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.sqrt(arg.data, 0, result.data, 0) + return result + } + + public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure { + arg.compiler.checkCompatibility(pow.compiler) + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.pow(arg.data, 0, pow.data, 0, result.data, 0) + return result + } + + override fun exp(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.exp(arg.data, 0, result.data, 0) + return result + } + + override fun ln(arg: DerivativeStructure): DerivativeStructure { + val result = DerivativeStructure(this, arg.compiler) + arg.compiler.ln(arg.data, 0, result.data, 0) + return result + } +} + +@UnstableKMathAPI +public class DerivativeStructureFieldExpression>( + public val algebra: A, + public val bufferFactory: MutableBufferFactory, + public val function: DerivativeStructureField.() -> DerivativeStructure, +) : DifferentiableExpression { + override operator fun invoke(arguments: Map): T = + DerivativeStructureField(algebra, bufferFactory, 0, arguments).function().value + + override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + with( + DerivativeStructureField( + algebra, + bufferFactory, + symbols.size, + arguments, + ) + ) { function().derivative(symbols) } + } +} diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt new file mode 100644 index 000000000..429fe310b --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt @@ -0,0 +1,59 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +@file:OptIn(UnstableKMathAPI::class) + +package space.kscience.kmath.expressions + +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.structures.DoubleBuffer +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFails + +internal inline fun diff( + order: Int, + vararg parameters: Pair, + block: DerivativeStructureField.() -> Unit, +) { + contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } + DerivativeStructureField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block() +} + +internal class AutoDiffTest { + private val x by symbol + private val y by symbol + + @Test + fun derivativeStructureFieldTest() { + diff(2, x to 1.0, y to 1.0) { + val x = bindSymbol(x)//by binding() + val y = bindSymbol("y") + val z = x * (-sin(x * y) + y) + 2.0 + println(z.derivative(x)) + println(z.derivative(y, x)) + assertEquals(z.derivative(x, y), z.derivative(y, x)) + // check improper order cause failure + assertFails { z.derivative(x, x, y) } + } + } + + @Test + fun autoDifTest() { + val f = DerivativeStructureFieldExpression(DoubleField, ::DoubleBuffer) { + val x by binding + val y by binding + x.pow(2) + 2 * x * y + y.pow(2) + 1 + } + + assertEquals(10.0, f(x to 1.0, y to 2.0)) + assertEquals(6.0, f.derivative(x)(x to 1.0, y to 2.0)) + assertEquals(2.0, f.derivative(x, x)(x to 1.234, y to -2.0)) + assertEquals(2.0, f.derivative(x, y)(x to 1.0, y to 2.0)) + } +} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt index 3997a77b3..71fd15fe6 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt @@ -48,7 +48,8 @@ internal object InternalUtils { cache.copyInto( logFactorials, BEGIN_LOG_FACTORIALS, - BEGIN_LOG_FACTORIALS, endCopy + BEGIN_LOG_FACTORIALS, + endCopy, ) } else // All values to be computed From 5846f42141a210bef8ba873030f9566ebc14c5e9 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 15 Jul 2022 15:21:49 +0300 Subject: [PATCH 2/6] Grand derivative refactoring. Phase 1 --- .../kscience/kmath/expressions/DSCompiler.kt | 90 +- .../kmath/expressions/DerivativeStructure.kt | 191 ++-- .../DerivativeStructureExpression.kt | 327 +++--- .../kscience/kmath/linear/LinearSpace.kt | 2 +- .../space/kscience/kmath/misc/annotations.kt | 2 +- .../space/kscience/kmath/nd/BufferND.kt | 2 +- .../space/kscience/kmath/nd/StructureND.kt | 4 +- .../kmath/operations/DoubleBufferOps.kt | 4 +- .../kmath/operations/bufferOperation.kt | 26 +- .../space/kscience/kmath/structures/Buffer.kt | 8 +- .../space/kscience/kmath/ejml/_generated.kt | 1003 ----------------- .../histogram/UniformHistogramGroupND.kt | 6 +- .../kmath/multik/MultikDoubleAlgebra.kt | 7 + .../space/kscience/kmath/stat/Sampler.kt | 2 +- 14 files changed, 311 insertions(+), 1363 deletions(-) delete mode 100644 kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt index bb88ce52c..e0050cf03 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt @@ -5,11 +5,11 @@ package space.kscience.kmath.expressions + import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBufferFactory -import kotlin.math.max import kotlin.math.min internal fun MutableBuffer.fill(element: T, fromIndex: Int = 0, toIndex: Int = size) { @@ -54,7 +54,7 @@ internal fun MutableBuffer.fill(element: T, fromIndex: Int = 0, toIndex: * @property order Derivation order. * @see DerivativeStructure */ -internal class DSCompiler> internal constructor( +class DSCompiler> internal constructor( val algebra: A, val bufferFactory: MutableBufferFactory, val freeParameters: Int, @@ -120,8 +120,7 @@ internal class DSCompiler> internal constructor( * This number includes the single 0 order derivative element, which is * guaranteed to be stored in the first element of the array. */ - val size: Int - get() = sizes[freeParameters][order] + val size: Int get() = sizes[freeParameters][order] /** * Get the index of a partial derivative in the array. @@ -178,7 +177,7 @@ internal fun DSCompiler.ln( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : ExponentialOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -211,7 +210,7 @@ internal fun DSCompiler.pow( operandOffset: Int, n: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : PowerOperations = algebra { if (n == 0) { // special case, x^0 = 1 for all x @@ -267,7 +266,7 @@ internal fun DSCompiler.exp( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Ring, A : ScaleOperations, A : ExponentialOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -290,7 +289,7 @@ internal fun DSCompiler.sqrt( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : PowerOperations = algebra { // create the function value and derivatives // [x^(1/n), (1/n)x^((1/n)-1), (1-n)/n^2x^((1/n)-2), ... ] @@ -351,7 +350,7 @@ internal fun DSCompiler.pow( operandOffset: Int, p: Double, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Ring, A : NumericAlgebra, A : PowerOperations, A : ScaleOperations = algebra { // create the function value and derivatives // [x^p, px^(p-1), p(p-1)x^(p-2), ... ] @@ -387,7 +386,7 @@ internal fun DSCompiler.tan( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Ring, A : TrigonometricOperations, A : ScaleOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -469,7 +468,7 @@ internal fun DSCompiler.sin( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Ring, A : ScaleOperations, A : TrigonometricOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -497,7 +496,7 @@ internal fun DSCompiler.acos( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : TrigonometricOperations, A : PowerOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -559,7 +558,7 @@ internal fun DSCompiler.asin( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : TrigonometricOperations, A : PowerOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -618,7 +617,7 @@ internal fun DSCompiler.atan( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : TrigonometricOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -678,7 +677,7 @@ internal fun DSCompiler.cosh( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Ring, A : ScaleOperations, A : ExponentialOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -708,7 +707,7 @@ internal fun DSCompiler.tanh( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : ExponentialOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -765,7 +764,7 @@ internal fun DSCompiler.acosh( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : ExponentialOperations, A : PowerOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -857,7 +856,7 @@ internal fun DSCompiler.sinh( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : ExponentialOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -964,7 +963,7 @@ internal fun DSCompiler.asinh( operand: Buffer, operandOffset: Int, result: MutableBuffer, - resultOffset: Int + resultOffset: Int, ) where A : Field, A : ExponentialOperations, A : PowerOperations = algebra { // create the function value and derivatives val function = bufferFactory(1 + order) { zero } @@ -1109,59 +1108,6 @@ internal fun DSCompiler.atanh( compose(operand, operandOffset, function, result, resultOffset) } -/** - * Get the compiler for number of free parameters and order. - * - * @param parameters number of free parameters. - * @param order derivation order. - * @return cached rules set. - */ -internal fun > getCompiler( - algebra: A, - bufferFactory: MutableBufferFactory, - parameters: Int, - order: Int -): DSCompiler { - // get the cached compilers - val cache: Array?>>? = null - - // we need to create more compilers - val maxParameters: Int = max(parameters, cache?.size ?: 0) - val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size) - val newCache: Array?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) } - - if (cache != null) { - // preserve the already created compilers - for (i in cache.indices) { - cache[i].copyInto(newCache[i], endIndex = cache[i].size) - } - } - - // create the array in increasing diagonal order - - // create the array in increasing diagonal order - for (diag in 0..parameters + order) { - for (o in max(0, diag - parameters)..min(order, diag)) { - val p: Int = diag - o - if (newCache[p][o] == null) { - val valueCompiler: DSCompiler? = if (p == 0) null else newCache[p - 1][o]!! - val derivativeCompiler: DSCompiler? = if (o == 0) null else newCache[p][o - 1]!! - - newCache[p][o] = DSCompiler( - algebra, - bufferFactory, - p, - o, - valueCompiler, - derivativeCompiler, - ) - } - } - } - - return newCache[parameters][order]!! -} - /** * Compile the sizes array. * diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt index a1a6354f0..01c045cdb 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt @@ -6,10 +6,9 @@ package space.kscience.kmath.expressions import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.NumericAlgebra import space.kscience.kmath.operations.Ring -import space.kscience.kmath.operations.ScaleOperations -import space.kscience.kmath.structures.MutableBuffer +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.asBuffer /** * Class representing both the value and the differentials of a function. @@ -28,128 +27,29 @@ import space.kscience.kmath.structures.MutableBuffer * [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java). */ @UnstableKMathAPI -public open class DerivativeStructure internal constructor( - internal val derivativeAlgebra: DerivativeStructureRing, - internal val compiler: DSCompiler, -) where A : Ring, A : NumericAlgebra, A : ScaleOperations { - /** - * Combined array holding all values. - */ - internal var data: MutableBuffer = - derivativeAlgebra.bufferFactory(compiler.size) { derivativeAlgebra.algebra.zero } +public open class DerivativeStructure> @PublishedApi internal constructor( + private val derivativeAlgebra: DerivativeStructureAlgebra, + @PublishedApi internal val data: Buffer, +) { - /** - * Build an instance with all values and derivatives set to 0. - * - * @param parameters number of free parameters. - * @param order derivation order. - */ - public constructor ( - derivativeAlgebra: DerivativeStructureRing, - parameters: Int, - order: Int, - ) : this( - derivativeAlgebra, - getCompiler(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, parameters, order), - ) - - /** - * Build an instance representing a constant value. - * - * @param parameters number of free parameters. - * @param order derivation order. - * @param value value of the constant. - * @see DerivativeStructure - */ - public constructor ( - derivativeAlgebra: DerivativeStructureRing, - parameters: Int, - order: Int, - value: T, - ) : this( - derivativeAlgebra, - parameters, - order, - ) { - data[0] = value - } - - /** - * Build an instance representing a variable. - * - * Instances built using this constructor are considered to be the free variables with respect to which - * differentials are computed. As such, their differential with respect to themselves is +1. - * - * @param parameters number of free parameters. - * @param order derivation order. - * @param index index of the variable (from 0 to `parameters - 1`). - * @param value value of the variable. - */ - public constructor ( - derivativeAlgebra: DerivativeStructureRing, - parameters: Int, - order: Int, - index: Int, - value: T, - ) : this(derivativeAlgebra, parameters, order, value) { - require(index < parameters) { "number is too large: $index >= $parameters" } - - if (order > 0) { - // the derivative of the variable with respect to itself is 1. - data[getCompiler(derivativeAlgebra.algebra, derivativeAlgebra.bufferFactory, index, order).size] = - derivativeAlgebra.algebra.one - } - } - - /** - * Build an instance from all its derivatives. - * - * @param parameters number of free parameters. - * @param order derivation order. - * @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex]. - */ - public constructor ( - derivativeAlgebra: DerivativeStructureRing, - parameters: Int, - order: Int, - vararg derivatives: T, - ) : this( - derivativeAlgebra, - parameters, - order, - ) { - require(derivatives.size == data.size) { "dimension mismatch: ${derivatives.size} and ${data.size}" } - data = derivativeAlgebra.bufferFactory(data.size) { derivatives[it] } - } - - /** - * Copy constructor. - * - * @param ds instance to copy. - */ - internal constructor(ds: DerivativeStructure) : this(ds.derivativeAlgebra, ds.compiler) { - this.data = ds.data.copy() - } + public val compiler: DSCompiler get() = derivativeAlgebra.compiler /** * The number of free parameters. */ - public val freeParameters: Int - get() = compiler.freeParameters + public val freeParameters: Int get() = compiler.freeParameters /** * The derivation order. */ - public val order: Int - get() = compiler.order + public val order: Int get() = compiler.order /** * The value part of the derivative structure. * * @see getPartialDerivative */ - public val value: T - get() = data[0] + public val value: T get() = data[0] /** * Get a partial derivative. @@ -183,4 +83,75 @@ public open class DerivativeStructure internal constructor( public override fun hashCode(): Int = 227 + 229 * freeParameters + 233 * order + 239 * data.hashCode() + + public companion object { + + /** + * Build an instance representing a variable. + * + * Instances built using this constructor are considered to be the free variables with respect to which + * differentials are computed. As such, their differential with respect to themselves is +1. + */ + public fun > variable( + derivativeAlgebra: DerivativeStructureAlgebra, + index: Int, + value: T, + ): DerivativeStructure { + val compiler = derivativeAlgebra.compiler + require(index < compiler.freeParameters) { "number is too large: $index >= ${compiler.freeParameters}" } + return DerivativeStructure(derivativeAlgebra, derivativeAlgebra.bufferForVariable(index, value)) + } + + /** + * Build an instance from all its derivatives. + * + * @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex]. + */ + public fun > ofDerivatives( + derivativeAlgebra: DerivativeStructureAlgebra, + vararg derivatives: T, + ): DerivativeStructure { + val compiler = derivativeAlgebra.compiler + require(derivatives.size == compiler.size) { "dimension mismatch: ${derivatives.size} and ${compiler.size}" } + val data = derivatives.asBuffer() + + return DerivativeStructure( + derivativeAlgebra, + data + ) + } + } +} + +@OptIn(UnstableKMathAPI::class) +private fun > DerivativeStructureAlgebra.bufferForVariable(index: Int, value: T): Buffer { + val buffer = bufferFactory(compiler.size) { algebra.zero } + buffer[0] = value + if (compiler.order > 0) { + // the derivative of the variable with respect to itself is 1. + + val indexOfDerivative = compiler.getPartialDerivativeIndex(*IntArray(numberOfVariables).apply { + set(index, 1) + }) + + buffer[indexOfDerivative] = algebra.one + } + return buffer +} + +/** + * A class implementing both [DerivativeStructure] and [Symbol]. + */ +@UnstableKMathAPI +public class DerivativeStructureSymbol> internal constructor( + derivativeAlgebra: DerivativeStructureAlgebra, + index: Int, + symbol: Symbol, + value: T, +) : Symbol by symbol, DerivativeStructure( + derivativeAlgebra, derivativeAlgebra.bufferForVariable(index, value) +) { + override fun toString(): String = symbol.toString() + override fun equals(other: Any?): Boolean = (other as? Symbol) == symbol + override fun hashCode(): Int = symbol.hashCode() } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt index f91fb55e8..638057921 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt @@ -7,83 +7,89 @@ package space.kscience.kmath.expressions import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.* +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBufferFactory -import space.kscience.kmath.structures.indices +import kotlin.math.max +import kotlin.math.min -/** - * A class implementing both [DerivativeStructure] and [Symbol]. - */ @UnstableKMathAPI -public class DerivativeStructureSymbol( - derivativeAlgebra: DerivativeStructureRing, - size: Int, - order: Int, - index: Int, - symbol: Symbol, - value: T, -) : Symbol by symbol, DerivativeStructure( - derivativeAlgebra, - size, - order, - index, - value -) where A : Ring, A : NumericAlgebra, A : ScaleOperations { - override fun toString(): String = symbol.toString() - override fun equals(other: Any?): Boolean = (other as? Symbol) == symbol - override fun hashCode(): Int = symbol.hashCode() -} - -/** - * A ring over [DerivativeStructure]. - * - * @property order The derivation order. - * @param bindings The map of bindings values. All bindings are considered free parameters. - */ -@UnstableKMathAPI -public open class DerivativeStructureRing( +public abstract class DerivativeStructureAlgebra>( public val algebra: A, public val bufferFactory: MutableBufferFactory, public val order: Int, bindings: Map, -) : Ring>, ScaleOperations>, - NumericAlgebra>, - ExpressionAlgebra>, - NumbersAddOps> where A : Ring, A : NumericAlgebra, A : ScaleOperations { +) : ExpressionAlgebra> { + public val numberOfVariables: Int = bindings.size - override val zero: DerivativeStructure by lazy { - DerivativeStructure( - this, - numberOfVariables, - order, - ) - } - override val one: DerivativeStructure by lazy { - DerivativeStructure( - this, - numberOfVariables, - order, - algebra.one, - ) - } + /** + * Get the compiler for number of free parameters and order. + * + * @return cached rules set. + */ + @PublishedApi + internal val compiler: DSCompiler by lazy { + // get the cached compilers + val cache: Array?>>? = null - override fun number(value: Number): DerivativeStructure = const(algebra.number(value)) + // we need to create more compilers + val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0) + val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size) + val newCache: Array?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) } + + if (cache != null) { + // preserve the already created compilers + for (i in cache.indices) { + cache[i].copyInto(newCache[i], endIndex = cache[i].size) + } + } + + // create the array in increasing diagonal order + for (diag in 0..numberOfVariables + order) { + for (o in max(0, diag - numberOfVariables)..min(order, diag)) { + val p: Int = diag - o + if (newCache[p][o] == null) { + val valueCompiler: DSCompiler? = if (p == 0) null else newCache[p - 1][o]!! + val derivativeCompiler: DSCompiler? = if (o == 0) null else newCache[p][o - 1]!! + + newCache[p][o] = DSCompiler( + algebra, + bufferFactory, + p, + o, + valueCompiler, + derivativeCompiler, + ) + } + } + } + + return@lazy newCache[numberOfVariables][order]!! + } private val variables: Map> = bindings.entries.mapIndexed { index, (key, value) -> key to DerivativeStructureSymbol( this, - numberOfVariables, - order, index, key, value, ) }.toMap() - public override fun const(value: T): DerivativeStructure = - DerivativeStructure(this, numberOfVariables, order, value) + + + public override fun const(value: T): DerivativeStructure { + val buffer = bufferFactory(compiler.size) { algebra.zero } + buffer[0] = value + + return DerivativeStructure( + this, + buffer + ) + } override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = variables[StringSymbol(value)] @@ -103,54 +109,99 @@ public open class DerivativeStructureRing( public fun DerivativeStructure.derivative(vararg symbols: Symbol): T = derivative(symbols.toList()) +} + + +/** + * A ring over [DerivativeStructure]. + * + * @property order The derivation order. + * @param bindings The map of bindings values. All bindings are considered free parameters. + */ +@UnstableKMathAPI +public open class DerivativeStructureRing( + algebra: A, + bufferFactory: MutableBufferFactory, + order: Int, + bindings: Map, +) : DerivativeStructureAlgebra(algebra, bufferFactory, order, bindings), + Ring>, ScaleOperations>, + NumericAlgebra>, + NumbersAddOps> where A : Ring, A : NumericAlgebra, A : ScaleOperations { + + override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = + super.bindSymbolOrNull(value) + override fun DerivativeStructure.unaryMinus(): DerivativeStructure { - val ds = DerivativeStructure(this@DerivativeStructureRing, compiler) - for (i in ds.data.indices) { - ds.data[i] = algebra { -data[i] } - } - return ds + val newData = algebra { data.map(bufferFactory) { -it } } + return DerivativeStructure(this@DerivativeStructureRing, newData) } + /** + * Create a copy of given [Buffer] and modify it according to [block] + */ + protected inline fun DerivativeStructure.transformDataBuffer(block: DSCompiler.(MutableBuffer) -> Unit): DerivativeStructure { + val newData = bufferFactory(compiler.size) { data[it] } + compiler.block(newData) + return DerivativeStructure(this@DerivativeStructureRing, newData) + } + + protected fun DerivativeStructure.mapData(block: (T) -> T): DerivativeStructure { + val newData: Buffer = data.map(bufferFactory, block) + return DerivativeStructure(this@DerivativeStructureRing, newData) + } + + protected fun DerivativeStructure.mapDataIndexed(block: (Int, T) -> T): DerivativeStructure { + val newData: Buffer = data.mapIndexed(bufferFactory, block) + return DerivativeStructure(this@DerivativeStructureRing, newData) + } + + override val zero: DerivativeStructure by lazy { + const(algebra.zero) + } + + override val one: DerivativeStructure by lazy { + const(algebra.one) + } + + override fun number(value: Number): DerivativeStructure = const(algebra.number(value)) + override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure { left.compiler.checkCompatibility(right.compiler) - val ds = DerivativeStructure(left) - left.compiler.add(left.data, 0, right.data, 0, ds.data, 0) - return ds + return left.transformDataBuffer { result -> + add(left.data, 0, right.data, 0, result, 0) + } } - override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure { - val ds = DerivativeStructure(a) - for (i in ds.data.indices) { - ds.data[i] = algebra { ds.data[i].times(value) } - } - return ds + override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = algebra { + a.mapData { it.times(value) } } override fun multiply( left: DerivativeStructure, - right: DerivativeStructure + right: DerivativeStructure, ): DerivativeStructure { left.compiler.checkCompatibility(right.compiler) - val result = DerivativeStructure(this, left.compiler) - left.compiler.multiply(left.data, 0, right.data, 0, result.data, 0) - return result + return left.transformDataBuffer { result -> + multiply(left.data, 0, right.data, 0, result, 0) + } } override fun DerivativeStructure.minus(arg: DerivativeStructure): DerivativeStructure { compiler.checkCompatibility(arg.compiler) - val ds = DerivativeStructure(this) - compiler.subtract(data, 0, arg.data, 0, ds.data, 0) - return ds + return transformDataBuffer { result -> + subtract(data, 0, arg.data, 0, result, 0) + } } - override operator fun DerivativeStructure.plus(other: Number): DerivativeStructure { - val ds = DerivativeStructure(this) - ds.data[0] = algebra { ds.data[0] + number(other) } - return ds + override operator fun DerivativeStructure.plus(other: Number): DerivativeStructure = algebra { + transformDataBuffer { + it[0] += number(other) + } } override operator fun DerivativeStructure.minus(other: Number): DerivativeStructure = - this + -other.toDouble() + this + (-other.toDouble()) override operator fun Number.plus(other: DerivativeStructure): DerivativeStructure = other + this override operator fun Number.minus(other: DerivativeStructure): DerivativeStructure = other - this @@ -194,119 +245,85 @@ public class DerivativeStructureField>( override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure { left.compiler.checkCompatibility(right.compiler) - val result = DerivativeStructure(this, left.compiler) - left.compiler.divide(left.data, 0, right.data, 0, result.data, 0) - return result + return left.transformDataBuffer { result -> + left.compiler.divide(left.data, 0, right.data, 0, result, 0) + } } - override fun sin(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.sin(arg.data, 0, result.data, 0) - return result + override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + sin(arg.data, 0, result, 0) } - override fun cos(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.cos(arg.data, 0, result.data, 0) - return result + override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + cos(arg.data, 0, result, 0) } - override fun tan(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.tan(arg.data, 0, result.data, 0) - return result + override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + tan(arg.data, 0, result, 0) } - override fun asin(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.asin(arg.data, 0, result.data, 0) - return result + override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + asin(arg.data, 0, result, 0) } - override fun acos(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.acos(arg.data, 0, result.data, 0) - return result + override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + acos(arg.data, 0, result, 0) } - override fun atan(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.atan(arg.data, 0, result.data, 0) - return result + override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + atan(arg.data, 0, result, 0) } - override fun sinh(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.sinh(arg.data, 0, result.data, 0) - return result + override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + sinh(arg.data, 0, result, 0) } - override fun cosh(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.cosh(arg.data, 0, result.data, 0) - return result + override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + cosh(arg.data, 0, result, 0) } - override fun tanh(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.tanh(arg.data, 0, result.data, 0) - return result + override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + tanh(arg.data, 0, result, 0) } - override fun asinh(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.asinh(arg.data, 0, result.data, 0) - return result + override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + asinh(arg.data, 0, result, 0) } - override fun acosh(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.acosh(arg.data, 0, result.data, 0) - return result + override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + acosh(arg.data, 0, result, 0) } - override fun atanh(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.atanh(arg.data, 0, result.data, 0) - return result + override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + atanh(arg.data, 0, result, 0) } override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { - is Int -> { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.pow(arg.data, 0, pow, result.data, 0) - result + is Int -> arg.transformDataBuffer { result -> + pow(arg.data, 0, pow, result, 0) } - else -> { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.pow(arg.data, 0, pow.toDouble(), result.data, 0) - result + else -> arg.transformDataBuffer { result -> + pow(arg.data, 0, pow.toDouble(), result, 0) } } - override fun sqrt(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.sqrt(arg.data, 0, result.data, 0) - return result + override fun sqrt(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + sqrt(arg.data, 0, result, 0) } public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure { arg.compiler.checkCompatibility(pow.compiler) - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.pow(arg.data, 0, pow.data, 0, result.data, 0) - return result + return arg.transformDataBuffer { result -> + pow(arg.data, 0, pow.data, 0, result, 0) + } } - override fun exp(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.exp(arg.data, 0, result.data, 0) - return result + override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + exp(arg.data, 0, result, 0) } - override fun ln(arg: DerivativeStructure): DerivativeStructure { - val result = DerivativeStructure(this, arg.compiler) - arg.compiler.ln(arg.data, 0, result.data, 0) - return result + override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> + ln(arg.data, 0, result, 0) } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt index 715fad07b..10438dd02 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt @@ -188,7 +188,7 @@ public interface LinearSpace> { */ public fun > buffered( algebra: A, - bufferFactory: BufferFactory = Buffer.Companion::boxing, + bufferFactory: BufferFactory = BufferFactory(Buffer.Companion::boxing), ): LinearSpace = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory)) @Deprecated("use DoubleField.linearSpace") diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/annotations.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/annotations.kt index 7c612b6a9..60fa81cd8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/annotations.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/misc/annotations.kt @@ -27,5 +27,5 @@ public annotation class UnstableKMathAPI RequiresOptIn.Level.WARNING, ) public annotation class PerformancePitfall( - val message: String = "Potential performance problem" + val message: String = "Potential performance problem", ) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt index 2401f6319..8175bd65e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt @@ -69,7 +69,7 @@ public class MutableBufferND( * Transform structure to a new structure using provided [MutableBufferFactory] and optimizing if argument is [MutableBufferND] */ public inline fun MutableStructureND.mapToMutableBuffer( - factory: MutableBufferFactory = MutableBuffer.Companion::auto, + factory: MutableBufferFactory = MutableBufferFactory(MutableBuffer.Companion::auto), crossinline transform: (T) -> R, ): MutableBufferND { return if (this is MutableBufferND) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt index e934c6370..6e54e1b9d 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt @@ -120,7 +120,7 @@ public interface StructureND : Featured, WithShape { */ public fun buffered( strides: Strides, - bufferFactory: BufferFactory = Buffer.Companion::boxing, + bufferFactory: BufferFactory = BufferFactory(Buffer.Companion::boxing), initializer: (IntArray) -> T, ): BufferND = BufferND(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) @@ -140,7 +140,7 @@ public interface StructureND : Featured, WithShape { public fun buffered( shape: IntArray, - bufferFactory: BufferFactory = Buffer.Companion::boxing, + bufferFactory: BufferFactory = BufferFactory(Buffer.Companion::boxing), initializer: (IntArray) -> T, ): BufferND = buffered(DefaultStrides(shape), bufferFactory, initializer) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt index 0ee591acc..083892105 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt @@ -6,12 +6,10 @@ package space.kscience.kmath.operations import space.kscience.kmath.linear.Point -import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.asBuffer - import kotlin.math.* /** @@ -21,7 +19,7 @@ public abstract class DoubleBufferOps : BufferAlgebra, Exte Norm, Double> { override val elementAlgebra: DoubleField get() = DoubleField - override val bufferFactory: BufferFactory get() = ::DoubleBuffer + override val bufferFactory: BufferFactory get() = BufferFactory(::DoubleBuffer) override fun Buffer.map(block: DoubleField.(Double) -> Double): DoubleBuffer = mapInline { DoubleField.block(it) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/bufferOperation.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/bufferOperation.kt index 31b0c2841..652472fcf 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/bufferOperation.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/bufferOperation.kt @@ -61,31 +61,39 @@ public inline fun Buffer.toTypedArray(): Array = Array(size, : /** * Create a new buffer from this one with the given mapping function and using [Buffer.Companion.auto] buffer factory. */ -public inline fun Buffer.map(block: (T) -> R): Buffer = +public inline fun Buffer.map(block: (T) -> R): Buffer = Buffer.auto(size) { block(get(it)) } /** * Create a new buffer from this one with the given mapping function. * Provided [bufferFactory] is used to construct the new buffer. */ -public inline fun Buffer.map( +public inline fun Buffer.map( bufferFactory: BufferFactory, crossinline block: (T) -> R, ): Buffer = bufferFactory(size) { block(get(it)) } /** - * Create a new buffer from this one with the given indexed mapping function. - * Provided [BufferFactory] is used to construct the new buffer. + * Create a new buffer from this one with the given mapping (indexed) function. + * Provided [bufferFactory] is used to construct the new buffer. */ -public inline fun Buffer.mapIndexed( - bufferFactory: BufferFactory = Buffer.Companion::auto, +public inline fun Buffer.mapIndexed( + bufferFactory: BufferFactory, crossinline block: (index: Int, value: T) -> R, ): Buffer = bufferFactory(size) { block(it, get(it)) } +/** + * Create a new buffer from this one with the given indexed mapping function. + * Provided [BufferFactory] is used to construct the new buffer. + */ +public inline fun Buffer.mapIndexed( + crossinline block: (index: Int, value: T) -> R, +): Buffer = BufferFactory(Buffer.Companion::auto).invoke(size) { block(it, get(it)) } + /** * Fold given buffer according to [operation] */ -public inline fun Buffer.fold(initial: R, operation: (acc: R, T) -> R): R { +public inline fun Buffer.fold(initial: R, operation: (acc: R, T) -> R): R { var accumulator = initial for (index in this.indices) accumulator = operation(accumulator, get(index)) return accumulator @@ -95,9 +103,9 @@ public inline fun Buffer.fold(initial: R, operation: (acc: R, T) * Zip two buffers using given [transform]. */ @UnstableKMathAPI -public inline fun Buffer.zip( +public inline fun Buffer.zip( other: Buffer, - bufferFactory: BufferFactory = Buffer.Companion::auto, + bufferFactory: BufferFactory = BufferFactory(Buffer.Companion::auto), crossinline transform: (T1, T2) -> R, ): Buffer { require(size == other.size) { "Buffer size mismatch in zip: expected $size but found ${other.size}" } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt index a1b0307c4..1c79c257a 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/structures/Buffer.kt @@ -14,14 +14,18 @@ import kotlin.reflect.KClass * * @param T the type of buffer. */ -public typealias BufferFactory = (Int, (Int) -> T) -> Buffer +public fun interface BufferFactory { + public operator fun invoke(size: Int, builder: (Int) -> T): Buffer +} /** * Function that produces [MutableBuffer] from its size and function that supplies values. * * @param T the type of buffer. */ -public typealias MutableBufferFactory = (Int, (Int) -> T) -> MutableBuffer +public fun interface MutableBufferFactory: BufferFactory{ + override fun invoke(size: Int, builder: (Int) -> T): MutableBuffer +} /** * A generic read-only random-access structure for both primitives and objects. diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt deleted file mode 100644 index aac327a84..000000000 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt +++ /dev/null @@ -1,1003 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. - */ - -/* This file is generated with buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt */ - -package space.kscience.kmath.ejml - -import org.ejml.data.* -import org.ejml.dense.row.CommonOps_DDRM -import org.ejml.dense.row.CommonOps_FDRM -import org.ejml.dense.row.factory.DecompositionFactory_DDRM -import org.ejml.dense.row.factory.DecompositionFactory_FDRM -import org.ejml.sparse.FillReducing -import org.ejml.sparse.csc.CommonOps_DSCC -import org.ejml.sparse.csc.CommonOps_FSCC -import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC -import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC -import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC -import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC -import space.kscience.kmath.linear.* -import space.kscience.kmath.linear.Matrix -import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.nd.StructureFeature -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.FloatField -import space.kscience.kmath.operations.invoke -import space.kscience.kmath.structures.DoubleBuffer -import space.kscience.kmath.structures.FloatBuffer -import kotlin.reflect.KClass -import kotlin.reflect.cast - -/** - * [EjmlVector] specialization for [Double]. - */ -public class EjmlDoubleVector(override val origin: M) : EjmlVector(origin) { - init { - require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } - } - - override operator fun get(index: Int): Double = origin[0, index] -} - -/** - * [EjmlVector] specialization for [Float]. - */ -public class EjmlFloatVector(override val origin: M) : EjmlVector(origin) { - init { - require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } - } - - override operator fun get(index: Int): Float = origin[0, index] -} - -/** - * [EjmlMatrix] specialization for [Double]. - */ -public class EjmlDoubleMatrix(override val origin: M) : EjmlMatrix(origin) { - override operator fun get(i: Int, j: Int): Double = origin[i, j] -} - -/** - * [EjmlMatrix] specialization for [Float]. - */ -public class EjmlFloatMatrix(override val origin: M) : EjmlMatrix(origin) { - override operator fun get(i: Int, j: Int): Float = origin[i, j] -} - -/** - * [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and - * [DMatrixRMaj] matrices. - */ -public object EjmlLinearSpaceDDRM : EjmlLinearSpace() { - /** - * The [DoubleField] reference. - */ - override val elementAlgebra: DoubleField get() = DoubleField - - @Suppress("UNCHECKED_CAST") - override fun Matrix.toEjml(): EjmlDoubleMatrix = when { - this is EjmlDoubleMatrix<*> && origin is DMatrixRMaj -> this as EjmlDoubleMatrix - else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } - } - - @Suppress("UNCHECKED_CAST") - override fun Point.toEjml(): EjmlDoubleVector = when { - this is EjmlDoubleVector<*> && origin is DMatrixRMaj -> this as EjmlDoubleVector - else -> EjmlDoubleVector(DMatrixRMaj(size, 1).also { - (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } - }) - } - - override fun buildMatrix( - rows: Int, - columns: Int, - initializer: DoubleField.(i: Int, j: Int) -> Double, - ): EjmlDoubleMatrix = DMatrixRMaj(rows, columns).also { - (0 until rows).forEach { row -> - (0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) } - } - }.wrapMatrix() - - override fun buildVector( - size: Int, - initializer: DoubleField.(Int) -> Double, - ): EjmlDoubleVector = EjmlDoubleVector(DMatrixRMaj(size, 1).also { - (0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) } - }) - - private fun T.wrapMatrix() = EjmlDoubleMatrix(this) - private fun T.wrapVector() = EjmlDoubleVector(this) - - override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } - - override fun Matrix.dot(other: Matrix): EjmlDoubleMatrix { - val out = DMatrixRMaj(1, 1) - CommonOps_DDRM.mult(toEjml().origin, other.toEjml().origin, out) - return out.wrapMatrix() - } - - override fun Matrix.dot(vector: Point): EjmlDoubleVector { - val out = DMatrixRMaj(1, 1) - CommonOps_DDRM.mult(toEjml().origin, vector.toEjml().origin, out) - return out.wrapVector() - } - - override operator fun Matrix.minus(other: Matrix): EjmlDoubleMatrix { - val out = DMatrixRMaj(1, 1) - - CommonOps_DDRM.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra { -one }, - other.toEjml().origin, - out, - ) - - return out.wrapMatrix() - } - - override operator fun Matrix.times(value: Double): EjmlDoubleMatrix { - val res = DMatrixRMaj(1, 1) - CommonOps_DDRM.scale(value, toEjml().origin, res) - return res.wrapMatrix() - } - - override fun Point.unaryMinus(): EjmlDoubleVector { - val res = DMatrixRMaj(1, 1) - CommonOps_DDRM.changeSign(toEjml().origin, res) - return res.wrapVector() - } - - override fun Matrix.plus(other: Matrix): EjmlDoubleMatrix { - val out = DMatrixRMaj(1, 1) - - CommonOps_DDRM.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra.one, - other.toEjml().origin, - out, - ) - - return out.wrapMatrix() - } - - override fun Point.plus(other: Point): EjmlDoubleVector { - val out = DMatrixRMaj(1, 1) - - CommonOps_DDRM.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra.one, - other.toEjml().origin, - out, - ) - - return out.wrapVector() - } - - override fun Point.minus(other: Point): EjmlDoubleVector { - val out = DMatrixRMaj(1, 1) - - CommonOps_DDRM.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra { -one }, - other.toEjml().origin, - out, - ) - - return out.wrapVector() - } - - override fun Double.times(m: Matrix): EjmlDoubleMatrix = m * this - - override fun Point.times(value: Double): EjmlDoubleVector { - val res = DMatrixRMaj(1, 1) - CommonOps_DDRM.scale(value, toEjml().origin, res) - return res.wrapVector() - } - - override fun Double.times(v: Point): EjmlDoubleVector = v * this - - @UnstableKMathAPI - override fun computeFeature(structure: Matrix, type: KClass): F? { - structure.getFeature(type)?.let { return it } - val origin = structure.toEjml().origin - - return when (type) { - InverseMatrixFeature::class -> object : InverseMatrixFeature { - override val inverse: Matrix by lazy { - val res = origin.copy() - CommonOps_DDRM.invert(res) - res.wrapMatrix() - } - } - - DeterminantFeature::class -> object : DeterminantFeature { - override val determinant: Double by lazy { CommonOps_DDRM.det(origin) } - } - - SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature { - private val svd by lazy { - DecompositionFactory_DDRM.svd(origin.numRows, origin.numCols, true, true, false) - .apply { decompose(origin.copy()) } - } - - override val u: Matrix by lazy { svd.getU(null, false).wrapMatrix() } - override val s: Matrix by lazy { svd.getW(null).wrapMatrix() } - override val v: Matrix by lazy { svd.getV(null, false).wrapMatrix() } - override val singularValues: Point by lazy { DoubleBuffer(svd.singularValues) } - } - - QRDecompositionFeature::class -> object : QRDecompositionFeature { - private val qr by lazy { - DecompositionFactory_DDRM.qr().apply { decompose(origin.copy()) } - } - - override val q: Matrix by lazy { - qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) - } - - override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } - } - - CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { - override val l: Matrix by lazy { - val cholesky = - DecompositionFactory_DDRM.chol(structure.rowNum, true).apply { decompose(origin.copy()) } - - cholesky.getT(null).wrapMatrix().withFeature(LFeature) - } - } - - LupDecompositionFeature::class -> object : LupDecompositionFeature { - private val lup by lazy { - DecompositionFactory_DDRM.lu(origin.numRows, origin.numCols).apply { decompose(origin.copy()) } - } - - override val l: Matrix by lazy { - lup.getLower(null).wrapMatrix().withFeature(LFeature) - } - - override val u: Matrix by lazy { - lup.getUpper(null).wrapMatrix().withFeature(UFeature) - } - - override val p: Matrix by lazy { lup.getRowPivot(null).wrapMatrix() } - } - - else -> null - }?.let{ - type.cast(it) - } - } - - /** - * Solves for *x* in the following equation: *x = [a] -1 · [b]*. - * - * @param a the base matrix. - * @param b n by p matrix. - * @return the solution for *x* that is n by p. - */ - public fun solve(a: Matrix, b: Matrix): EjmlDoubleMatrix { - val res = DMatrixRMaj(1, 1) - CommonOps_DDRM.solve(DMatrixRMaj(a.toEjml().origin), DMatrixRMaj(b.toEjml().origin), res) - return res.wrapMatrix() - } - - /** - * Solves for *x* in the following equation: *x = [a] -1 · [b]*. - * - * @param a the base matrix. - * @param b n by p vector. - * @return the solution for *x* that is n by p. - */ - public fun solve(a: Matrix, b: Point): EjmlDoubleVector { - val res = DMatrixRMaj(1, 1) - CommonOps_DDRM.solve(DMatrixRMaj(a.toEjml().origin), DMatrixRMaj(b.toEjml().origin), res) - return EjmlDoubleVector(res) - } -} - -/** - * [EjmlLinearSpace] implementation based on [CommonOps_FDRM], [DecompositionFactory_FDRM] operations and - * [FMatrixRMaj] matrices. - */ -public object EjmlLinearSpaceFDRM : EjmlLinearSpace() { - /** - * The [FloatField] reference. - */ - override val elementAlgebra: FloatField get() = FloatField - - @Suppress("UNCHECKED_CAST") - override fun Matrix.toEjml(): EjmlFloatMatrix = when { - this is EjmlFloatMatrix<*> && origin is FMatrixRMaj -> this as EjmlFloatMatrix - else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } - } - - @Suppress("UNCHECKED_CAST") - override fun Point.toEjml(): EjmlFloatVector = when { - this is EjmlFloatVector<*> && origin is FMatrixRMaj -> this as EjmlFloatVector - else -> EjmlFloatVector(FMatrixRMaj(size, 1).also { - (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } - }) - } - - override fun buildMatrix( - rows: Int, - columns: Int, - initializer: FloatField.(i: Int, j: Int) -> Float, - ): EjmlFloatMatrix = FMatrixRMaj(rows, columns).also { - (0 until rows).forEach { row -> - (0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) } - } - }.wrapMatrix() - - override fun buildVector( - size: Int, - initializer: FloatField.(Int) -> Float, - ): EjmlFloatVector = EjmlFloatVector(FMatrixRMaj(size, 1).also { - (0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) } - }) - - private fun T.wrapMatrix() = EjmlFloatMatrix(this) - private fun T.wrapVector() = EjmlFloatVector(this) - - override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } - - override fun Matrix.dot(other: Matrix): EjmlFloatMatrix { - val out = FMatrixRMaj(1, 1) - CommonOps_FDRM.mult(toEjml().origin, other.toEjml().origin, out) - return out.wrapMatrix() - } - - override fun Matrix.dot(vector: Point): EjmlFloatVector { - val out = FMatrixRMaj(1, 1) - CommonOps_FDRM.mult(toEjml().origin, vector.toEjml().origin, out) - return out.wrapVector() - } - - override operator fun Matrix.minus(other: Matrix): EjmlFloatMatrix { - val out = FMatrixRMaj(1, 1) - - CommonOps_FDRM.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra { -one }, - other.toEjml().origin, - out, - ) - - return out.wrapMatrix() - } - - override operator fun Matrix.times(value: Float): EjmlFloatMatrix { - val res = FMatrixRMaj(1, 1) - CommonOps_FDRM.scale(value, toEjml().origin, res) - return res.wrapMatrix() - } - - override fun Point.unaryMinus(): EjmlFloatVector { - val res = FMatrixRMaj(1, 1) - CommonOps_FDRM.changeSign(toEjml().origin, res) - return res.wrapVector() - } - - override fun Matrix.plus(other: Matrix): EjmlFloatMatrix { - val out = FMatrixRMaj(1, 1) - - CommonOps_FDRM.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra.one, - other.toEjml().origin, - out, - ) - - return out.wrapMatrix() - } - - override fun Point.plus(other: Point): EjmlFloatVector { - val out = FMatrixRMaj(1, 1) - - CommonOps_FDRM.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra.one, - other.toEjml().origin, - out, - ) - - return out.wrapVector() - } - - override fun Point.minus(other: Point): EjmlFloatVector { - val out = FMatrixRMaj(1, 1) - - CommonOps_FDRM.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra { -one }, - other.toEjml().origin, - out, - ) - - return out.wrapVector() - } - - override fun Float.times(m: Matrix): EjmlFloatMatrix = m * this - - override fun Point.times(value: Float): EjmlFloatVector { - val res = FMatrixRMaj(1, 1) - CommonOps_FDRM.scale(value, toEjml().origin, res) - return res.wrapVector() - } - - override fun Float.times(v: Point): EjmlFloatVector = v * this - - @UnstableKMathAPI - override fun computeFeature(structure: Matrix, type: KClass): F? { - structure.getFeature(type)?.let { return it } - val origin = structure.toEjml().origin - - return when (type) { - InverseMatrixFeature::class -> object : InverseMatrixFeature { - override val inverse: Matrix by lazy { - val res = origin.copy() - CommonOps_FDRM.invert(res) - res.wrapMatrix() - } - } - - DeterminantFeature::class -> object : DeterminantFeature { - override val determinant: Float by lazy { CommonOps_FDRM.det(origin) } - } - - SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature { - private val svd by lazy { - DecompositionFactory_FDRM.svd(origin.numRows, origin.numCols, true, true, false) - .apply { decompose(origin.copy()) } - } - - override val u: Matrix by lazy { svd.getU(null, false).wrapMatrix() } - override val s: Matrix by lazy { svd.getW(null).wrapMatrix() } - override val v: Matrix by lazy { svd.getV(null, false).wrapMatrix() } - override val singularValues: Point by lazy { FloatBuffer(svd.singularValues) } - } - - QRDecompositionFeature::class -> object : QRDecompositionFeature { - private val qr by lazy { - DecompositionFactory_FDRM.qr().apply { decompose(origin.copy()) } - } - - override val q: Matrix by lazy { - qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) - } - - override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } - } - - CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { - override val l: Matrix by lazy { - val cholesky = - DecompositionFactory_FDRM.chol(structure.rowNum, true).apply { decompose(origin.copy()) } - - cholesky.getT(null).wrapMatrix().withFeature(LFeature) - } - } - - LupDecompositionFeature::class -> object : LupDecompositionFeature { - private val lup by lazy { - DecompositionFactory_FDRM.lu(origin.numRows, origin.numCols).apply { decompose(origin.copy()) } - } - - override val l: Matrix by lazy { - lup.getLower(null).wrapMatrix().withFeature(LFeature) - } - - override val u: Matrix by lazy { - lup.getUpper(null).wrapMatrix().withFeature(UFeature) - } - - override val p: Matrix by lazy { lup.getRowPivot(null).wrapMatrix() } - } - - else -> null - }?.let{ - type.cast(it) - } - } - - /** - * Solves for *x* in the following equation: *x = [a] -1 · [b]*. - * - * @param a the base matrix. - * @param b n by p matrix. - * @return the solution for *x* that is n by p. - */ - public fun solve(a: Matrix, b: Matrix): EjmlFloatMatrix { - val res = FMatrixRMaj(1, 1) - CommonOps_FDRM.solve(FMatrixRMaj(a.toEjml().origin), FMatrixRMaj(b.toEjml().origin), res) - return res.wrapMatrix() - } - - /** - * Solves for *x* in the following equation: *x = [a] -1 · [b]*. - * - * @param a the base matrix. - * @param b n by p vector. - * @return the solution for *x* that is n by p. - */ - public fun solve(a: Matrix, b: Point): EjmlFloatVector { - val res = FMatrixRMaj(1, 1) - CommonOps_FDRM.solve(FMatrixRMaj(a.toEjml().origin), FMatrixRMaj(b.toEjml().origin), res) - return EjmlFloatVector(res) - } -} - -/** - * [EjmlLinearSpace] implementation based on [CommonOps_DSCC], [DecompositionFactory_DSCC] operations and - * [DMatrixSparseCSC] matrices. - */ -public object EjmlLinearSpaceDSCC : EjmlLinearSpace() { - /** - * The [DoubleField] reference. - */ - override val elementAlgebra: DoubleField get() = DoubleField - - @Suppress("UNCHECKED_CAST") - override fun Matrix.toEjml(): EjmlDoubleMatrix = when { - this is EjmlDoubleMatrix<*> && origin is DMatrixSparseCSC -> this as EjmlDoubleMatrix - else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } - } - - @Suppress("UNCHECKED_CAST") - override fun Point.toEjml(): EjmlDoubleVector = when { - this is EjmlDoubleVector<*> && origin is DMatrixSparseCSC -> this as EjmlDoubleVector - else -> EjmlDoubleVector(DMatrixSparseCSC(size, 1).also { - (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } - }) - } - - override fun buildMatrix( - rows: Int, - columns: Int, - initializer: DoubleField.(i: Int, j: Int) -> Double, - ): EjmlDoubleMatrix = DMatrixSparseCSC(rows, columns).also { - (0 until rows).forEach { row -> - (0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) } - } - }.wrapMatrix() - - override fun buildVector( - size: Int, - initializer: DoubleField.(Int) -> Double, - ): EjmlDoubleVector = EjmlDoubleVector(DMatrixSparseCSC(size, 1).also { - (0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) } - }) - - private fun T.wrapMatrix() = EjmlDoubleMatrix(this) - private fun T.wrapVector() = EjmlDoubleVector(this) - - override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } - - override fun Matrix.dot(other: Matrix): EjmlDoubleMatrix { - val out = DMatrixSparseCSC(1, 1) - CommonOps_DSCC.mult(toEjml().origin, other.toEjml().origin, out) - return out.wrapMatrix() - } - - override fun Matrix.dot(vector: Point): EjmlDoubleVector { - val out = DMatrixSparseCSC(1, 1) - CommonOps_DSCC.mult(toEjml().origin, vector.toEjml().origin, out) - return out.wrapVector() - } - - override operator fun Matrix.minus(other: Matrix): EjmlDoubleMatrix { - val out = DMatrixSparseCSC(1, 1) - - CommonOps_DSCC.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra { -one }, - other.toEjml().origin, - out, - null, - null, - ) - - return out.wrapMatrix() - } - - override operator fun Matrix.times(value: Double): EjmlDoubleMatrix { - val res = DMatrixSparseCSC(1, 1) - CommonOps_DSCC.scale(value, toEjml().origin, res) - return res.wrapMatrix() - } - - override fun Point.unaryMinus(): EjmlDoubleVector { - val res = DMatrixSparseCSC(1, 1) - CommonOps_DSCC.changeSign(toEjml().origin, res) - return res.wrapVector() - } - - override fun Matrix.plus(other: Matrix): EjmlDoubleMatrix { - val out = DMatrixSparseCSC(1, 1) - - CommonOps_DSCC.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra.one, - other.toEjml().origin, - out, - null, - null, - ) - - return out.wrapMatrix() - } - - override fun Point.plus(other: Point): EjmlDoubleVector { - val out = DMatrixSparseCSC(1, 1) - - CommonOps_DSCC.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra.one, - other.toEjml().origin, - out, - null, - null, - ) - - return out.wrapVector() - } - - override fun Point.minus(other: Point): EjmlDoubleVector { - val out = DMatrixSparseCSC(1, 1) - - CommonOps_DSCC.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra { -one }, - other.toEjml().origin, - out, - null, - null, - ) - - return out.wrapVector() - } - - override fun Double.times(m: Matrix): EjmlDoubleMatrix = m * this - - override fun Point.times(value: Double): EjmlDoubleVector { - val res = DMatrixSparseCSC(1, 1) - CommonOps_DSCC.scale(value, toEjml().origin, res) - return res.wrapVector() - } - - override fun Double.times(v: Point): EjmlDoubleVector = v * this - - @UnstableKMathAPI - override fun computeFeature(structure: Matrix, type: KClass): F? { - structure.getFeature(type)?.let { return it } - val origin = structure.toEjml().origin - - return when (type) { - QRDecompositionFeature::class -> object : QRDecompositionFeature { - private val qr by lazy { - DecompositionFactory_DSCC.qr(FillReducing.NONE).apply { decompose(origin.copy()) } - } - - override val q: Matrix by lazy { - qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) - } - - override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } - } - - CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { - override val l: Matrix by lazy { - val cholesky = - DecompositionFactory_DSCC.cholesky().apply { decompose(origin.copy()) } - - (cholesky.getT(null) as DMatrix).wrapMatrix().withFeature(LFeature) - } - } - - LUDecompositionFeature::class, DeterminantFeature::class, InverseMatrixFeature::class -> object : - LUDecompositionFeature, DeterminantFeature, InverseMatrixFeature { - private val lu by lazy { - DecompositionFactory_DSCC.lu(FillReducing.NONE).apply { decompose(origin.copy()) } - } - - override val l: Matrix by lazy { - lu.getLower(null).wrapMatrix().withFeature(LFeature) - } - - override val u: Matrix by lazy { - lu.getUpper(null).wrapMatrix().withFeature(UFeature) - } - - override val inverse: Matrix by lazy { - var a = origin - val inverse = DMatrixRMaj(1, 1) - val solver = LinearSolverFactory_DSCC.lu(FillReducing.NONE) - if (solver.modifiesA()) a = a.copy() - val i = CommonOps_DDRM.identity(a.numRows) - solver.solve(i, inverse) - inverse.wrapMatrix() - } - - override val determinant: Double by lazy { elementAlgebra.number(lu.computeDeterminant().real) } - } - - else -> null - }?.let{ - type.cast(it) - } - } - - /** - * Solves for *x* in the following equation: *x = [a] -1 · [b]*. - * - * @param a the base matrix. - * @param b n by p matrix. - * @return the solution for *x* that is n by p. - */ - public fun solve(a: Matrix, b: Matrix): EjmlDoubleMatrix { - val res = DMatrixSparseCSC(1, 1) - CommonOps_DSCC.solve(DMatrixSparseCSC(a.toEjml().origin), DMatrixSparseCSC(b.toEjml().origin), res) - return res.wrapMatrix() - } - - /** - * Solves for *x* in the following equation: *x = [a] -1 · [b]*. - * - * @param a the base matrix. - * @param b n by p vector. - * @return the solution for *x* that is n by p. - */ - public fun solve(a: Matrix, b: Point): EjmlDoubleVector { - val res = DMatrixSparseCSC(1, 1) - CommonOps_DSCC.solve(DMatrixSparseCSC(a.toEjml().origin), DMatrixSparseCSC(b.toEjml().origin), res) - return EjmlDoubleVector(res) - } -} - -/** - * [EjmlLinearSpace] implementation based on [CommonOps_FSCC], [DecompositionFactory_FSCC] operations and - * [FMatrixSparseCSC] matrices. - */ -public object EjmlLinearSpaceFSCC : EjmlLinearSpace() { - /** - * The [FloatField] reference. - */ - override val elementAlgebra: FloatField get() = FloatField - - @Suppress("UNCHECKED_CAST") - override fun Matrix.toEjml(): EjmlFloatMatrix = when { - this is EjmlFloatMatrix<*> && origin is FMatrixSparseCSC -> this as EjmlFloatMatrix - else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } - } - - @Suppress("UNCHECKED_CAST") - override fun Point.toEjml(): EjmlFloatVector = when { - this is EjmlFloatVector<*> && origin is FMatrixSparseCSC -> this as EjmlFloatVector - else -> EjmlFloatVector(FMatrixSparseCSC(size, 1).also { - (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } - }) - } - - override fun buildMatrix( - rows: Int, - columns: Int, - initializer: FloatField.(i: Int, j: Int) -> Float, - ): EjmlFloatMatrix = FMatrixSparseCSC(rows, columns).also { - (0 until rows).forEach { row -> - (0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) } - } - }.wrapMatrix() - - override fun buildVector( - size: Int, - initializer: FloatField.(Int) -> Float, - ): EjmlFloatVector = EjmlFloatVector(FMatrixSparseCSC(size, 1).also { - (0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) } - }) - - private fun T.wrapMatrix() = EjmlFloatMatrix(this) - private fun T.wrapVector() = EjmlFloatVector(this) - - override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } - - override fun Matrix.dot(other: Matrix): EjmlFloatMatrix { - val out = FMatrixSparseCSC(1, 1) - CommonOps_FSCC.mult(toEjml().origin, other.toEjml().origin, out) - return out.wrapMatrix() - } - - override fun Matrix.dot(vector: Point): EjmlFloatVector { - val out = FMatrixSparseCSC(1, 1) - CommonOps_FSCC.mult(toEjml().origin, vector.toEjml().origin, out) - return out.wrapVector() - } - - override operator fun Matrix.minus(other: Matrix): EjmlFloatMatrix { - val out = FMatrixSparseCSC(1, 1) - - CommonOps_FSCC.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra { -one }, - other.toEjml().origin, - out, - null, - null, - ) - - return out.wrapMatrix() - } - - override operator fun Matrix.times(value: Float): EjmlFloatMatrix { - val res = FMatrixSparseCSC(1, 1) - CommonOps_FSCC.scale(value, toEjml().origin, res) - return res.wrapMatrix() - } - - override fun Point.unaryMinus(): EjmlFloatVector { - val res = FMatrixSparseCSC(1, 1) - CommonOps_FSCC.changeSign(toEjml().origin, res) - return res.wrapVector() - } - - override fun Matrix.plus(other: Matrix): EjmlFloatMatrix { - val out = FMatrixSparseCSC(1, 1) - - CommonOps_FSCC.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra.one, - other.toEjml().origin, - out, - null, - null, - ) - - return out.wrapMatrix() - } - - override fun Point.plus(other: Point): EjmlFloatVector { - val out = FMatrixSparseCSC(1, 1) - - CommonOps_FSCC.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra.one, - other.toEjml().origin, - out, - null, - null, - ) - - return out.wrapVector() - } - - override fun Point.minus(other: Point): EjmlFloatVector { - val out = FMatrixSparseCSC(1, 1) - - CommonOps_FSCC.add( - elementAlgebra.one, - toEjml().origin, - elementAlgebra { -one }, - other.toEjml().origin, - out, - null, - null, - ) - - return out.wrapVector() - } - - override fun Float.times(m: Matrix): EjmlFloatMatrix = m * this - - override fun Point.times(value: Float): EjmlFloatVector { - val res = FMatrixSparseCSC(1, 1) - CommonOps_FSCC.scale(value, toEjml().origin, res) - return res.wrapVector() - } - - override fun Float.times(v: Point): EjmlFloatVector = v * this - - @UnstableKMathAPI - override fun computeFeature(structure: Matrix, type: KClass): F? { - structure.getFeature(type)?.let { return it } - val origin = structure.toEjml().origin - - return when (type) { - QRDecompositionFeature::class -> object : QRDecompositionFeature { - private val qr by lazy { - DecompositionFactory_FSCC.qr(FillReducing.NONE).apply { decompose(origin.copy()) } - } - - override val q: Matrix by lazy { - qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) - } - - override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } - } - - CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { - override val l: Matrix by lazy { - val cholesky = - DecompositionFactory_FSCC.cholesky().apply { decompose(origin.copy()) } - - (cholesky.getT(null) as FMatrix).wrapMatrix().withFeature(LFeature) - } - } - - LUDecompositionFeature::class, DeterminantFeature::class, InverseMatrixFeature::class -> object : - LUDecompositionFeature, DeterminantFeature, InverseMatrixFeature { - private val lu by lazy { - DecompositionFactory_FSCC.lu(FillReducing.NONE).apply { decompose(origin.copy()) } - } - - override val l: Matrix by lazy { - lu.getLower(null).wrapMatrix().withFeature(LFeature) - } - - override val u: Matrix by lazy { - lu.getUpper(null).wrapMatrix().withFeature(UFeature) - } - - override val inverse: Matrix by lazy { - var a = origin - val inverse = FMatrixRMaj(1, 1) - val solver = LinearSolverFactory_FSCC.lu(FillReducing.NONE) - if (solver.modifiesA()) a = a.copy() - val i = CommonOps_FDRM.identity(a.numRows) - solver.solve(i, inverse) - inverse.wrapMatrix() - } - - override val determinant: Float by lazy { elementAlgebra.number(lu.computeDeterminant().real) } - } - - else -> null - }?.let{ - type.cast(it) - } - } - - /** - * Solves for *x* in the following equation: *x = [a] -1 · [b]*. - * - * @param a the base matrix. - * @param b n by p matrix. - * @return the solution for *x* that is n by p. - */ - public fun solve(a: Matrix, b: Matrix): EjmlFloatMatrix { - val res = FMatrixSparseCSC(1, 1) - CommonOps_FSCC.solve(FMatrixSparseCSC(a.toEjml().origin), FMatrixSparseCSC(b.toEjml().origin), res) - return res.wrapMatrix() - } - - /** - * Solves for *x* in the following equation: *x = [a] -1 · [b]*. - * - * @param a the base matrix. - * @param b n by p vector. - * @return the solution for *x* that is n by p. - */ - public fun solve(a: Matrix, b: Point): EjmlFloatVector { - val res = FMatrixSparseCSC(1, 1) - CommonOps_FSCC.solve(FMatrixSparseCSC(a.toEjml().origin), FMatrixSparseCSC(b.toEjml().origin), res) - return EjmlFloatVector(res) - } -} - diff --git a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt index 90ec29ce3..eafd55513 100644 --- a/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt +++ b/kmath-histograms/src/commonMain/kotlin/space/kscience/kmath/histogram/UniformHistogramGroupND.kt @@ -28,7 +28,7 @@ public class UniformHistogramGroupND>( private val lower: Buffer, private val upper: Buffer, private val binNums: IntArray = IntArray(lower.size) { 20 }, - private val bufferFactory: BufferFactory = Buffer.Companion::boxing, + private val bufferFactory: BufferFactory = BufferFactory(Buffer.Companion::boxing), ) : HistogramGroupND { init { @@ -114,7 +114,7 @@ public class UniformHistogramGroupND>( public fun > Histogram.Companion.uniformNDFromRanges( valueAlgebraND: FieldOpsND, vararg ranges: ClosedFloatingPointRange, - bufferFactory: BufferFactory = Buffer.Companion::boxing, + bufferFactory: BufferFactory = BufferFactory(Buffer.Companion::boxing), ): UniformHistogramGroupND = UniformHistogramGroupND( valueAlgebraND, ranges.map(ClosedFloatingPointRange::start).asBuffer(), @@ -140,7 +140,7 @@ public fun Histogram.Companion.uniformDoubleNDFromRanges( public fun > Histogram.Companion.uniformNDFromRanges( valueAlgebraND: FieldOpsND, vararg ranges: Pair, Int>, - bufferFactory: BufferFactory = Buffer.Companion::boxing, + bufferFactory: BufferFactory = BufferFactory(Buffer.Companion::boxing), ): UniformHistogramGroupND = UniformHistogramGroupND( valueAlgebraND, ListBuffer( diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt index 1dc318517..0de2d8349 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt @@ -6,6 +6,7 @@ package space.kscience.kmath.multik import org.jetbrains.kotlinx.multik.ndarray.data.DataType +import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.StructureND import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.ExponentialOperations @@ -22,10 +23,13 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra): MultikTensor = sin(arg) / cos(arg) + @PerformancePitfall override fun asin(arg: StructureND): MultikTensor = arg.map { asin(it) } + @PerformancePitfall override fun acos(arg: StructureND): MultikTensor = arg.map { acos(it) } + @PerformancePitfall override fun atan(arg: StructureND): MultikTensor = arg.map { atan(it) } override fun exp(arg: StructureND): MultikTensor = multikMath.mathEx.exp(arg.asMultik().array).wrap() @@ -42,10 +46,13 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra): MultikTensor = arg.map { asinh(it) } + @PerformancePitfall override fun acosh(arg: StructureND): MultikTensor = arg.map { acosh(it) } + @PerformancePitfall override fun atanh(arg: StructureND): MultikTensor = arg.map { atanh(it) } } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt index a88f3e437..890318e31 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt @@ -35,7 +35,7 @@ public fun interface Sampler { public fun Sampler.sampleBuffer( generator: RandomGenerator, size: Int, - bufferFactory: BufferFactory = Buffer.Companion::boxing, + bufferFactory: BufferFactory = BufferFactory(Buffer.Companion::boxing), ): Chain> { require(size > 1) //creating temporary storage once From f5fe53a9f234edd47ecb2e7d73e7700e12f4a01c Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 15 Jul 2022 16:20:28 +0300 Subject: [PATCH 3/6] Grand derivative refactoring. Phase 2 --- .../kscience/kmath/expressions/DSAlgebra.kt | 437 ++++++++++++++++++ .../kmath/expressions/DerivativeStructure.kt | 157 ------- .../DerivativeStructureExpression.kt | 349 -------------- .../DerivativeStructureExpressionTest.kt | 4 +- 4 files changed, 439 insertions(+), 508 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt delete mode 100644 kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt delete mode 100644 kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt new file mode 100644 index 000000000..d9fc46b47 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt @@ -0,0 +1,437 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +package space.kscience.kmath.expressions + +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.* +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.MutableBuffer +import space.kscience.kmath.structures.MutableBufferFactory +import space.kscience.kmath.structures.asBuffer +import kotlin.math.max +import kotlin.math.min + +/** + * Class representing both the value and the differentials of a function. + * + * This class is the workhorse of the differentiation package. + * + * This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper + * [Doubly Recursive Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf), + * Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used + * throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's + * derivative structures hold all partial derivatives up to any specified order, with respect to any number of free + * parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free + * parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters. + * + * Derived from + * [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java). + */ +@UnstableKMathAPI +public interface DS> { + public val derivativeAlgebra: DSAlgebra + public val data: Buffer +} + +/** + * Get a partial derivative. + * + * @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned). + * @return partial derivative. + * @see value + */ +@UnstableKMathAPI +public fun > DS.getPartialDerivative(vararg orders: Int): T = + data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)] + +/** + * The value part of the derivative structure. + * + * @see getPartialDerivative + */ +@UnstableKMathAPI +public val > DS.value: T get() = data[0] + +@UnstableKMathAPI +public abstract class DSAlgebra>( + public val algebra: A, + public val bufferFactory: MutableBufferFactory, + public val order: Int, + bindings: Map, +) : ExpressionAlgebra> { + + @OptIn(UnstableKMathAPI::class) + private fun bufferForVariable(index: Int, value: T): Buffer { + val buffer = bufferFactory(compiler.size) { algebra.zero } + buffer[0] = value + if (compiler.order > 0) { + // the derivative of the variable with respect to itself is 1. + + val indexOfDerivative = compiler.getPartialDerivativeIndex(*IntArray(numberOfVariables).apply { + set(index, 1) + }) + + buffer[indexOfDerivative] = algebra.one + } + return buffer + } + + @UnstableKMathAPI + protected inner class DSImpl internal constructor( + override val data: Buffer, + ) : DS { + override val derivativeAlgebra: DSAlgebra get() = this@DSAlgebra + } + + protected fun DS(data: Buffer): DS = DSImpl(data) + + + /** + * Build an instance representing a variable. + * + * Instances built using this constructor are considered to be the free variables with respect to which + * differentials are computed. As such, their differential with respect to themselves is +1. + */ + public fun variable( + index: Int, + value: T, + ): DS { + require(index < compiler.freeParameters) { "number is too large: $index >= ${compiler.freeParameters}" } + return DS(bufferForVariable(index, value)) + } + + /** + * Build an instance from all its derivatives. + * + * @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex]. + */ + public fun ofDerivatives( + vararg derivatives: T, + ): DS { + require(derivatives.size == compiler.size) { "dimension mismatch: ${derivatives.size} and ${compiler.size}" } + val data = derivatives.asBuffer() + + return DS(data) + } + + /** + * A class implementing both [DS] and [Symbol]. + */ + @UnstableKMathAPI + public inner class DSSymbol internal constructor( + index: Int, + symbol: Symbol, + value: T, + ) : Symbol by symbol, DS { + override val derivativeAlgebra: DSAlgebra get() = this@DSAlgebra + override val data: Buffer = bufferForVariable(index, value) + } + + + public val numberOfVariables: Int = bindings.size + + /** + * Get the compiler for number of free parameters and order. + * + * @return cached rules set. + */ + @PublishedApi + internal val compiler: DSCompiler by lazy { + // get the cached compilers + val cache: Array?>>? = null + + // we need to create more compilers + val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0) + val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size) + val newCache: Array?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) } + + if (cache != null) { + // preserve the already created compilers + for (i in cache.indices) { + cache[i].copyInto(newCache[i], endIndex = cache[i].size) + } + } + + // create the array in increasing diagonal order + for (diag in 0..numberOfVariables + order) { + for (o in max(0, diag - numberOfVariables)..min(order, diag)) { + val p: Int = diag - o + if (newCache[p][o] == null) { + val valueCompiler: DSCompiler? = if (p == 0) null else newCache[p - 1][o]!! + val derivativeCompiler: DSCompiler? = if (o == 0) null else newCache[p][o - 1]!! + + newCache[p][o] = DSCompiler( + algebra, + bufferFactory, + p, + o, + valueCompiler, + derivativeCompiler, + ) + } + } + } + + return@lazy newCache[numberOfVariables][order]!! + } + + private val variables: Map = bindings.entries.mapIndexed { index, (key, value) -> + key to DSSymbol( + index, + key, + value, + ) + }.toMap() + + + public override fun const(value: T): DS { + val buffer = bufferFactory(compiler.size) { algebra.zero } + buffer[0] = value + + return DS(buffer) + } + + override fun bindSymbolOrNull(value: String): DSSymbol? = variables[StringSymbol(value)] + + override fun bindSymbol(value: String): DSSymbol = + bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this") + + public fun bindSymbolOrNull(symbol: Symbol): DSSymbol? = variables[symbol.identity] + + public fun bindSymbol(symbol: Symbol): DSSymbol = + bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this") + + public fun DS.derivative(symbols: List): T { + require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" } + val ordersCount = symbols.groupBy { it }.mapValues { it.value.size } + return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray()) + } + + public fun DS.derivative(vararg symbols: Symbol): T = derivative(symbols.toList()) + +} + + +/** + * A ring over [DS]. + * + * @property order The derivation order. + * @param bindings The map of bindings values. All bindings are considered free parameters. + */ +@UnstableKMathAPI +public open class DSRing( + algebra: A, + bufferFactory: MutableBufferFactory, + order: Int, + bindings: Map, +) : DSAlgebra(algebra, bufferFactory, order, bindings), + Ring>, ScaleOperations>, + NumericAlgebra>, + NumbersAddOps> where A : Ring, A : NumericAlgebra, A : ScaleOperations { + + override fun bindSymbolOrNull(value: String): DSSymbol? = + super.bindSymbolOrNull(value) + + override fun DS.unaryMinus(): DS = mapData { -it } + + /** + * Create a copy of given [Buffer] and modify it according to [block] + */ + protected inline fun DS.transformDataBuffer(block: A.(MutableBuffer) -> Unit): DS { + require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" } + val newData = bufferFactory(compiler.size) { data[it] } + algebra.block(newData) + return DS(newData) + } + + protected fun DS.mapData(block: A.(T) -> T): DS { + require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" } + val newData: Buffer = data.map(bufferFactory) { + algebra.block(it) + } + return DS(newData) + } + + protected fun DS.mapDataIndexed(block: (Int, T) -> T): DS { + require(derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" } + val newData: Buffer = data.mapIndexed(bufferFactory, block) + return DS(newData) + } + + override val zero: DS by lazy { + const(algebra.zero) + } + + override val one: DS by lazy { + const(algebra.one) + } + + override fun number(value: Number): DS = const(algebra.number(value)) + + override fun add(left: DS, right: DS): DS = left.transformDataBuffer { result -> + require(right.derivativeAlgebra == this@DSRing) { "All derivative operations should be done in the same algebra" } + compiler.add(left.data, 0, right.data, 0, result, 0) + } + + override fun scale(a: DS, value: Double): DS = a.mapData { + it.times(value) + } + + override fun multiply( + left: DS, + right: DS, + ): DS = left.transformDataBuffer { result -> + compiler.multiply(left.data, 0, right.data, 0, result, 0) + } +// +// override fun DS.minus(arg: DS): DS = transformDataBuffer { result -> +// subtract(data, 0, arg.data, 0, result, 0) +// } + + override operator fun DS.plus(other: Number): DS = transformDataBuffer { + it[0] += number(other) + } + +// +// override operator fun DS.minus(other: Number): DS = +// this + (-other.toDouble()) + + override operator fun Number.plus(other: DS): DS = other + this + override operator fun Number.minus(other: DS): DS = other - this +} + +@UnstableKMathAPI +public class DerivativeStructureRingExpression( + public val algebra: A, + public val bufferFactory: MutableBufferFactory, + public val function: DSRing.() -> DS, +) : DifferentiableExpression where A : Ring, A : ScaleOperations, A : NumericAlgebra { + override operator fun invoke(arguments: Map): T = + DSRing(algebra, bufferFactory, 0, arguments).function().value + + override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + with( + DSRing( + algebra, + bufferFactory, + symbols.size, + arguments + ) + ) { function().derivative(symbols) } + } +} + +/** + * A field over commons-math [DerivativeStructure]. + * + * @property order The derivation order. + * @param bindings The map of bindings values. All bindings are considered free parameters. + */ +@UnstableKMathAPI +public class DSField>( + algebra: A, + bufferFactory: MutableBufferFactory, + order: Int, + bindings: Map, +) : DSRing(algebra, bufferFactory, order, bindings), ExtendedField> { + override fun number(value: Number): DS = const(algebra.number(value)) + + override fun divide(left: DS, right: DS): DS = left.transformDataBuffer { result -> + compiler.divide(left.data, 0, right.data, 0, result, 0) + } + + override fun sin(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.sin(arg.data, 0, result, 0) + } + + override fun cos(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.cos(arg.data, 0, result, 0) + } + + override fun tan(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.tan(arg.data, 0, result, 0) + } + + override fun asin(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.asin(arg.data, 0, result, 0) + } + + override fun acos(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.acos(arg.data, 0, result, 0) + } + + override fun atan(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.atan(arg.data, 0, result, 0) + } + + override fun sinh(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.sinh(arg.data, 0, result, 0) + } + + override fun cosh(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.cosh(arg.data, 0, result, 0) + } + + override fun tanh(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.tanh(arg.data, 0, result, 0) + } + + override fun asinh(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.asinh(arg.data, 0, result, 0) + } + + override fun acosh(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.acosh(arg.data, 0, result, 0) + } + + override fun atanh(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.atanh(arg.data, 0, result, 0) + } + + override fun power(arg: DS, pow: Number): DS = when (pow) { + is Int -> arg.transformDataBuffer { result -> + compiler.pow(arg.data, 0, pow, result, 0) + } + else -> arg.transformDataBuffer { result -> + compiler.pow(arg.data, 0, pow.toDouble(), result, 0) + } + } + + override fun sqrt(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.sqrt(arg.data, 0, result, 0) + } + + public fun power(arg: DS, pow: DS): DS = arg.transformDataBuffer { result -> + compiler.pow(arg.data, 0, pow.data, 0, result, 0) + } + + override fun exp(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.exp(arg.data, 0, result, 0) + } + + override fun ln(arg: DS): DS = arg.transformDataBuffer { result -> + compiler.ln(arg.data, 0, result, 0) + } +} + +@UnstableKMathAPI +public class DerivativeStructureFieldExpression>( + public val algebra: A, + public val bufferFactory: MutableBufferFactory, + public val function: DSField.() -> DS, +) : DifferentiableExpression { + override operator fun invoke(arguments: Map): T = + DSField(algebra, bufferFactory, 0, arguments).function().value + + override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + DSField( + algebra, + bufferFactory, + symbols.size, + arguments, + ).run { function().derivative(symbols) } + } +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt deleted file mode 100644 index 01c045cdb..000000000 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructure.kt +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. - */ - -package space.kscience.kmath.expressions - -import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.Ring -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.asBuffer - -/** - * Class representing both the value and the differentials of a function. - * - * This class is the workhorse of the differentiation package. - * - * This class is an implementation of the extension to Rall's numbers described in Dan Kalman's paper [Doubly Recursive - * Multivariate Automatic Differentiation](http://www1.american.edu/cas/mathstat/People/kalman/pdffiles/mmgautodiff.pdf), - * Mathematics Magazine, vol. 75, no. 3, June 2002. Rall's numbers are an extension to the real numbers used - * throughout mathematical expressions; they hold the derivative together with the value of a function. Dan Kalman's - * derivative structures hold all partial derivatives up to any specified order, with respect to any number of free - * parameters. Rall's numbers therefore can be seen as derivative structures for order one derivative and one free - * parameter, and real numbers can be seen as derivative structures with zero order derivative and no free parameters. - * - * Derived from - * [Commons Math's `DerivativeStructure`](https://github.com/apache/commons-math/blob/924f6c357465b39beb50e3c916d5eb6662194175/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/analysis/differentiation/DerivativeStructure.java). - */ -@UnstableKMathAPI -public open class DerivativeStructure> @PublishedApi internal constructor( - private val derivativeAlgebra: DerivativeStructureAlgebra, - @PublishedApi internal val data: Buffer, -) { - - public val compiler: DSCompiler get() = derivativeAlgebra.compiler - - /** - * The number of free parameters. - */ - public val freeParameters: Int get() = compiler.freeParameters - - /** - * The derivation order. - */ - public val order: Int get() = compiler.order - - /** - * The value part of the derivative structure. - * - * @see getPartialDerivative - */ - public val value: T get() = data[0] - - /** - * Get a partial derivative. - * - * @param orders derivation orders with respect to each variable (if all orders are 0, the value is returned). - * @return partial derivative. - * @see value - */ - public fun getPartialDerivative(vararg orders: Int): T = data[compiler.getPartialDerivativeIndex(*orders)] - - - /** - * Test for the equality of two derivative structures. - * - * Derivative structures are considered equal if they have the same number - * of free parameters, the same derivation order, and the same derivatives. - * - * @return `true` if two derivative structures are equal. - */ - public override fun equals(other: Any?): Boolean { - if (this === other) return true - - if (other is DerivativeStructure<*, *>) { - return ((freeParameters == other.freeParameters) && - (order == other.order) && - data == other.data) - } - - return false - } - - public override fun hashCode(): Int = - 227 + 229 * freeParameters + 233 * order + 239 * data.hashCode() - - public companion object { - - /** - * Build an instance representing a variable. - * - * Instances built using this constructor are considered to be the free variables with respect to which - * differentials are computed. As such, their differential with respect to themselves is +1. - */ - public fun > variable( - derivativeAlgebra: DerivativeStructureAlgebra, - index: Int, - value: T, - ): DerivativeStructure { - val compiler = derivativeAlgebra.compiler - require(index < compiler.freeParameters) { "number is too large: $index >= ${compiler.freeParameters}" } - return DerivativeStructure(derivativeAlgebra, derivativeAlgebra.bufferForVariable(index, value)) - } - - /** - * Build an instance from all its derivatives. - * - * @param derivatives derivatives sorted according to [DSCompiler.getPartialDerivativeIndex]. - */ - public fun > ofDerivatives( - derivativeAlgebra: DerivativeStructureAlgebra, - vararg derivatives: T, - ): DerivativeStructure { - val compiler = derivativeAlgebra.compiler - require(derivatives.size == compiler.size) { "dimension mismatch: ${derivatives.size} and ${compiler.size}" } - val data = derivatives.asBuffer() - - return DerivativeStructure( - derivativeAlgebra, - data - ) - } - } -} - -@OptIn(UnstableKMathAPI::class) -private fun > DerivativeStructureAlgebra.bufferForVariable(index: Int, value: T): Buffer { - val buffer = bufferFactory(compiler.size) { algebra.zero } - buffer[0] = value - if (compiler.order > 0) { - // the derivative of the variable with respect to itself is 1. - - val indexOfDerivative = compiler.getPartialDerivativeIndex(*IntArray(numberOfVariables).apply { - set(index, 1) - }) - - buffer[indexOfDerivative] = algebra.one - } - return buffer -} - -/** - * A class implementing both [DerivativeStructure] and [Symbol]. - */ -@UnstableKMathAPI -public class DerivativeStructureSymbol> internal constructor( - derivativeAlgebra: DerivativeStructureAlgebra, - index: Int, - symbol: Symbol, - value: T, -) : Symbol by symbol, DerivativeStructure( - derivativeAlgebra, derivativeAlgebra.bufferForVariable(index, value) -) { - override fun toString(): String = symbol.toString() - override fun equals(other: Any?): Boolean = (other as? Symbol) == symbol - override fun hashCode(): Int = symbol.hashCode() -} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt deleted file mode 100644 index 638057921..000000000 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpression.kt +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. - */ - -package space.kscience.kmath.expressions - -import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.* -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.MutableBuffer -import space.kscience.kmath.structures.MutableBufferFactory -import kotlin.math.max -import kotlin.math.min - -@UnstableKMathAPI -public abstract class DerivativeStructureAlgebra>( - public val algebra: A, - public val bufferFactory: MutableBufferFactory, - public val order: Int, - bindings: Map, -) : ExpressionAlgebra> { - - public val numberOfVariables: Int = bindings.size - - - /** - * Get the compiler for number of free parameters and order. - * - * @return cached rules set. - */ - @PublishedApi - internal val compiler: DSCompiler by lazy { - // get the cached compilers - val cache: Array?>>? = null - - // we need to create more compilers - val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0) - val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size) - val newCache: Array?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) } - - if (cache != null) { - // preserve the already created compilers - for (i in cache.indices) { - cache[i].copyInto(newCache[i], endIndex = cache[i].size) - } - } - - // create the array in increasing diagonal order - for (diag in 0..numberOfVariables + order) { - for (o in max(0, diag - numberOfVariables)..min(order, diag)) { - val p: Int = diag - o - if (newCache[p][o] == null) { - val valueCompiler: DSCompiler? = if (p == 0) null else newCache[p - 1][o]!! - val derivativeCompiler: DSCompiler? = if (o == 0) null else newCache[p][o - 1]!! - - newCache[p][o] = DSCompiler( - algebra, - bufferFactory, - p, - o, - valueCompiler, - derivativeCompiler, - ) - } - } - } - - return@lazy newCache[numberOfVariables][order]!! - } - - private val variables: Map> = - bindings.entries.mapIndexed { index, (key, value) -> - key to DerivativeStructureSymbol( - this, - index, - key, - value, - ) - }.toMap() - - - - public override fun const(value: T): DerivativeStructure { - val buffer = bufferFactory(compiler.size) { algebra.zero } - buffer[0] = value - - return DerivativeStructure( - this, - buffer - ) - } - - override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = variables[StringSymbol(value)] - - override fun bindSymbol(value: String): DerivativeStructureSymbol = - bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this") - - public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] - - public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol = - bindSymbolOrNull(symbol.identity) ?: error("Symbol '${symbol}' is not supported in $this") - - public fun DerivativeStructure.derivative(symbols: List): T { - require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" } - val ordersCount = symbols.groupBy { it }.mapValues { it.value.size } - return getPartialDerivative(*variables.keys.map { ordersCount[it] ?: 0 }.toIntArray()) - } - - public fun DerivativeStructure.derivative(vararg symbols: Symbol): T = derivative(symbols.toList()) - -} - - -/** - * A ring over [DerivativeStructure]. - * - * @property order The derivation order. - * @param bindings The map of bindings values. All bindings are considered free parameters. - */ -@UnstableKMathAPI -public open class DerivativeStructureRing( - algebra: A, - bufferFactory: MutableBufferFactory, - order: Int, - bindings: Map, -) : DerivativeStructureAlgebra(algebra, bufferFactory, order, bindings), - Ring>, ScaleOperations>, - NumericAlgebra>, - NumbersAddOps> where A : Ring, A : NumericAlgebra, A : ScaleOperations { - - override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = - super.bindSymbolOrNull(value) - - override fun DerivativeStructure.unaryMinus(): DerivativeStructure { - val newData = algebra { data.map(bufferFactory) { -it } } - return DerivativeStructure(this@DerivativeStructureRing, newData) - } - - /** - * Create a copy of given [Buffer] and modify it according to [block] - */ - protected inline fun DerivativeStructure.transformDataBuffer(block: DSCompiler.(MutableBuffer) -> Unit): DerivativeStructure { - val newData = bufferFactory(compiler.size) { data[it] } - compiler.block(newData) - return DerivativeStructure(this@DerivativeStructureRing, newData) - } - - protected fun DerivativeStructure.mapData(block: (T) -> T): DerivativeStructure { - val newData: Buffer = data.map(bufferFactory, block) - return DerivativeStructure(this@DerivativeStructureRing, newData) - } - - protected fun DerivativeStructure.mapDataIndexed(block: (Int, T) -> T): DerivativeStructure { - val newData: Buffer = data.mapIndexed(bufferFactory, block) - return DerivativeStructure(this@DerivativeStructureRing, newData) - } - - override val zero: DerivativeStructure by lazy { - const(algebra.zero) - } - - override val one: DerivativeStructure by lazy { - const(algebra.one) - } - - override fun number(value: Number): DerivativeStructure = const(algebra.number(value)) - - override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure { - left.compiler.checkCompatibility(right.compiler) - return left.transformDataBuffer { result -> - add(left.data, 0, right.data, 0, result, 0) - } - } - - override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = algebra { - a.mapData { it.times(value) } - } - - override fun multiply( - left: DerivativeStructure, - right: DerivativeStructure, - ): DerivativeStructure { - left.compiler.checkCompatibility(right.compiler) - return left.transformDataBuffer { result -> - multiply(left.data, 0, right.data, 0, result, 0) - } - } - - override fun DerivativeStructure.minus(arg: DerivativeStructure): DerivativeStructure { - compiler.checkCompatibility(arg.compiler) - return transformDataBuffer { result -> - subtract(data, 0, arg.data, 0, result, 0) - } - } - - override operator fun DerivativeStructure.plus(other: Number): DerivativeStructure = algebra { - transformDataBuffer { - it[0] += number(other) - } - } - - override operator fun DerivativeStructure.minus(other: Number): DerivativeStructure = - this + (-other.toDouble()) - - override operator fun Number.plus(other: DerivativeStructure): DerivativeStructure = other + this - override operator fun Number.minus(other: DerivativeStructure): DerivativeStructure = other - this -} - -@UnstableKMathAPI -public class DerivativeStructureRingExpression( - public val algebra: A, - public val bufferFactory: MutableBufferFactory, - public val function: DerivativeStructureRing.() -> DerivativeStructure, -) : DifferentiableExpression where A : Ring, A : ScaleOperations, A : NumericAlgebra { - override operator fun invoke(arguments: Map): T = - DerivativeStructureRing(algebra, bufferFactory, 0, arguments).function().value - - override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> - with( - DerivativeStructureRing( - algebra, - bufferFactory, - symbols.size, - arguments - ) - ) { function().derivative(symbols) } - } -} - -/** - * A field over commons-math [DerivativeStructure]. - * - * @property order The derivation order. - * @param bindings The map of bindings values. All bindings are considered free parameters. - */ -@UnstableKMathAPI -public class DerivativeStructureField>( - algebra: A, - bufferFactory: MutableBufferFactory, - order: Int, - bindings: Map, -) : DerivativeStructureRing(algebra, bufferFactory, order, bindings), ExtendedField> { - override fun number(value: Number): DerivativeStructure = const(algebra.number(value)) - - override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure { - left.compiler.checkCompatibility(right.compiler) - return left.transformDataBuffer { result -> - left.compiler.divide(left.data, 0, right.data, 0, result, 0) - } - } - - override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - sin(arg.data, 0, result, 0) - } - - override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - cos(arg.data, 0, result, 0) - } - - override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - tan(arg.data, 0, result, 0) - } - - override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - asin(arg.data, 0, result, 0) - } - - override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - acos(arg.data, 0, result, 0) - } - - override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - atan(arg.data, 0, result, 0) - } - - override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - sinh(arg.data, 0, result, 0) - } - - override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - cosh(arg.data, 0, result, 0) - } - - override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - tanh(arg.data, 0, result, 0) - } - - override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - asinh(arg.data, 0, result, 0) - } - - override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - acosh(arg.data, 0, result, 0) - } - - override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - atanh(arg.data, 0, result, 0) - } - - override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { - is Int -> arg.transformDataBuffer { result -> - pow(arg.data, 0, pow, result, 0) - } - else -> arg.transformDataBuffer { result -> - pow(arg.data, 0, pow.toDouble(), result, 0) - } - } - - override fun sqrt(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - sqrt(arg.data, 0, result, 0) - } - - public fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure { - arg.compiler.checkCompatibility(pow.compiler) - return arg.transformDataBuffer { result -> - pow(arg.data, 0, pow.data, 0, result, 0) - } - } - - override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - exp(arg.data, 0, result, 0) - } - - override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.transformDataBuffer { result -> - ln(arg.data, 0, result, 0) - } -} - -@UnstableKMathAPI -public class DerivativeStructureFieldExpression>( - public val algebra: A, - public val bufferFactory: MutableBufferFactory, - public val function: DerivativeStructureField.() -> DerivativeStructure, -) : DifferentiableExpression { - override operator fun invoke(arguments: Map): T = - DerivativeStructureField(algebra, bufferFactory, 0, arguments).function().value - - override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> - with( - DerivativeStructureField( - algebra, - bufferFactory, - symbols.size, - arguments, - ) - ) { function().derivative(symbols) } - } -} diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt index 429fe310b..fdeda4512 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt @@ -19,10 +19,10 @@ import kotlin.test.assertFails internal inline fun diff( order: Int, vararg parameters: Pair, - block: DerivativeStructureField.() -> Unit, + block: DSField.() -> Unit, ) { contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - DerivativeStructureField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block() + DSField(DoubleField, ::DoubleBuffer, order, mapOf(*parameters)).block() } internal class AutoDiffTest { From 846a6d2620810dc9e15538befba437f0882a77db Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 15 Jul 2022 17:20:00 +0300 Subject: [PATCH 4/6] Grand derivative refactoring. Phase 3 --- .../kscience/kmath/expressions/DSAlgebra.kt | 143 ++++++++++-------- .../kscience/kmath/expressions/DSCompiler.kt | 28 ++-- 2 files changed, 96 insertions(+), 75 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt index d9fc46b47..506fbd001 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt @@ -44,9 +44,29 @@ public interface DS> { * @see value */ @UnstableKMathAPI -public fun > DS.getPartialDerivative(vararg orders: Int): T = +private fun > DS.getPartialDerivative(vararg orders: Int): T = data[derivativeAlgebra.compiler.getPartialDerivativeIndex(*orders)] +/** + * Provide a partial derivative with given symbols. On symbol could me mentioned multiple times + */ +@UnstableKMathAPI +public fun > DS.derivative(symbols: List): T { + require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" } + val ordersCount: Map = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size } + return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray()) +} + +/** + * Provide a partial derivative with given symbols. On symbol could me mentioned multiple times + */ +@UnstableKMathAPI +public fun > DS.derivative(vararg symbols: Symbol): T { + require(symbols.size <= derivativeAlgebra.order) { "The order of derivative ${symbols.size} exceeds computed order ${derivativeAlgebra.order}" } + val ordersCount: Map = symbols.map { it.identity }.groupBy { it }.mapValues { it.value.size } + return getPartialDerivative(*symbols.map { ordersCount[it] ?: 0 }.toIntArray()) +} + /** * The value part of the derivative structure. * @@ -61,9 +81,67 @@ public abstract class DSAlgebra>( public val bufferFactory: MutableBufferFactory, public val order: Int, bindings: Map, -) : ExpressionAlgebra> { +) : ExpressionAlgebra>, SymbolIndexer { + + /** + * Get the compiler for number of free parameters and order. + * + * @return cached rules set. + */ + @PublishedApi + internal val compiler: DSCompiler by lazy { + // get the cached compilers + val cache: Array?>>? = null + + // we need to create more compilers + val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0) + val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size) + val newCache: Array?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) } + + if (cache != null) { + // preserve the already created compilers + for (i in cache.indices) { + cache[i].copyInto(newCache[i], endIndex = cache[i].size) + } + } + + // create the array in increasing diagonal order + for (diag in 0..numberOfVariables + order) { + for (o in max(0, diag - numberOfVariables)..min(order, diag)) { + val p: Int = diag - o + if (newCache[p][o] == null) { + val valueCompiler: DSCompiler? = if (p == 0) null else newCache[p - 1][o]!! + val derivativeCompiler: DSCompiler? = if (o == 0) null else newCache[p][o - 1]!! + + newCache[p][o] = DSCompiler( + algebra, + bufferFactory, + p, + o, + valueCompiler, + derivativeCompiler, + ) + } + } + } + + return@lazy newCache[numberOfVariables][order]!! + } + + private val variables: Map by lazy { + bindings.entries.mapIndexed { index, (key, value) -> + key to DSSymbol( + index, + key, + value, + ) + }.toMap() + } + override val symbols: List = bindings.map { it.key } + + public val numberOfVariables: Int get() = symbols.size + - @OptIn(UnstableKMathAPI::class) private fun bufferForVariable(index: Int, value: T): Buffer { val buffer = bufferFactory(compiler.size) { algebra.zero } buffer[0] = value @@ -80,7 +158,7 @@ public abstract class DSAlgebra>( } @UnstableKMathAPI - protected inner class DSImpl internal constructor( + private inner class DSImpl( override val data: Buffer, ) : DS { override val derivativeAlgebra: DSAlgebra get() = this@DSAlgebra @@ -130,63 +208,6 @@ public abstract class DSAlgebra>( override val data: Buffer = bufferForVariable(index, value) } - - public val numberOfVariables: Int = bindings.size - - /** - * Get the compiler for number of free parameters and order. - * - * @return cached rules set. - */ - @PublishedApi - internal val compiler: DSCompiler by lazy { - // get the cached compilers - val cache: Array?>>? = null - - // we need to create more compilers - val maxParameters: Int = max(numberOfVariables, cache?.size ?: 0) - val maxOrder: Int = max(order, if (cache == null) 0 else cache[0].size) - val newCache: Array?>> = Array(maxParameters + 1) { arrayOfNulls(maxOrder + 1) } - - if (cache != null) { - // preserve the already created compilers - for (i in cache.indices) { - cache[i].copyInto(newCache[i], endIndex = cache[i].size) - } - } - - // create the array in increasing diagonal order - for (diag in 0..numberOfVariables + order) { - for (o in max(0, diag - numberOfVariables)..min(order, diag)) { - val p: Int = diag - o - if (newCache[p][o] == null) { - val valueCompiler: DSCompiler? = if (p == 0) null else newCache[p - 1][o]!! - val derivativeCompiler: DSCompiler? = if (o == 0) null else newCache[p][o - 1]!! - - newCache[p][o] = DSCompiler( - algebra, - bufferFactory, - p, - o, - valueCompiler, - derivativeCompiler, - ) - } - } - } - - return@lazy newCache[numberOfVariables][order]!! - } - - private val variables: Map = bindings.entries.mapIndexed { index, (key, value) -> - key to DSSymbol( - index, - key, - value, - ) - }.toMap() - - public override fun const(value: T): DS { val buffer = bufferFactory(compiler.size) { algebra.zero } buffer[0] = value diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt index e0050cf03..b5b2988a3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSCompiler.kt @@ -52,20 +52,20 @@ internal fun MutableBuffer.fill(element: T, fromIndex: Int = 0, toIndex: * * @property freeParameters Number of free parameters. * @property order Derivation order. - * @see DerivativeStructure + * @see DS */ -class DSCompiler> internal constructor( - val algebra: A, - val bufferFactory: MutableBufferFactory, - val freeParameters: Int, - val order: Int, +public class DSCompiler> internal constructor( + public val algebra: A, + public val bufferFactory: MutableBufferFactory, + public val freeParameters: Int, + public val order: Int, valueCompiler: DSCompiler?, derivativeCompiler: DSCompiler?, ) { /** * Number of partial derivatives (including the single 0 order derivative element). */ - val sizes: Array by lazy { + public val sizes: Array by lazy { compileSizes( freeParameters, order, @@ -76,7 +76,7 @@ class DSCompiler> internal constructor( /** * Indirection array for partial derivatives. */ - val derivativesIndirection: Array by lazy { + internal val derivativesIndirection: Array by lazy { compileDerivativesIndirection( freeParameters, order, valueCompiler, derivativeCompiler, @@ -86,7 +86,7 @@ class DSCompiler> internal constructor( /** * Indirection array of the lower derivative elements. */ - val lowerIndirection: IntArray by lazy { + internal val lowerIndirection: IntArray by lazy { compileLowerIndirection( freeParameters, order, valueCompiler, derivativeCompiler, @@ -96,7 +96,7 @@ class DSCompiler> internal constructor( /** * Indirection arrays for multiplication. */ - val multIndirection: Array> by lazy { + internal val multIndirection: Array> by lazy { compileMultiplicationIndirection( freeParameters, order, valueCompiler, derivativeCompiler, lowerIndirection, @@ -106,7 +106,7 @@ class DSCompiler> internal constructor( /** * Indirection arrays for function composition. */ - val compositionIndirection: Array> by lazy { + internal val compositionIndirection: Array> by lazy { compileCompositionIndirection( freeParameters, order, valueCompiler, derivativeCompiler, @@ -120,7 +120,7 @@ class DSCompiler> internal constructor( * This number includes the single 0 order derivative element, which is * guaranteed to be stored in the first element of the array. */ - val size: Int get() = sizes[freeParameters][order] + public val size: Int get() = sizes[freeParameters][order] /** * Get the index of a partial derivative in the array. @@ -147,7 +147,7 @@ class DSCompiler> internal constructor( * @return index of the partial derivative. * @see getPartialDerivativeOrders */ - fun getPartialDerivativeIndex(vararg orders: Int): Int { + public fun getPartialDerivativeIndex(vararg orders: Int): Int { // safety check require(orders.size == freeParameters) { "dimension mismatch: ${orders.size} and $freeParameters" } return getPartialDerivativeIndex(freeParameters, order, sizes, *orders) @@ -162,7 +162,7 @@ class DSCompiler> internal constructor( * @return orders derivation orders with respect to each parameter * @see getPartialDerivativeIndex */ - fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index] + public fun getPartialDerivativeOrders(index: Int): IntArray = derivativesIndirection[index] } /** From bfadf5b33d545ac0e4d40f2253e20e0ad88e4ba0 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 15 Jul 2022 17:31:28 +0300 Subject: [PATCH 5/6] Name refactor --- .../kotlin/space/kscience/kmath/expressions/DSAlgebra.kt | 2 +- .../kmath/expressions/DerivativeStructureExpressionTest.kt | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt index 506fbd001..59e6f4f6f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt @@ -439,7 +439,7 @@ public class DSField>( } @UnstableKMathAPI -public class DerivativeStructureFieldExpression>( +public class DSFieldExpression>( public val algebra: A, public val bufferFactory: MutableBufferFactory, public val function: DSField.() -> DS, diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt index fdeda4512..e5bc9805a 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/DerivativeStructureExpressionTest.kt @@ -30,7 +30,7 @@ internal class AutoDiffTest { private val y by symbol @Test - fun derivativeStructureFieldTest() { + fun dsAlgebraTest() { diff(2, x to 1.0, y to 1.0) { val x = bindSymbol(x)//by binding() val y = bindSymbol("y") @@ -44,8 +44,8 @@ internal class AutoDiffTest { } @Test - fun autoDifTest() { - val f = DerivativeStructureFieldExpression(DoubleField, ::DoubleBuffer) { + fun dsExpressionTest() { + val f = DSFieldExpression(DoubleField, ::DoubleBuffer) { val x by binding val y by binding x.pow(2) + 2 * x * y + y.pow(2) + 1 From 18ae964e57a2d1059aaa4f17909da45c3a4a1972 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 15 Jul 2022 17:35:13 +0300 Subject: [PATCH 6/6] Name refactor --- .../space/kscience/kmath/ejml/_generated.kt | 1003 +++++++++++++++++ 1 file changed, 1003 insertions(+) create mode 100644 kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt new file mode 100644 index 000000000..aac327a84 --- /dev/null +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt @@ -0,0 +1,1003 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. + */ + +/* This file is generated with buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt */ + +package space.kscience.kmath.ejml + +import org.ejml.data.* +import org.ejml.dense.row.CommonOps_DDRM +import org.ejml.dense.row.CommonOps_FDRM +import org.ejml.dense.row.factory.DecompositionFactory_DDRM +import org.ejml.dense.row.factory.DecompositionFactory_FDRM +import org.ejml.sparse.FillReducing +import org.ejml.sparse.csc.CommonOps_DSCC +import org.ejml.sparse.csc.CommonOps_FSCC +import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC +import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC +import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC +import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC +import space.kscience.kmath.linear.* +import space.kscience.kmath.linear.Matrix +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.nd.StructureFeature +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.FloatField +import space.kscience.kmath.operations.invoke +import space.kscience.kmath.structures.DoubleBuffer +import space.kscience.kmath.structures.FloatBuffer +import kotlin.reflect.KClass +import kotlin.reflect.cast + +/** + * [EjmlVector] specialization for [Double]. + */ +public class EjmlDoubleVector(override val origin: M) : EjmlVector(origin) { + init { + require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } + } + + override operator fun get(index: Int): Double = origin[0, index] +} + +/** + * [EjmlVector] specialization for [Float]. + */ +public class EjmlFloatVector(override val origin: M) : EjmlVector(origin) { + init { + require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } + } + + override operator fun get(index: Int): Float = origin[0, index] +} + +/** + * [EjmlMatrix] specialization for [Double]. + */ +public class EjmlDoubleMatrix(override val origin: M) : EjmlMatrix(origin) { + override operator fun get(i: Int, j: Int): Double = origin[i, j] +} + +/** + * [EjmlMatrix] specialization for [Float]. + */ +public class EjmlFloatMatrix(override val origin: M) : EjmlMatrix(origin) { + override operator fun get(i: Int, j: Int): Float = origin[i, j] +} + +/** + * [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and + * [DMatrixRMaj] matrices. + */ +public object EjmlLinearSpaceDDRM : EjmlLinearSpace() { + /** + * The [DoubleField] reference. + */ + override val elementAlgebra: DoubleField get() = DoubleField + + @Suppress("UNCHECKED_CAST") + override fun Matrix.toEjml(): EjmlDoubleMatrix = when { + this is EjmlDoubleMatrix<*> && origin is DMatrixRMaj -> this as EjmlDoubleMatrix + else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } + } + + @Suppress("UNCHECKED_CAST") + override fun Point.toEjml(): EjmlDoubleVector = when { + this is EjmlDoubleVector<*> && origin is DMatrixRMaj -> this as EjmlDoubleVector + else -> EjmlDoubleVector(DMatrixRMaj(size, 1).also { + (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } + }) + } + + override fun buildMatrix( + rows: Int, + columns: Int, + initializer: DoubleField.(i: Int, j: Int) -> Double, + ): EjmlDoubleMatrix = DMatrixRMaj(rows, columns).also { + (0 until rows).forEach { row -> + (0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) } + } + }.wrapMatrix() + + override fun buildVector( + size: Int, + initializer: DoubleField.(Int) -> Double, + ): EjmlDoubleVector = EjmlDoubleVector(DMatrixRMaj(size, 1).also { + (0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) } + }) + + private fun T.wrapMatrix() = EjmlDoubleMatrix(this) + private fun T.wrapVector() = EjmlDoubleVector(this) + + override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } + + override fun Matrix.dot(other: Matrix): EjmlDoubleMatrix { + val out = DMatrixRMaj(1, 1) + CommonOps_DDRM.mult(toEjml().origin, other.toEjml().origin, out) + return out.wrapMatrix() + } + + override fun Matrix.dot(vector: Point): EjmlDoubleVector { + val out = DMatrixRMaj(1, 1) + CommonOps_DDRM.mult(toEjml().origin, vector.toEjml().origin, out) + return out.wrapVector() + } + + override operator fun Matrix.minus(other: Matrix): EjmlDoubleMatrix { + val out = DMatrixRMaj(1, 1) + + CommonOps_DDRM.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra { -one }, + other.toEjml().origin, + out, + ) + + return out.wrapMatrix() + } + + override operator fun Matrix.times(value: Double): EjmlDoubleMatrix { + val res = DMatrixRMaj(1, 1) + CommonOps_DDRM.scale(value, toEjml().origin, res) + return res.wrapMatrix() + } + + override fun Point.unaryMinus(): EjmlDoubleVector { + val res = DMatrixRMaj(1, 1) + CommonOps_DDRM.changeSign(toEjml().origin, res) + return res.wrapVector() + } + + override fun Matrix.plus(other: Matrix): EjmlDoubleMatrix { + val out = DMatrixRMaj(1, 1) + + CommonOps_DDRM.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra.one, + other.toEjml().origin, + out, + ) + + return out.wrapMatrix() + } + + override fun Point.plus(other: Point): EjmlDoubleVector { + val out = DMatrixRMaj(1, 1) + + CommonOps_DDRM.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra.one, + other.toEjml().origin, + out, + ) + + return out.wrapVector() + } + + override fun Point.minus(other: Point): EjmlDoubleVector { + val out = DMatrixRMaj(1, 1) + + CommonOps_DDRM.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra { -one }, + other.toEjml().origin, + out, + ) + + return out.wrapVector() + } + + override fun Double.times(m: Matrix): EjmlDoubleMatrix = m * this + + override fun Point.times(value: Double): EjmlDoubleVector { + val res = DMatrixRMaj(1, 1) + CommonOps_DDRM.scale(value, toEjml().origin, res) + return res.wrapVector() + } + + override fun Double.times(v: Point): EjmlDoubleVector = v * this + + @UnstableKMathAPI + override fun computeFeature(structure: Matrix, type: KClass): F? { + structure.getFeature(type)?.let { return it } + val origin = structure.toEjml().origin + + return when (type) { + InverseMatrixFeature::class -> object : InverseMatrixFeature { + override val inverse: Matrix by lazy { + val res = origin.copy() + CommonOps_DDRM.invert(res) + res.wrapMatrix() + } + } + + DeterminantFeature::class -> object : DeterminantFeature { + override val determinant: Double by lazy { CommonOps_DDRM.det(origin) } + } + + SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature { + private val svd by lazy { + DecompositionFactory_DDRM.svd(origin.numRows, origin.numCols, true, true, false) + .apply { decompose(origin.copy()) } + } + + override val u: Matrix by lazy { svd.getU(null, false).wrapMatrix() } + override val s: Matrix by lazy { svd.getW(null).wrapMatrix() } + override val v: Matrix by lazy { svd.getV(null, false).wrapMatrix() } + override val singularValues: Point by lazy { DoubleBuffer(svd.singularValues) } + } + + QRDecompositionFeature::class -> object : QRDecompositionFeature { + private val qr by lazy { + DecompositionFactory_DDRM.qr().apply { decompose(origin.copy()) } + } + + override val q: Matrix by lazy { + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) + } + + override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } + } + + CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { + override val l: Matrix by lazy { + val cholesky = + DecompositionFactory_DDRM.chol(structure.rowNum, true).apply { decompose(origin.copy()) } + + cholesky.getT(null).wrapMatrix().withFeature(LFeature) + } + } + + LupDecompositionFeature::class -> object : LupDecompositionFeature { + private val lup by lazy { + DecompositionFactory_DDRM.lu(origin.numRows, origin.numCols).apply { decompose(origin.copy()) } + } + + override val l: Matrix by lazy { + lup.getLower(null).wrapMatrix().withFeature(LFeature) + } + + override val u: Matrix by lazy { + lup.getUpper(null).wrapMatrix().withFeature(UFeature) + } + + override val p: Matrix by lazy { lup.getRowPivot(null).wrapMatrix() } + } + + else -> null + }?.let{ + type.cast(it) + } + } + + /** + * Solves for *x* in the following equation: *x = [a] -1 · [b]*. + * + * @param a the base matrix. + * @param b n by p matrix. + * @return the solution for *x* that is n by p. + */ + public fun solve(a: Matrix, b: Matrix): EjmlDoubleMatrix { + val res = DMatrixRMaj(1, 1) + CommonOps_DDRM.solve(DMatrixRMaj(a.toEjml().origin), DMatrixRMaj(b.toEjml().origin), res) + return res.wrapMatrix() + } + + /** + * Solves for *x* in the following equation: *x = [a] -1 · [b]*. + * + * @param a the base matrix. + * @param b n by p vector. + * @return the solution for *x* that is n by p. + */ + public fun solve(a: Matrix, b: Point): EjmlDoubleVector { + val res = DMatrixRMaj(1, 1) + CommonOps_DDRM.solve(DMatrixRMaj(a.toEjml().origin), DMatrixRMaj(b.toEjml().origin), res) + return EjmlDoubleVector(res) + } +} + +/** + * [EjmlLinearSpace] implementation based on [CommonOps_FDRM], [DecompositionFactory_FDRM] operations and + * [FMatrixRMaj] matrices. + */ +public object EjmlLinearSpaceFDRM : EjmlLinearSpace() { + /** + * The [FloatField] reference. + */ + override val elementAlgebra: FloatField get() = FloatField + + @Suppress("UNCHECKED_CAST") + override fun Matrix.toEjml(): EjmlFloatMatrix = when { + this is EjmlFloatMatrix<*> && origin is FMatrixRMaj -> this as EjmlFloatMatrix + else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } + } + + @Suppress("UNCHECKED_CAST") + override fun Point.toEjml(): EjmlFloatVector = when { + this is EjmlFloatVector<*> && origin is FMatrixRMaj -> this as EjmlFloatVector + else -> EjmlFloatVector(FMatrixRMaj(size, 1).also { + (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } + }) + } + + override fun buildMatrix( + rows: Int, + columns: Int, + initializer: FloatField.(i: Int, j: Int) -> Float, + ): EjmlFloatMatrix = FMatrixRMaj(rows, columns).also { + (0 until rows).forEach { row -> + (0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) } + } + }.wrapMatrix() + + override fun buildVector( + size: Int, + initializer: FloatField.(Int) -> Float, + ): EjmlFloatVector = EjmlFloatVector(FMatrixRMaj(size, 1).also { + (0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) } + }) + + private fun T.wrapMatrix() = EjmlFloatMatrix(this) + private fun T.wrapVector() = EjmlFloatVector(this) + + override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } + + override fun Matrix.dot(other: Matrix): EjmlFloatMatrix { + val out = FMatrixRMaj(1, 1) + CommonOps_FDRM.mult(toEjml().origin, other.toEjml().origin, out) + return out.wrapMatrix() + } + + override fun Matrix.dot(vector: Point): EjmlFloatVector { + val out = FMatrixRMaj(1, 1) + CommonOps_FDRM.mult(toEjml().origin, vector.toEjml().origin, out) + return out.wrapVector() + } + + override operator fun Matrix.minus(other: Matrix): EjmlFloatMatrix { + val out = FMatrixRMaj(1, 1) + + CommonOps_FDRM.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra { -one }, + other.toEjml().origin, + out, + ) + + return out.wrapMatrix() + } + + override operator fun Matrix.times(value: Float): EjmlFloatMatrix { + val res = FMatrixRMaj(1, 1) + CommonOps_FDRM.scale(value, toEjml().origin, res) + return res.wrapMatrix() + } + + override fun Point.unaryMinus(): EjmlFloatVector { + val res = FMatrixRMaj(1, 1) + CommonOps_FDRM.changeSign(toEjml().origin, res) + return res.wrapVector() + } + + override fun Matrix.plus(other: Matrix): EjmlFloatMatrix { + val out = FMatrixRMaj(1, 1) + + CommonOps_FDRM.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra.one, + other.toEjml().origin, + out, + ) + + return out.wrapMatrix() + } + + override fun Point.plus(other: Point): EjmlFloatVector { + val out = FMatrixRMaj(1, 1) + + CommonOps_FDRM.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra.one, + other.toEjml().origin, + out, + ) + + return out.wrapVector() + } + + override fun Point.minus(other: Point): EjmlFloatVector { + val out = FMatrixRMaj(1, 1) + + CommonOps_FDRM.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra { -one }, + other.toEjml().origin, + out, + ) + + return out.wrapVector() + } + + override fun Float.times(m: Matrix): EjmlFloatMatrix = m * this + + override fun Point.times(value: Float): EjmlFloatVector { + val res = FMatrixRMaj(1, 1) + CommonOps_FDRM.scale(value, toEjml().origin, res) + return res.wrapVector() + } + + override fun Float.times(v: Point): EjmlFloatVector = v * this + + @UnstableKMathAPI + override fun computeFeature(structure: Matrix, type: KClass): F? { + structure.getFeature(type)?.let { return it } + val origin = structure.toEjml().origin + + return when (type) { + InverseMatrixFeature::class -> object : InverseMatrixFeature { + override val inverse: Matrix by lazy { + val res = origin.copy() + CommonOps_FDRM.invert(res) + res.wrapMatrix() + } + } + + DeterminantFeature::class -> object : DeterminantFeature { + override val determinant: Float by lazy { CommonOps_FDRM.det(origin) } + } + + SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature { + private val svd by lazy { + DecompositionFactory_FDRM.svd(origin.numRows, origin.numCols, true, true, false) + .apply { decompose(origin.copy()) } + } + + override val u: Matrix by lazy { svd.getU(null, false).wrapMatrix() } + override val s: Matrix by lazy { svd.getW(null).wrapMatrix() } + override val v: Matrix by lazy { svd.getV(null, false).wrapMatrix() } + override val singularValues: Point by lazy { FloatBuffer(svd.singularValues) } + } + + QRDecompositionFeature::class -> object : QRDecompositionFeature { + private val qr by lazy { + DecompositionFactory_FDRM.qr().apply { decompose(origin.copy()) } + } + + override val q: Matrix by lazy { + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) + } + + override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } + } + + CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { + override val l: Matrix by lazy { + val cholesky = + DecompositionFactory_FDRM.chol(structure.rowNum, true).apply { decompose(origin.copy()) } + + cholesky.getT(null).wrapMatrix().withFeature(LFeature) + } + } + + LupDecompositionFeature::class -> object : LupDecompositionFeature { + private val lup by lazy { + DecompositionFactory_FDRM.lu(origin.numRows, origin.numCols).apply { decompose(origin.copy()) } + } + + override val l: Matrix by lazy { + lup.getLower(null).wrapMatrix().withFeature(LFeature) + } + + override val u: Matrix by lazy { + lup.getUpper(null).wrapMatrix().withFeature(UFeature) + } + + override val p: Matrix by lazy { lup.getRowPivot(null).wrapMatrix() } + } + + else -> null + }?.let{ + type.cast(it) + } + } + + /** + * Solves for *x* in the following equation: *x = [a] -1 · [b]*. + * + * @param a the base matrix. + * @param b n by p matrix. + * @return the solution for *x* that is n by p. + */ + public fun solve(a: Matrix, b: Matrix): EjmlFloatMatrix { + val res = FMatrixRMaj(1, 1) + CommonOps_FDRM.solve(FMatrixRMaj(a.toEjml().origin), FMatrixRMaj(b.toEjml().origin), res) + return res.wrapMatrix() + } + + /** + * Solves for *x* in the following equation: *x = [a] -1 · [b]*. + * + * @param a the base matrix. + * @param b n by p vector. + * @return the solution for *x* that is n by p. + */ + public fun solve(a: Matrix, b: Point): EjmlFloatVector { + val res = FMatrixRMaj(1, 1) + CommonOps_FDRM.solve(FMatrixRMaj(a.toEjml().origin), FMatrixRMaj(b.toEjml().origin), res) + return EjmlFloatVector(res) + } +} + +/** + * [EjmlLinearSpace] implementation based on [CommonOps_DSCC], [DecompositionFactory_DSCC] operations and + * [DMatrixSparseCSC] matrices. + */ +public object EjmlLinearSpaceDSCC : EjmlLinearSpace() { + /** + * The [DoubleField] reference. + */ + override val elementAlgebra: DoubleField get() = DoubleField + + @Suppress("UNCHECKED_CAST") + override fun Matrix.toEjml(): EjmlDoubleMatrix = when { + this is EjmlDoubleMatrix<*> && origin is DMatrixSparseCSC -> this as EjmlDoubleMatrix + else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } + } + + @Suppress("UNCHECKED_CAST") + override fun Point.toEjml(): EjmlDoubleVector = when { + this is EjmlDoubleVector<*> && origin is DMatrixSparseCSC -> this as EjmlDoubleVector + else -> EjmlDoubleVector(DMatrixSparseCSC(size, 1).also { + (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } + }) + } + + override fun buildMatrix( + rows: Int, + columns: Int, + initializer: DoubleField.(i: Int, j: Int) -> Double, + ): EjmlDoubleMatrix = DMatrixSparseCSC(rows, columns).also { + (0 until rows).forEach { row -> + (0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) } + } + }.wrapMatrix() + + override fun buildVector( + size: Int, + initializer: DoubleField.(Int) -> Double, + ): EjmlDoubleVector = EjmlDoubleVector(DMatrixSparseCSC(size, 1).also { + (0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) } + }) + + private fun T.wrapMatrix() = EjmlDoubleMatrix(this) + private fun T.wrapVector() = EjmlDoubleVector(this) + + override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } + + override fun Matrix.dot(other: Matrix): EjmlDoubleMatrix { + val out = DMatrixSparseCSC(1, 1) + CommonOps_DSCC.mult(toEjml().origin, other.toEjml().origin, out) + return out.wrapMatrix() + } + + override fun Matrix.dot(vector: Point): EjmlDoubleVector { + val out = DMatrixSparseCSC(1, 1) + CommonOps_DSCC.mult(toEjml().origin, vector.toEjml().origin, out) + return out.wrapVector() + } + + override operator fun Matrix.minus(other: Matrix): EjmlDoubleMatrix { + val out = DMatrixSparseCSC(1, 1) + + CommonOps_DSCC.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra { -one }, + other.toEjml().origin, + out, + null, + null, + ) + + return out.wrapMatrix() + } + + override operator fun Matrix.times(value: Double): EjmlDoubleMatrix { + val res = DMatrixSparseCSC(1, 1) + CommonOps_DSCC.scale(value, toEjml().origin, res) + return res.wrapMatrix() + } + + override fun Point.unaryMinus(): EjmlDoubleVector { + val res = DMatrixSparseCSC(1, 1) + CommonOps_DSCC.changeSign(toEjml().origin, res) + return res.wrapVector() + } + + override fun Matrix.plus(other: Matrix): EjmlDoubleMatrix { + val out = DMatrixSparseCSC(1, 1) + + CommonOps_DSCC.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra.one, + other.toEjml().origin, + out, + null, + null, + ) + + return out.wrapMatrix() + } + + override fun Point.plus(other: Point): EjmlDoubleVector { + val out = DMatrixSparseCSC(1, 1) + + CommonOps_DSCC.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra.one, + other.toEjml().origin, + out, + null, + null, + ) + + return out.wrapVector() + } + + override fun Point.minus(other: Point): EjmlDoubleVector { + val out = DMatrixSparseCSC(1, 1) + + CommonOps_DSCC.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra { -one }, + other.toEjml().origin, + out, + null, + null, + ) + + return out.wrapVector() + } + + override fun Double.times(m: Matrix): EjmlDoubleMatrix = m * this + + override fun Point.times(value: Double): EjmlDoubleVector { + val res = DMatrixSparseCSC(1, 1) + CommonOps_DSCC.scale(value, toEjml().origin, res) + return res.wrapVector() + } + + override fun Double.times(v: Point): EjmlDoubleVector = v * this + + @UnstableKMathAPI + override fun computeFeature(structure: Matrix, type: KClass): F? { + structure.getFeature(type)?.let { return it } + val origin = structure.toEjml().origin + + return when (type) { + QRDecompositionFeature::class -> object : QRDecompositionFeature { + private val qr by lazy { + DecompositionFactory_DSCC.qr(FillReducing.NONE).apply { decompose(origin.copy()) } + } + + override val q: Matrix by lazy { + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) + } + + override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } + } + + CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { + override val l: Matrix by lazy { + val cholesky = + DecompositionFactory_DSCC.cholesky().apply { decompose(origin.copy()) } + + (cholesky.getT(null) as DMatrix).wrapMatrix().withFeature(LFeature) + } + } + + LUDecompositionFeature::class, DeterminantFeature::class, InverseMatrixFeature::class -> object : + LUDecompositionFeature, DeterminantFeature, InverseMatrixFeature { + private val lu by lazy { + DecompositionFactory_DSCC.lu(FillReducing.NONE).apply { decompose(origin.copy()) } + } + + override val l: Matrix by lazy { + lu.getLower(null).wrapMatrix().withFeature(LFeature) + } + + override val u: Matrix by lazy { + lu.getUpper(null).wrapMatrix().withFeature(UFeature) + } + + override val inverse: Matrix by lazy { + var a = origin + val inverse = DMatrixRMaj(1, 1) + val solver = LinearSolverFactory_DSCC.lu(FillReducing.NONE) + if (solver.modifiesA()) a = a.copy() + val i = CommonOps_DDRM.identity(a.numRows) + solver.solve(i, inverse) + inverse.wrapMatrix() + } + + override val determinant: Double by lazy { elementAlgebra.number(lu.computeDeterminant().real) } + } + + else -> null + }?.let{ + type.cast(it) + } + } + + /** + * Solves for *x* in the following equation: *x = [a] -1 · [b]*. + * + * @param a the base matrix. + * @param b n by p matrix. + * @return the solution for *x* that is n by p. + */ + public fun solve(a: Matrix, b: Matrix): EjmlDoubleMatrix { + val res = DMatrixSparseCSC(1, 1) + CommonOps_DSCC.solve(DMatrixSparseCSC(a.toEjml().origin), DMatrixSparseCSC(b.toEjml().origin), res) + return res.wrapMatrix() + } + + /** + * Solves for *x* in the following equation: *x = [a] -1 · [b]*. + * + * @param a the base matrix. + * @param b n by p vector. + * @return the solution for *x* that is n by p. + */ + public fun solve(a: Matrix, b: Point): EjmlDoubleVector { + val res = DMatrixSparseCSC(1, 1) + CommonOps_DSCC.solve(DMatrixSparseCSC(a.toEjml().origin), DMatrixSparseCSC(b.toEjml().origin), res) + return EjmlDoubleVector(res) + } +} + +/** + * [EjmlLinearSpace] implementation based on [CommonOps_FSCC], [DecompositionFactory_FSCC] operations and + * [FMatrixSparseCSC] matrices. + */ +public object EjmlLinearSpaceFSCC : EjmlLinearSpace() { + /** + * The [FloatField] reference. + */ + override val elementAlgebra: FloatField get() = FloatField + + @Suppress("UNCHECKED_CAST") + override fun Matrix.toEjml(): EjmlFloatMatrix = when { + this is EjmlFloatMatrix<*> && origin is FMatrixSparseCSC -> this as EjmlFloatMatrix + else -> buildMatrix(rowNum, colNum) { i, j -> get(i, j) } + } + + @Suppress("UNCHECKED_CAST") + override fun Point.toEjml(): EjmlFloatVector = when { + this is EjmlFloatVector<*> && origin is FMatrixSparseCSC -> this as EjmlFloatVector + else -> EjmlFloatVector(FMatrixSparseCSC(size, 1).also { + (0 until it.numRows).forEach { row -> it[row, 0] = get(row) } + }) + } + + override fun buildMatrix( + rows: Int, + columns: Int, + initializer: FloatField.(i: Int, j: Int) -> Float, + ): EjmlFloatMatrix = FMatrixSparseCSC(rows, columns).also { + (0 until rows).forEach { row -> + (0 until columns).forEach { col -> it[row, col] = elementAlgebra.initializer(row, col) } + } + }.wrapMatrix() + + override fun buildVector( + size: Int, + initializer: FloatField.(Int) -> Float, + ): EjmlFloatVector = EjmlFloatVector(FMatrixSparseCSC(size, 1).also { + (0 until it.numRows).forEach { row -> it[row, 0] = elementAlgebra.initializer(row) } + }) + + private fun T.wrapMatrix() = EjmlFloatMatrix(this) + private fun T.wrapVector() = EjmlFloatVector(this) + + override fun Matrix.unaryMinus(): Matrix = this * elementAlgebra { -one } + + override fun Matrix.dot(other: Matrix): EjmlFloatMatrix { + val out = FMatrixSparseCSC(1, 1) + CommonOps_FSCC.mult(toEjml().origin, other.toEjml().origin, out) + return out.wrapMatrix() + } + + override fun Matrix.dot(vector: Point): EjmlFloatVector { + val out = FMatrixSparseCSC(1, 1) + CommonOps_FSCC.mult(toEjml().origin, vector.toEjml().origin, out) + return out.wrapVector() + } + + override operator fun Matrix.minus(other: Matrix): EjmlFloatMatrix { + val out = FMatrixSparseCSC(1, 1) + + CommonOps_FSCC.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra { -one }, + other.toEjml().origin, + out, + null, + null, + ) + + return out.wrapMatrix() + } + + override operator fun Matrix.times(value: Float): EjmlFloatMatrix { + val res = FMatrixSparseCSC(1, 1) + CommonOps_FSCC.scale(value, toEjml().origin, res) + return res.wrapMatrix() + } + + override fun Point.unaryMinus(): EjmlFloatVector { + val res = FMatrixSparseCSC(1, 1) + CommonOps_FSCC.changeSign(toEjml().origin, res) + return res.wrapVector() + } + + override fun Matrix.plus(other: Matrix): EjmlFloatMatrix { + val out = FMatrixSparseCSC(1, 1) + + CommonOps_FSCC.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra.one, + other.toEjml().origin, + out, + null, + null, + ) + + return out.wrapMatrix() + } + + override fun Point.plus(other: Point): EjmlFloatVector { + val out = FMatrixSparseCSC(1, 1) + + CommonOps_FSCC.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra.one, + other.toEjml().origin, + out, + null, + null, + ) + + return out.wrapVector() + } + + override fun Point.minus(other: Point): EjmlFloatVector { + val out = FMatrixSparseCSC(1, 1) + + CommonOps_FSCC.add( + elementAlgebra.one, + toEjml().origin, + elementAlgebra { -one }, + other.toEjml().origin, + out, + null, + null, + ) + + return out.wrapVector() + } + + override fun Float.times(m: Matrix): EjmlFloatMatrix = m * this + + override fun Point.times(value: Float): EjmlFloatVector { + val res = FMatrixSparseCSC(1, 1) + CommonOps_FSCC.scale(value, toEjml().origin, res) + return res.wrapVector() + } + + override fun Float.times(v: Point): EjmlFloatVector = v * this + + @UnstableKMathAPI + override fun computeFeature(structure: Matrix, type: KClass): F? { + structure.getFeature(type)?.let { return it } + val origin = structure.toEjml().origin + + return when (type) { + QRDecompositionFeature::class -> object : QRDecompositionFeature { + private val qr by lazy { + DecompositionFactory_FSCC.qr(FillReducing.NONE).apply { decompose(origin.copy()) } + } + + override val q: Matrix by lazy { + qr.getQ(null, false).wrapMatrix().withFeature(OrthogonalFeature) + } + + override val r: Matrix by lazy { qr.getR(null, false).wrapMatrix().withFeature(UFeature) } + } + + CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { + override val l: Matrix by lazy { + val cholesky = + DecompositionFactory_FSCC.cholesky().apply { decompose(origin.copy()) } + + (cholesky.getT(null) as FMatrix).wrapMatrix().withFeature(LFeature) + } + } + + LUDecompositionFeature::class, DeterminantFeature::class, InverseMatrixFeature::class -> object : + LUDecompositionFeature, DeterminantFeature, InverseMatrixFeature { + private val lu by lazy { + DecompositionFactory_FSCC.lu(FillReducing.NONE).apply { decompose(origin.copy()) } + } + + override val l: Matrix by lazy { + lu.getLower(null).wrapMatrix().withFeature(LFeature) + } + + override val u: Matrix by lazy { + lu.getUpper(null).wrapMatrix().withFeature(UFeature) + } + + override val inverse: Matrix by lazy { + var a = origin + val inverse = FMatrixRMaj(1, 1) + val solver = LinearSolverFactory_FSCC.lu(FillReducing.NONE) + if (solver.modifiesA()) a = a.copy() + val i = CommonOps_FDRM.identity(a.numRows) + solver.solve(i, inverse) + inverse.wrapMatrix() + } + + override val determinant: Float by lazy { elementAlgebra.number(lu.computeDeterminant().real) } + } + + else -> null + }?.let{ + type.cast(it) + } + } + + /** + * Solves for *x* in the following equation: *x = [a] -1 · [b]*. + * + * @param a the base matrix. + * @param b n by p matrix. + * @return the solution for *x* that is n by p. + */ + public fun solve(a: Matrix, b: Matrix): EjmlFloatMatrix { + val res = FMatrixSparseCSC(1, 1) + CommonOps_FSCC.solve(FMatrixSparseCSC(a.toEjml().origin), FMatrixSparseCSC(b.toEjml().origin), res) + return res.wrapMatrix() + } + + /** + * Solves for *x* in the following equation: *x = [a] -1 · [b]*. + * + * @param a the base matrix. + * @param b n by p vector. + * @return the solution for *x* that is n by p. + */ + public fun solve(a: Matrix, b: Point): EjmlFloatVector { + val res = FMatrixSparseCSC(1, 1) + CommonOps_FSCC.solve(FMatrixSparseCSC(a.toEjml().origin), FMatrixSparseCSC(b.toEjml().origin), res) + return EjmlFloatVector(res) + } +} +