From b4e8b9050f1cd1f864c50017057fee2f13f7c650 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 17 Jan 2019 12:42:56 +0300 Subject: [PATCH] Fixed LUP solver implementation --- ...LUDecomposition.kt => LUPDecomposition.kt} | 114 +++++++----------- .../kotlin/scientifik/kmath/linear/Matrix.kt | 58 +++++---- .../kmath/linear/StructureMatrix.kt | 21 ++-- .../scientifik/kmath/linear/VirtualMatrix.kt | 14 +++ .../kmath/structures/NDStructure.kt | 26 +++- .../kmath/linear/RealLUSolverTest.kt | 33 ++++- 6 files changed, 154 insertions(+), 112 deletions(-) rename kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/{LUDecomposition.kt => LUPDecomposition.kt} (65%) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt similarity index 65% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt index 651613489..dbc45618d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt @@ -1,52 +1,28 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Field -import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring import scientifik.kmath.structures.* import scientifik.kmath.structures.MutableBuffer.Companion.boxing -/** - * Matrix LUP decomposition - */ -interface LUPDecompositionFeature : DeterminantFeature { - /** - * A reference to L-matrix - */ - val l: Matrix - /** - * A reference to u-matrix - */ - val u: Matrix - /** - * Pivoting points for each row - */ - val pivot: IntArray - /** - * Permutation matrix based on [pivot] - */ - val p: Matrix -} - -private class LUPDecomposition, R : Ring>( - val context: R, - val lu: NDStructure, - override val pivot: IntArray, +class LUPDecomposition>( + private val elementContext: Ring, + private val lu: NDStructure, + val pivot: IntArray, private val even: Boolean -) : LUPDecompositionFeature { +) : DeterminantFeature { /** * Returns the matrix L of the decomposition. * - * L is a lower-triangular matrix - * @return the L matrix (or null if decomposed matrix is singular) + * L is a lower-triangular matrix with [Ring.one] in diagonal */ - override val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> when { j < i -> lu[i, j] - j == i -> context.one - else -> context.zero + j == i -> elementContext.one + else -> elementContext.zero } } @@ -54,26 +30,21 @@ private class LUPDecomposition, R : Ring>( /** * Returns the matrix U of the decomposition. * - * U is an upper-triangular matrix - * @return the U matrix (or null if decomposed matrix is singular) + * U is an upper-triangular matrix including the diagonal */ - override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> - if (j >= i) lu[i, j] else context.zero + val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + if (j >= i) lu[i, j] else elementContext.zero } /** * Returns the P rows permutation matrix. * - * P is a sparse matrix with exactly one element set to 1.0 in - * each row and each column, all other elements being set to 0.0. - * - * The positions of the 1 elements are given by the [ pivot permutation vector][.getPivot]. - * @return the P rows permutation matrix (or null if decomposed matrix is singular) - * @see .getPivot + * P is a sparse matrix with exactly one element set to [Ring.one] in + * each row and each column, all other elements being set to [Ring.zero]. */ - override val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> - if (j == pivot[i]) context.one else context.zero + val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + if (j == pivot[i]) elementContext.one else elementContext.zero } @@ -82,7 +53,7 @@ private class LUPDecomposition, R : Ring>( * @return determinant of the matrix */ override val determinant: T by lazy { - with(context) { + with(elementContext) { (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } } } @@ -90,14 +61,12 @@ private class LUPDecomposition, R : Ring>( } -/** - * Implementation based on Apache common-maths LU-decomposition - */ -class LUPDecompositionBuilder, F : Field>( - val context: F, +class LUSolver, F : Field>( + override val context: MatrixContext, val bufferFactory: MutableBufferFactory = ::boxing, val singularityCheck: (T) -> Boolean -) { +) : LinearSolver { + /** * In-place transformation for [MutableNDStructure], using given transformation for each element @@ -106,9 +75,10 @@ class LUPDecompositionBuilder, F : Field>( this[intArrayOf(i, j)] = value } - private fun abs(value: T) = if (value > context.zero) value else with(context) { -value } + private fun abs(value: T) = + if (value > context.elementContext.zero) value else with(context.elementContext) { -value } - fun decompose(matrix: Matrix): LUPDecompositionFeature { + fun buildDecomposition(matrix: Matrix): LUPDecomposition { if (matrix.rowNum != matrix.colNum) { error("LU decomposition supports only square matrices") } @@ -121,7 +91,7 @@ class LUPDecompositionBuilder, F : Field>( } - with(context) { + with(context.elementContext) { // Initialize permutation array and parity for (row in 0 until m) { pivot[row] = row @@ -174,23 +144,22 @@ class LUPDecompositionBuilder, F : Field>( lu[row, col] = lu[row, col] / luDiag } } - return LUPDecomposition(context, lu, pivot, even) + return LUPDecomposition(context.elementContext, lu, pivot, even) } } - companion object { - val real: LUPDecompositionBuilder = LUPDecompositionBuilder(RealField) { it < 1e-11 } + /** + * Produce a matrix with added decomposition feature + */ + fun decompose(matrix: Matrix): Matrix { + if (matrix.hasFeature>()) { + return matrix + } else { + val decomposition = buildDecomposition(matrix) + return VirtualMatrix.wrap(matrix, decomposition) + } } -} - - -class LUSolver, F : Field>( - override val context: MatrixContext, - val bufferFactory: MutableBufferFactory = ::boxing, - val singularityCheck: (T) -> Boolean -) : LinearSolver { - override fun solve(a: Matrix, b: Matrix): Matrix { if (b.rowNum != a.colNum) { @@ -198,10 +167,7 @@ class LUSolver, F : Field>( } // Use existing decomposition if it is provided by matrix - @Suppress("UNCHECKED_CAST") - val decomposition = a.features.find { it is LUPDecompositionFeature<*> }?.let { - it as LUPDecompositionFeature - } ?: LUPDecompositionBuilder(context.elementContext, bufferFactory, singularityCheck).decompose(a) + val decomposition = a.getFeature() ?: buildDecomposition(a) with(decomposition) { with(context.elementContext) { @@ -219,12 +185,14 @@ class LUSolver, F : Field>( } } - // Solve UX = Y for (col in a.rowNum - 1 downTo 0) { + for(j in 0 until b.colNum){ + bp[col,j]/= u[col,col] + } for (i in 0 until col) { for (j in 0 until b.colNum) { - bp[i, j] -= bp[col, j] / u[col, col] * u[i, col] + bp[i, j] -= bp[col, j] * u[i, col] } } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index d9a14864b..6cd5e985f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -5,6 +5,7 @@ import scientifik.kmath.operations.Ring import scientifik.kmath.structures.* import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory import scientifik.kmath.structures.Buffer.Companion.boxing +import kotlin.math.sqrt interface MatrixContext> { @@ -146,43 +147,56 @@ interface Matrix : NDStructure { companion object { fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = MatrixContext.real.produce(rows, columns, initializer) + + /** + * Build a square matrix from given elements. + */ + fun build(vararg elements: T): Matrix { + val buffer = elements.asBuffer() + val size: Int = sqrt(elements.size.toDouble()).toInt() + if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square") + val structure = Mutable2DStructure(size, size, buffer) + return StructureMatrix(structure) + } } } +/** + * Check if matrix has the given feature class + */ +inline fun Matrix<*>.hasFeature(): Boolean = features.find { T::class.isInstance(it) } != null + +/** + * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria + */ +inline fun Matrix<*>.getFeature(): T? = features.filterIsInstance().firstOrNull() + /** * Diagonal matrix of ones. The matrix is virtual no actual matrix is created */ -fun > MatrixContext.one(rows: Int, columns: Int): Matrix = VirtualMatrix(rows,columns){i, j-> - if (i == j) elementContext.one else elementContext.zero -} +fun > MatrixContext.one(rows: Int, columns: Int): Matrix = + VirtualMatrix(rows, columns) { i, j -> + if (i == j) elementContext.one else elementContext.zero + } /** * A virtual matrix of zeroes */ -fun > MatrixContext.zero(rows: Int, columns: Int): Matrix = VirtualMatrix(rows,columns){i, j-> - elementContext.zero -} +fun > MatrixContext.zero(rows: Int, columns: Int): Matrix = + VirtualMatrix(rows, columns) { i, j -> + elementContext.zero + } - -inline class TransposedMatrix(val original: Matrix) : Matrix { - override val rowNum: Int get() = original.colNum - override val colNum: Int get() = original.rowNum - override val features: Set get() = emptySet() //TODO retain some features - - override fun get(i: Int, j: Int): T = original[j, i] - - override fun elements(): Sequence> = - original.elements().map { (key, value) -> intArrayOf(key[1], key[0]) to value } -} +class TransposedFeature(val original: Matrix) : MatrixFeature /** * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` */ fun > Matrix.transpose(): Matrix { - return if (this is TransposedMatrix) { - original - } else { - TransposedMatrix(this) - } + return this.getFeature>()?.original ?: VirtualMatrix( + this.colNum, + this.rowNum, + setOf(TransposedFeature(this)) + ) { i, j -> get(j, i) } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt index aa061fae2..21876a9fe 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt @@ -1,10 +1,7 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Ring -import scientifik.kmath.structures.BufferFactory -import scientifik.kmath.structures.NDStructure -import scientifik.kmath.structures.get -import scientifik.kmath.structures.ndStructure +import scientifik.kmath.structures.* /** * Basic implementation of Matrix space based on [NDStructure] @@ -24,7 +21,7 @@ class StructureMatrixContext>( } class StructureMatrix( - val structure: NDStructure, + val structure: NDStructure, override val features: Set = emptySet() ) : Matrix { @@ -51,8 +48,7 @@ class StructureMatrix( override fun equals(other: Any?): Boolean { if (this === other) return true return when (other) { - is StructureMatrix<*> -> return this.structure == other.structure - is Matrix<*> -> elements().all { (index, value) -> value == other[index] } + is NDStructure<*> -> return NDStructure.equals(this, other) else -> false } } @@ -63,5 +59,16 @@ class StructureMatrix( return result } + override fun toString(): String { + return if (rowNum <= 5 && colNum <= 5) { + "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" + + rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { + it.asSequence().joinToString(separator = "\t") { it.toString() } + } + } else { + "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)" + } + } + } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt index cb7b0a08e..98655ad48 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt @@ -27,4 +27,18 @@ class VirtualMatrix( } + companion object { + /** + * Wrap a matrix adding additional features to it + */ + fun wrap(matrix: Matrix, vararg features: MatrixFeature): Matrix { + return if (matrix is VirtualMatrix) { + VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator) + } else { + VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j -> + matrix[i, j] + } + } + } + } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt index 6c545d1af..4c437bf88 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt @@ -11,6 +11,16 @@ interface NDStructure { operator fun get(index: IntArray): T fun elements(): Sequence> + + companion object { + fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { + return when { + st1 === st2 -> true + st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(st2.buffer) + else -> st1.elements().all { (index, value) -> value == st2[index] } + } + } + } } operator fun NDStructure.get(vararg index: Int): T = get(index) @@ -140,7 +150,7 @@ interface NDBuffer : NDStructure { /** * Boxing generic [NDStructure] */ -data class BufferNDStructure( +class BufferNDStructure( override val strides: Strides, override val buffer: Buffer ) : NDBuffer { @@ -160,7 +170,7 @@ data class BufferNDStructure( override fun equals(other: Any?): Boolean { return when { this === other -> true - other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer) + other is BufferNDStructure<*> && this.strides == other.strides -> this.buffer.contentEquals(other.buffer) other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] } else -> false } @@ -193,7 +203,11 @@ inline fun NDStructure.mapToBuffer( * * Strides should be reused if possible */ -fun ndStructure(strides: Strides, bufferFactory: BufferFactory = Buffer.Companion::boxing, initializer: (IntArray) -> T) = +fun ndStructure( + strides: Strides, + bufferFactory: BufferFactory = Buffer.Companion::boxing, + initializer: (IntArray) -> T +) = BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) /** @@ -202,7 +216,11 @@ fun ndStructure(strides: Strides, bufferFactory: BufferFactory = Buffer.C inline fun inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) -fun ndStructure(shape: IntArray, bufferFactory: BufferFactory = Buffer.Companion::boxing, initializer: (IntArray) -> T) = +fun ndStructure( + shape: IntArray, + bufferFactory: BufferFactory = Buffer.Companion::boxing, + initializer: (IntArray) -> T +) = ndStructure(DefaultStrides(shape), bufferFactory, initializer) inline fun inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt index 54a308761..ea104355c 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt @@ -11,10 +11,31 @@ class RealLUSolverTest { assertEquals(matrix, inverted) } -// @Test -// fun testInvert() { -// val matrix = realMatrix(2,2){} -// val inverted = LUSolver.real.inverse(matrix) -// assertEquals(matrix, inverted) -// } + @Test + fun testInvert() { + val matrix = Matrix.build( + 3.0, 1.0, + 1.0, 3.0 + ) + + val decomposed = LUSolver.real.decompose(matrix) + val decomposition = decomposed.getFeature>()!! + + //Check determinant + assertEquals(8.0, decomposition.determinant) + + //Check decomposition + with(MatrixContext.real) { + assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u) + } + + val inverted = LUSolver.real.inverse(decomposed) + + val expected = Matrix.build( + 0.375, -0.125, + -0.125, 0.375 + ) + + assertEquals(expected, inverted) + } } \ No newline at end of file