LUP cleanup
This commit is contained in:
parent
fbee95ab8b
commit
79642a869d
@ -4,12 +4,14 @@
|
||||
*/
|
||||
|
||||
@file:Suppress("UnusedReceiverParameter")
|
||||
@file:OptIn(PerformancePitfall::class)
|
||||
|
||||
package space.kscience.kmath.linear
|
||||
|
||||
import space.kscience.attributes.Attributes
|
||||
import space.kscience.attributes.PolymorphicAttribute
|
||||
import space.kscience.attributes.safeTypeOf
|
||||
import space.kscience.kmath.PerformancePitfall
|
||||
import space.kscience.kmath.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.*
|
||||
@ -78,22 +80,22 @@ internal fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.abs(value: T): T =
|
||||
/**
|
||||
* Create a lup decomposition of generic matrix.
|
||||
*/
|
||||
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
||||
public fun <T : Comparable<T>> Field<T>.lup(
|
||||
matrix: Matrix<T>,
|
||||
checkSingular: (T) -> Boolean,
|
||||
): GenericLupDecomposition<T> = elementAlgebra {
|
||||
): GenericLupDecomposition<T> {
|
||||
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
|
||||
val m = matrix.colNum
|
||||
val pivot = IntArray(matrix.rowNum)
|
||||
|
||||
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
|
||||
|
||||
val lu: MutableStructure2D<T> = MutableBufferND(
|
||||
val lu = MutableBufferND(
|
||||
strides,
|
||||
bufferAlgebra.buffer(strides.linearSize) { offset ->
|
||||
matrix[strides.index(offset)]
|
||||
}
|
||||
).as2D()
|
||||
)
|
||||
|
||||
|
||||
// Initialize the permutation array and parity
|
||||
@ -108,7 +110,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
||||
// upper
|
||||
for (row in 0 until col) {
|
||||
var sum = lu[row, col]
|
||||
for (i in 0 until row) sum -= lu[row, i] * lu[i, col]
|
||||
for (i in 0 until row){
|
||||
sum -= lu[row, i] * lu[i, col]
|
||||
}
|
||||
lu[row, col] = sum
|
||||
}
|
||||
|
||||
@ -118,7 +122,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
||||
|
||||
for (row in col until m) {
|
||||
var sum = lu[row, col]
|
||||
for (i in 0 until col) sum -= lu[row, i] * lu[i, col]
|
||||
for (i in 0 until col){
|
||||
sum -= lu[row, i] * lu[i, col]
|
||||
}
|
||||
lu[row, col] = sum
|
||||
|
||||
// maintain the best permutation choice
|
||||
@ -151,31 +157,29 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
||||
}
|
||||
|
||||
|
||||
return GenericLupDecomposition(elementAlgebra, lu, pivot.asBuffer(), even)
|
||||
return GenericLupDecomposition(this, lu.as2D(), pivot.asBuffer(), even)
|
||||
|
||||
}
|
||||
|
||||
|
||||
public fun LinearSpace<Double, Float64Field>.lup(
|
||||
public fun Field<Float64>.lup(
|
||||
matrix: Matrix<Double>,
|
||||
singularityThreshold: Double = 1e-11,
|
||||
): GenericLupDecomposition<Double> = lup(matrix) { it < singularityThreshold }
|
||||
|
||||
internal fun <T> LinearSpace<T, Field<T>>.solve(
|
||||
private fun <T> Field<T>.solve(
|
||||
lup: LupDecomposition<T>,
|
||||
matrix: Matrix<T>,
|
||||
): Matrix<T> = elementAlgebra {
|
||||
): Matrix<T> {
|
||||
require(matrix.rowNum == lup.l.rowNum) { "Matrix dimension mismatch. Expected ${lup.l.rowNum}, but got ${matrix.colNum}" }
|
||||
|
||||
// with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)) {
|
||||
|
||||
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
|
||||
|
||||
// Apply permutations to b
|
||||
val bp: MutableStructure2D<T> = MutableBufferND(
|
||||
val bp = MutableBufferND(
|
||||
strides,
|
||||
bufferAlgebra.buffer(strides.linearSize) { offset -> zero }
|
||||
).as2D()
|
||||
bufferAlgebra.buffer(strides.linearSize) { _ -> zero }
|
||||
)
|
||||
|
||||
|
||||
for (row in 0 until matrix.rowNum) {
|
||||
@ -211,8 +215,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
|
||||
}
|
||||
}
|
||||
|
||||
return buildMatrix(matrix.rowNum, matrix.colNum) { i, j -> bp[i, j] }
|
||||
|
||||
return bp.as2D()
|
||||
}
|
||||
|
||||
|
||||
@ -223,7 +226,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
|
||||
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lupSolver(
|
||||
singularityCheck: (T) -> Boolean,
|
||||
): LinearSolver<T> = object : LinearSolver<T> {
|
||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> = elementAlgebra{
|
||||
// Use existing decomposition if it is provided by matrix or linear space itself
|
||||
val decomposition = a.getOrComputeAttribute(LUP) ?: lup(a, singularityCheck)
|
||||
return solve(decomposition, b)
|
||||
|
Loading…
Reference in New Issue
Block a user