LUP cleanup
This commit is contained in:
parent
fbee95ab8b
commit
79642a869d
@ -4,12 +4,14 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
@file:Suppress("UnusedReceiverParameter")
|
@file:Suppress("UnusedReceiverParameter")
|
||||||
|
@file:OptIn(PerformancePitfall::class)
|
||||||
|
|
||||||
package space.kscience.kmath.linear
|
package space.kscience.kmath.linear
|
||||||
|
|
||||||
import space.kscience.attributes.Attributes
|
import space.kscience.attributes.Attributes
|
||||||
import space.kscience.attributes.PolymorphicAttribute
|
import space.kscience.attributes.PolymorphicAttribute
|
||||||
import space.kscience.attributes.safeTypeOf
|
import space.kscience.attributes.safeTypeOf
|
||||||
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
@ -78,22 +80,22 @@ internal fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.abs(value: T): T =
|
|||||||
/**
|
/**
|
||||||
* Create a lup decomposition of generic matrix.
|
* Create a lup decomposition of generic matrix.
|
||||||
*/
|
*/
|
||||||
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
public fun <T : Comparable<T>> Field<T>.lup(
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
checkSingular: (T) -> Boolean,
|
checkSingular: (T) -> Boolean,
|
||||||
): GenericLupDecomposition<T> = elementAlgebra {
|
): GenericLupDecomposition<T> {
|
||||||
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
|
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
|
||||||
val m = matrix.colNum
|
val m = matrix.colNum
|
||||||
val pivot = IntArray(matrix.rowNum)
|
val pivot = IntArray(matrix.rowNum)
|
||||||
|
|
||||||
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
|
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
|
||||||
|
|
||||||
val lu: MutableStructure2D<T> = MutableBufferND(
|
val lu = MutableBufferND(
|
||||||
strides,
|
strides,
|
||||||
bufferAlgebra.buffer(strides.linearSize) { offset ->
|
bufferAlgebra.buffer(strides.linearSize) { offset ->
|
||||||
matrix[strides.index(offset)]
|
matrix[strides.index(offset)]
|
||||||
}
|
}
|
||||||
).as2D()
|
)
|
||||||
|
|
||||||
|
|
||||||
// Initialize the permutation array and parity
|
// Initialize the permutation array and parity
|
||||||
@ -108,7 +110,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
|||||||
// upper
|
// upper
|
||||||
for (row in 0 until col) {
|
for (row in 0 until col) {
|
||||||
var sum = lu[row, col]
|
var sum = lu[row, col]
|
||||||
for (i in 0 until row) sum -= lu[row, i] * lu[i, col]
|
for (i in 0 until row){
|
||||||
|
sum -= lu[row, i] * lu[i, col]
|
||||||
|
}
|
||||||
lu[row, col] = sum
|
lu[row, col] = sum
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,7 +122,9 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
|||||||
|
|
||||||
for (row in col until m) {
|
for (row in col until m) {
|
||||||
var sum = lu[row, col]
|
var sum = lu[row, col]
|
||||||
for (i in 0 until col) sum -= lu[row, i] * lu[i, col]
|
for (i in 0 until col){
|
||||||
|
sum -= lu[row, i] * lu[i, col]
|
||||||
|
}
|
||||||
lu[row, col] = sum
|
lu[row, col] = sum
|
||||||
|
|
||||||
// maintain the best permutation choice
|
// maintain the best permutation choice
|
||||||
@ -151,31 +157,29 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return GenericLupDecomposition(elementAlgebra, lu, pivot.asBuffer(), even)
|
return GenericLupDecomposition(this, lu.as2D(), pivot.asBuffer(), even)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public fun LinearSpace<Double, Float64Field>.lup(
|
public fun Field<Float64>.lup(
|
||||||
matrix: Matrix<Double>,
|
matrix: Matrix<Double>,
|
||||||
singularityThreshold: Double = 1e-11,
|
singularityThreshold: Double = 1e-11,
|
||||||
): GenericLupDecomposition<Double> = lup(matrix) { it < singularityThreshold }
|
): GenericLupDecomposition<Double> = lup(matrix) { it < singularityThreshold }
|
||||||
|
|
||||||
internal fun <T> LinearSpace<T, Field<T>>.solve(
|
private fun <T> Field<T>.solve(
|
||||||
lup: LupDecomposition<T>,
|
lup: LupDecomposition<T>,
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
): Matrix<T> = elementAlgebra {
|
): Matrix<T> {
|
||||||
require(matrix.rowNum == lup.l.rowNum) { "Matrix dimension mismatch. Expected ${lup.l.rowNum}, but got ${matrix.colNum}" }
|
require(matrix.rowNum == lup.l.rowNum) { "Matrix dimension mismatch. Expected ${lup.l.rowNum}, but got ${matrix.colNum}" }
|
||||||
|
|
||||||
// with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)) {
|
|
||||||
|
|
||||||
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
|
val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum))
|
||||||
|
|
||||||
// Apply permutations to b
|
// Apply permutations to b
|
||||||
val bp: MutableStructure2D<T> = MutableBufferND(
|
val bp = MutableBufferND(
|
||||||
strides,
|
strides,
|
||||||
bufferAlgebra.buffer(strides.linearSize) { offset -> zero }
|
bufferAlgebra.buffer(strides.linearSize) { _ -> zero }
|
||||||
).as2D()
|
)
|
||||||
|
|
||||||
|
|
||||||
for (row in 0 until matrix.rowNum) {
|
for (row in 0 until matrix.rowNum) {
|
||||||
@ -211,8 +215,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return buildMatrix(matrix.rowNum, matrix.colNum) { i, j -> bp[i, j] }
|
return bp.as2D()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -223,7 +226,7 @@ internal fun <T> LinearSpace<T, Field<T>>.solve(
|
|||||||
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lupSolver(
|
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lupSolver(
|
||||||
singularityCheck: (T) -> Boolean,
|
singularityCheck: (T) -> Boolean,
|
||||||
): LinearSolver<T> = object : LinearSolver<T> {
|
): LinearSolver<T> = object : LinearSolver<T> {
|
||||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> = elementAlgebra{
|
||||||
// Use existing decomposition if it is provided by matrix or linear space itself
|
// Use existing decomposition if it is provided by matrix or linear space itself
|
||||||
val decomposition = a.getOrComputeAttribute(LUP) ?: lup(a, singularityCheck)
|
val decomposition = a.getOrComputeAttribute(LUP) ?: lup(a, singularityCheck)
|
||||||
return solve(decomposition, b)
|
return solve(decomposition, b)
|
||||||
|
Loading…
Reference in New Issue
Block a user