LUP cleanup

This commit is contained in:
Alexander Nozik 2024-02-18 14:00:38 +03:00
parent fbee95ab8b
commit 79642a869d

View File

@ -4,12 +4,14 @@
*/ */
@file:Suppress("UnusedReceiverParameter") @file:Suppress("UnusedReceiverParameter")
@file:OptIn(PerformancePitfall::class)
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.attributes.Attributes import space.kscience.attributes.Attributes
import space.kscience.attributes.PolymorphicAttribute import space.kscience.attributes.PolymorphicAttribute
import space.kscience.attributes.safeTypeOf import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
@ -78,22 +80,22 @@ internal fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.abs(value: T): T =
/** /**
* Create a lup decomposition of generic matrix. * Create a lup decomposition of generic matrix.
*/ */
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup( public fun <T : Comparable<T>> Field<T>.lup(
matrix: Matrix<T>, matrix: Matrix<T>,
checkSingular: (T) -> Boolean, checkSingular: (T) -> Boolean,
): GenericLupDecomposition<T> = elementAlgebra { ): GenericLupDecomposition<T> {
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" } require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
val m = matrix.colNum val m = matrix.colNum
val pivot = IntArray(matrix.rowNum) val pivot = IntArray(matrix.rowNum)
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum)) val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
val lu: MutableStructure2D<T> = MutableBufferND( val lu = MutableBufferND(
strides, strides,
bufferAlgebra.buffer(strides.linearSize) { offset -> bufferAlgebra.buffer(strides.linearSize) { offset ->
matrix[strides.index(offset)] matrix[strides.index(offset)]
} }
).as2D() )
// Initialize the permutation array and parity // Initialize the permutation array and parity
@ -108,7 +110,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
// upper // upper
for (row in 0 until col) { for (row in 0 until col) {
var sum = lu[row, col] var sum = lu[row, col]
for (i in 0 until row) sum -= lu[row, i] * lu[i, col] for (i in 0 until row){
sum -= lu[row, i] * lu[i, col]
}
lu[row, col] = sum lu[row, col] = sum
} }
@ -118,7 +122,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
for (row in col until m) { for (row in col until m) {
var sum = lu[row, col] var sum = lu[row, col]
for (i in 0 until col) sum -= lu[row, i] * lu[i, col] for (i in 0 until col){
sum -= lu[row, i] * lu[i, col]
}
lu[row, col] = sum lu[row, col] = sum
// maintain the best permutation choice // maintain the best permutation choice
@ -151,31 +157,29 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
} }
return GenericLupDecomposition(elementAlgebra, lu, pivot.asBuffer(), even) return GenericLupDecomposition(this, lu.as2D(), pivot.asBuffer(), even)
} }
public fun LinearSpace<Double, Float64Field>.lup( public fun Field<Float64>.lup(
matrix: Matrix<Double>, matrix: Matrix<Double>,
singularityThreshold: Double = 1e-11, singularityThreshold: Double = 1e-11,
): GenericLupDecomposition<Double> = lup(matrix) { it < singularityThreshold } ): GenericLupDecomposition<Double> = lup(matrix) { it < singularityThreshold }
internal fun <T> LinearSpace<T, Field<T>>.solve( private fun <T> Field<T>.solve(
lup: LupDecomposition<T>, lup: LupDecomposition<T>,
matrix: Matrix<T>, matrix: Matrix<T>,
): Matrix<T> = elementAlgebra { ): Matrix<T> {
require(matrix.rowNum == lup.l.rowNum) { "Matrix dimension mismatch. Expected ${lup.l.rowNum}, but got ${matrix.colNum}" } require(matrix.rowNum == lup.l.rowNum) { "Matrix dimension mismatch. Expected ${lup.l.rowNum}, but got ${matrix.colNum}" }
// with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)) {
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum)) val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
// Apply permutations to b // Apply permutations to b
val bp: MutableStructure2D<T> = MutableBufferND( val bp = MutableBufferND(
strides, strides,
bufferAlgebra.buffer(strides.linearSize) { offset -> zero } bufferAlgebra.buffer(strides.linearSize) { _ -> zero }
).as2D() )
for (row in 0 until matrix.rowNum) { for (row in 0 until matrix.rowNum) {
@ -211,8 +215,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
} }
} }
return buildMatrix(matrix.rowNum, matrix.colNum) { i, j -> bp[i, j] } return bp.as2D()
} }
@ -223,7 +226,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lupSolver( public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lupSolver(
singularityCheck: (T) -> Boolean, singularityCheck: (T) -> Boolean,
): LinearSolver<T> = object : LinearSolver<T> { ): LinearSolver<T> = object : LinearSolver<T> {
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> { override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> = elementAlgebra{
// Use existing decomposition if it is provided by matrix or linear space itself // Use existing decomposition if it is provided by matrix or linear space itself
val decomposition = a.getOrComputeAttribute(LUP) ?: lup(a, singularityCheck) val decomposition = a.getOrComputeAttribute(LUP) ?: lup(a, singularityCheck)
return solve(decomposition, b) return solve(decomposition, b)