LUP cleanup

This commit is contained in:
Alexander Nozik 2024-02-18 14:00:38 +03:00
parent fbee95ab8b
commit 79642a869d

View File

@ -4,12 +4,14 @@
*/
@file:Suppress("UnusedReceiverParameter")
@file:OptIn(PerformancePitfall::class)
package space.kscience.kmath.linear
import space.kscience.attributes.Attributes
import space.kscience.attributes.PolymorphicAttribute
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.*
@ -78,22 +80,22 @@ internal fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.abs(value: T): T =
/**
* Create a lup decomposition of generic matrix.
*/
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
public fun <T : Comparable<T>> Field<T>.lup(
matrix: Matrix<T>,
checkSingular: (T) -> Boolean,
): GenericLupDecomposition<T> = elementAlgebra {
): GenericLupDecomposition<T> {
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
val m = matrix.colNum
val pivot = IntArray(matrix.rowNum)
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
val lu: MutableStructure2D<T> = MutableBufferND(
val lu = MutableBufferND(
strides,
bufferAlgebra.buffer(strides.linearSize) { offset ->
matrix[strides.index(offset)]
}
).as2D()
)
// Initialize the permutation array and parity
@ -108,7 +110,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
// upper
for (row in 0 until col) {
var sum = lu[row, col]
for (i in 0 until row) sum -= lu[row, i] * lu[i, col]
for (i in 0 until row){
sum -= lu[row, i] * lu[i, col]
}
lu[row, col] = sum
}
@ -118,7 +122,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
for (row in col until m) {
var sum = lu[row, col]
for (i in 0 until col) sum -= lu[row, i] * lu[i, col]
for (i in 0 until col){
sum -= lu[row, i] * lu[i, col]
}
lu[row, col] = sum
// maintain the best permutation choice
@ -151,31 +157,29 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
}
return GenericLupDecomposition(elementAlgebra, lu, pivot.asBuffer(), even)
return GenericLupDecomposition(this, lu.as2D(), pivot.asBuffer(), even)
}
public fun LinearSpace<Double, Float64Field>.lup(
public fun Field<Float64>.lup(
matrix: Matrix<Double>,
singularityThreshold: Double = 1e-11,
): GenericLupDecomposition<Double> = lup(matrix) { it < singularityThreshold }
internal fun <T> LinearSpace<T, Field<T>>.solve(
private fun <T> Field<T>.solve(
lup: LupDecomposition<T>,
matrix: Matrix<T>,
): Matrix<T> = elementAlgebra {
): Matrix<T> {
require(matrix.rowNum == lup.l.rowNum) { "Matrix dimension mismatch. Expected ${lup.l.rowNum}, but got ${matrix.colNum}" }
// with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)) {
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
// Apply permutations to b
val bp: MutableStructure2D<T> = MutableBufferND(
val bp = MutableBufferND(
strides,
bufferAlgebra.buffer(strides.linearSize) { offset -> zero }
).as2D()
bufferAlgebra.buffer(strides.linearSize) { _ -> zero }
)
for (row in 0 until matrix.rowNum) {
@ -211,8 +215,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
}
}
return buildMatrix(matrix.rowNum, matrix.colNum) { i, j -> bp[i, j] }
return bp.as2D()
}
@ -223,7 +226,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lupSolver(
singularityCheck: (T) -> Boolean,
): LinearSolver<T> = object : LinearSolver<T> {
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> = elementAlgebra{
// Use existing decomposition if it is provided by matrix or linear space itself
val decomposition = a.getOrComputeAttribute(LUP) ?: lup(a, singularityCheck)
return solve(decomposition, b)