From 79642a869daa533e7204d07b6161116fbc9e889b Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 18 Feb 2024 14:00:38 +0300 Subject: [PATCH] LUP cleanup --- .../kscience/kmath/linear/LupDecomposition.kt | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt index 090153bd5..cff72b2f5 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt @@ -4,12 +4,14 @@ */ @file:Suppress("UnusedReceiverParameter") +@file:OptIn(PerformancePitfall::class) package space.kscience.kmath.linear import space.kscience.attributes.Attributes import space.kscience.attributes.PolymorphicAttribute import space.kscience.attributes.safeTypeOf +import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.nd.* import space.kscience.kmath.operations.* @@ -78,22 +80,22 @@ internal fun > LinearSpace>.abs(value: T): T = /** * Create a lup decomposition of generic matrix. */ -public fun > LinearSpace>.lup( +public fun > Field.lup( matrix: Matrix, checkSingular: (T) -> Boolean, -): GenericLupDecomposition = elementAlgebra { +): GenericLupDecomposition { require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" } val m = matrix.colNum val pivot = IntArray(matrix.rowNum) val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum)) - val lu: MutableStructure2D = MutableBufferND( + val lu = MutableBufferND( strides, bufferAlgebra.buffer(strides.linearSize) { offset -> matrix[strides.index(offset)] } - ).as2D() + ) // Initialize the permutation array and parity @@ -108,7 +110,9 @@ public fun > LinearSpace>.lup( // upper for (row in 0 until col) { var sum = lu[row, col] - for (i in 0 until row) sum -= lu[row, i] * lu[i, col] + for (i in 0 until row){ + sum -= lu[row, i] * lu[i, col] + } lu[row, col] = sum } @@ -118,7 +122,9 @@ public fun > LinearSpace>.lup( for (row in col until m) { var sum = lu[row, col] - for (i in 0 until col) sum -= lu[row, i] * lu[i, col] + for (i in 0 until col){ + sum -= lu[row, i] * lu[i, col] + } lu[row, col] = sum // maintain the best permutation choice @@ -151,31 +157,29 @@ public fun > LinearSpace>.lup( } - return GenericLupDecomposition(elementAlgebra, lu, pivot.asBuffer(), even) + return GenericLupDecomposition(this, lu.as2D(), pivot.asBuffer(), even) } -public fun LinearSpace.lup( +public fun Field.lup( matrix: Matrix, singularityThreshold: Double = 1e-11, ): GenericLupDecomposition = lup(matrix) { it < singularityThreshold } -internal fun LinearSpace>.solve( +private fun Field.solve( lup: LupDecomposition, matrix: Matrix, -): Matrix = elementAlgebra { +): Matrix { require(matrix.rowNum == lup.l.rowNum) { "Matrix dimension mismatch. Expected ${lup.l.rowNum}, but got ${matrix.colNum}" } -// with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)) { - val strides = RowStrides(ShapeND(matrix.rowNum, matrix.colNum)) // Apply permutations to b - val bp: MutableStructure2D = MutableBufferND( + val bp = MutableBufferND( strides, - bufferAlgebra.buffer(strides.linearSize) { offset -> zero } - ).as2D() + bufferAlgebra.buffer(strides.linearSize) { _ -> zero } + ) for (row in 0 until matrix.rowNum) { @@ -211,8 +215,7 @@ internal fun LinearSpace>.solve( } } - return buildMatrix(matrix.rowNum, matrix.colNum) { i, j -> bp[i, j] } - + return bp.as2D() } @@ -223,7 +226,7 @@ internal fun LinearSpace>.solve( public fun > LinearSpace>.lupSolver( singularityCheck: (T) -> Boolean, ): LinearSolver = object : LinearSolver { - override fun solve(a: Matrix, b: Matrix): Matrix { + override fun solve(a: Matrix, b: Matrix): Matrix = elementAlgebra{ // Use existing decomposition if it is provided by matrix or linear space itself val decomposition = a.getOrComputeAttribute(LUP) ?: lup(a, singularityCheck) return solve(decomposition, b)