Fixed LUP solver implementation
This commit is contained in:
parent
271d7886df
commit
b4e8b9050f
@ -1,52 +1,28 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.*
|
||||
import scientifik.kmath.structures.MutableBuffer.Companion.boxing
|
||||
|
||||
/**
|
||||
* Matrix LUP decomposition
|
||||
*/
|
||||
interface LUPDecompositionFeature<T : Any> : DeterminantFeature<T> {
|
||||
/**
|
||||
* A reference to L-matrix
|
||||
*/
|
||||
val l: Matrix<T>
|
||||
/**
|
||||
* A reference to u-matrix
|
||||
*/
|
||||
val u: Matrix<T>
|
||||
/**
|
||||
* Pivoting points for each row
|
||||
*/
|
||||
val pivot: IntArray
|
||||
/**
|
||||
* Permutation matrix based on [pivot]
|
||||
*/
|
||||
val p: Matrix<T>
|
||||
}
|
||||
|
||||
|
||||
private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
|
||||
val context: R,
|
||||
val lu: NDStructure<T>,
|
||||
override val pivot: IntArray,
|
||||
class LUPDecomposition<T : Comparable<T>>(
|
||||
private val elementContext: Ring<T>,
|
||||
private val lu: NDStructure<T>,
|
||||
val pivot: IntArray,
|
||||
private val even: Boolean
|
||||
) : LUPDecompositionFeature<T> {
|
||||
) : DeterminantFeature<T> {
|
||||
|
||||
/**
|
||||
* Returns the matrix L of the decomposition.
|
||||
*
|
||||
* L is a lower-triangular matrix
|
||||
* @return the L matrix (or null if decomposed matrix is singular)
|
||||
* L is a lower-triangular matrix with [Ring.one] in diagonal
|
||||
*/
|
||||
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
when {
|
||||
j < i -> lu[i, j]
|
||||
j == i -> context.one
|
||||
else -> context.zero
|
||||
j == i -> elementContext.one
|
||||
else -> elementContext.zero
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,26 +30,21 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
|
||||
/**
|
||||
* Returns the matrix U of the decomposition.
|
||||
*
|
||||
* U is an upper-triangular matrix
|
||||
* @return the U matrix (or null if decomposed matrix is singular)
|
||||
* U is an upper-triangular matrix including the diagonal
|
||||
*/
|
||||
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j >= i) lu[i, j] else context.zero
|
||||
val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j >= i) lu[i, j] else elementContext.zero
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns the P rows permutation matrix.
|
||||
*
|
||||
* P is a sparse matrix with exactly one element set to 1.0 in
|
||||
* each row and each column, all other elements being set to 0.0.
|
||||
*
|
||||
* The positions of the 1 elements are given by the [ pivot permutation vector][.getPivot].
|
||||
* @return the P rows permutation matrix (or null if decomposed matrix is singular)
|
||||
* @see .getPivot
|
||||
* P is a sparse matrix with exactly one element set to [Ring.one] in
|
||||
* each row and each column, all other elements being set to [Ring.zero].
|
||||
*/
|
||||
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j == pivot[i]) context.one else context.zero
|
||||
val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j == pivot[i]) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
|
||||
@ -82,7 +53,7 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
|
||||
* @return determinant of the matrix
|
||||
*/
|
||||
override val determinant: T by lazy {
|
||||
with(context) {
|
||||
with(elementContext) {
|
||||
(0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] }
|
||||
}
|
||||
}
|
||||
@ -90,14 +61,12 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implementation based on Apache common-maths LU-decomposition
|
||||
*/
|
||||
class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
|
||||
val context: F,
|
||||
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||
override val context: MatrixContext<T, F>,
|
||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
||||
val singularityCheck: (T) -> Boolean
|
||||
) {
|
||||
) : LinearSolver<T, F> {
|
||||
|
||||
|
||||
/**
|
||||
* In-place transformation for [MutableNDStructure], using given transformation for each element
|
||||
@ -106,9 +75,10 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
|
||||
this[intArrayOf(i, j)] = value
|
||||
}
|
||||
|
||||
private fun abs(value: T) = if (value > context.zero) value else with(context) { -value }
|
||||
private fun abs(value: T) =
|
||||
if (value > context.elementContext.zero) value else with(context.elementContext) { -value }
|
||||
|
||||
fun decompose(matrix: Matrix<T>): LUPDecompositionFeature<T> {
|
||||
fun buildDecomposition(matrix: Matrix<T>): LUPDecomposition<T> {
|
||||
if (matrix.rowNum != matrix.colNum) {
|
||||
error("LU decomposition supports only square matrices")
|
||||
}
|
||||
@ -121,7 +91,7 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
|
||||
}
|
||||
|
||||
|
||||
with(context) {
|
||||
with(context.elementContext) {
|
||||
// Initialize permutation array and parity
|
||||
for (row in 0 until m) {
|
||||
pivot[row] = row
|
||||
@ -174,34 +144,30 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
|
||||
lu[row, col] = lu[row, col] / luDiag
|
||||
}
|
||||
}
|
||||
return LUPDecomposition(context, lu, pivot, even)
|
||||
return LUPDecomposition(context.elementContext, lu, pivot, even)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
val real: LUPDecompositionBuilder<Double, RealField> = LUPDecompositionBuilder(RealField) { it < 1e-11 }
|
||||
/**
|
||||
* Produce a matrix with added decomposition feature
|
||||
*/
|
||||
fun decompose(matrix: Matrix<T>): Matrix<T> {
|
||||
if (matrix.hasFeature<LUPDecomposition<*>>()) {
|
||||
return matrix
|
||||
} else {
|
||||
val decomposition = buildDecomposition(matrix)
|
||||
return VirtualMatrix.wrap(matrix, decomposition)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||
override val context: MatrixContext<T, F>,
|
||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
||||
val singularityCheck: (T) -> Boolean
|
||||
) : LinearSolver<T, F> {
|
||||
|
||||
|
||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||
if (b.rowNum != a.colNum) {
|
||||
error("Matrix dimension mismatch expected ${a.rowNum}, but got ${b.colNum}")
|
||||
}
|
||||
|
||||
// Use existing decomposition if it is provided by matrix
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val decomposition = a.features.find { it is LUPDecompositionFeature<*> }?.let {
|
||||
it as LUPDecompositionFeature<T>
|
||||
} ?: LUPDecompositionBuilder(context.elementContext, bufferFactory, singularityCheck).decompose(a)
|
||||
val decomposition = a.getFeature() ?: buildDecomposition(a)
|
||||
|
||||
with(decomposition) {
|
||||
with(context.elementContext) {
|
||||
@ -219,12 +185,14 @@ class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Solve UX = Y
|
||||
for (col in a.rowNum - 1 downTo 0) {
|
||||
for(j in 0 until b.colNum){
|
||||
bp[col,j]/= u[col,col]
|
||||
}
|
||||
for (i in 0 until col) {
|
||||
for (j in 0 until b.colNum) {
|
||||
bp[i, j] -= bp[col, j] / u[col, col] * u[i, col]
|
||||
bp[i, j] -= bp[col, j] * u[i, col]
|
||||
}
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.*
|
||||
import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory
|
||||
import scientifik.kmath.structures.Buffer.Companion.boxing
|
||||
import kotlin.math.sqrt
|
||||
|
||||
|
||||
interface MatrixContext<T : Any, R : Ring<T>> {
|
||||
@ -146,13 +147,35 @@ interface Matrix<T : Any> : NDStructure<T> {
|
||||
companion object {
|
||||
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||
MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
* Build a square matrix from given elements.
|
||||
*/
|
||||
fun <T : Any> build(vararg elements: T): Matrix<T> {
|
||||
val buffer = elements.asBuffer()
|
||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
|
||||
val structure = Mutable2DStructure(size, size, buffer)
|
||||
return StructureMatrix(structure)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if matrix has the given feature class
|
||||
*/
|
||||
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean = features.find { T::class.isInstance(it) } != null
|
||||
|
||||
/**
|
||||
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||
*/
|
||||
inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInstance<T>().firstOrNull()
|
||||
|
||||
/**
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> = VirtualMatrix<T>(rows,columns){i, j->
|
||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
@ -160,29 +183,20 @@ fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Mat
|
||||
/**
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> = VirtualMatrix<T>(rows,columns){i, j->
|
||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||
elementContext.zero
|
||||
}
|
||||
|
||||
|
||||
inline class TransposedMatrix<T : Any>(val original: Matrix<T>) : Matrix<T> {
|
||||
override val rowNum: Int get() = original.colNum
|
||||
override val colNum: Int get() = original.rowNum
|
||||
override val features: Set<MatrixFeature> get() = emptySet() //TODO retain some features
|
||||
|
||||
override fun get(i: Int, j: Int): T = original[j, i]
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||
original.elements().map { (key, value) -> intArrayOf(key[1], key[0]) to value }
|
||||
}
|
||||
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
||||
|
||||
/**
|
||||
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> Matrix<T>.transpose(): Matrix<T> {
|
||||
return if (this is TransposedMatrix) {
|
||||
original
|
||||
} else {
|
||||
TransposedMatrix(this)
|
||||
}
|
||||
return this.getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
||||
this.colNum,
|
||||
this.rowNum,
|
||||
setOf(TransposedFeature(this))
|
||||
) { i, j -> get(j, i) }
|
||||
}
|
@ -1,10 +1,7 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.BufferFactory
|
||||
import scientifik.kmath.structures.NDStructure
|
||||
import scientifik.kmath.structures.get
|
||||
import scientifik.kmath.structures.ndStructure
|
||||
import scientifik.kmath.structures.*
|
||||
|
||||
/**
|
||||
* Basic implementation of Matrix space based on [NDStructure]
|
||||
@ -24,7 +21,7 @@ class StructureMatrixContext<T : Any, R : Ring<T>>(
|
||||
}
|
||||
|
||||
class StructureMatrix<T : Any>(
|
||||
val structure: NDStructure<T>,
|
||||
val structure: NDStructure<out T>,
|
||||
override val features: Set<MatrixFeature> = emptySet()
|
||||
) : Matrix<T> {
|
||||
|
||||
@ -51,8 +48,7 @@ class StructureMatrix<T : Any>(
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
return when (other) {
|
||||
is StructureMatrix<*> -> return this.structure == other.structure
|
||||
is Matrix<*> -> elements().all { (index, value) -> value == other[index] }
|
||||
is NDStructure<*> -> return NDStructure.equals(this, other)
|
||||
else -> false
|
||||
}
|
||||
}
|
||||
@ -63,5 +59,16 @@ class StructureMatrix<T : Any>(
|
||||
return result
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return if (rowNum <= 5 && colNum <= 5) {
|
||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
|
||||
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") {
|
||||
it.asSequence().joinToString(separator = "\t") { it.toString() }
|
||||
}
|
||||
} else {
|
||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -27,4 +27,18 @@ class VirtualMatrix<T : Any>(
|
||||
}
|
||||
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Wrap a matrix adding additional features to it
|
||||
*/
|
||||
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): Matrix<T> {
|
||||
return if (matrix is VirtualMatrix) {
|
||||
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
|
||||
} else {
|
||||
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j ->
|
||||
matrix[i, j]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -11,6 +11,16 @@ interface NDStructure<T> {
|
||||
operator fun get(index: IntArray): T
|
||||
|
||||
fun elements(): Sequence<Pair<IntArray, T>>
|
||||
|
||||
companion object {
|
||||
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||
return when {
|
||||
st1 === st2 -> true
|
||||
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(st2.buffer)
|
||||
else -> st1.elements().all { (index, value) -> value == st2[index] }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
||||
@ -140,7 +150,7 @@ interface NDBuffer<T> : NDStructure<T> {
|
||||
/**
|
||||
* Boxing generic [NDStructure]
|
||||
*/
|
||||
data class BufferNDStructure<T>(
|
||||
class BufferNDStructure<T>(
|
||||
override val strides: Strides,
|
||||
override val buffer: Buffer<T>
|
||||
) : NDBuffer<T> {
|
||||
@ -160,7 +170,7 @@ data class BufferNDStructure<T>(
|
||||
override fun equals(other: Any?): Boolean {
|
||||
return when {
|
||||
this === other -> true
|
||||
other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer)
|
||||
other is BufferNDStructure<*> && this.strides == other.strides -> this.buffer.contentEquals(other.buffer)
|
||||
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
|
||||
else -> false
|
||||
}
|
||||
@ -193,7 +203,11 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
||||
*
|
||||
* Strides should be reused if possible
|
||||
*/
|
||||
fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, initializer: (IntArray) -> T) =
|
||||
fun <T> ndStructure(
|
||||
strides: Strides,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) =
|
||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
/**
|
||||
@ -202,7 +216,11 @@ fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = Buffer.C
|
||||
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
fun <T> ndStructure(shape: IntArray, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, initializer: (IntArray) -> T) =
|
||||
fun <T> ndStructure(
|
||||
shape: IntArray,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) =
|
||||
ndStructure(DefaultStrides(shape), bufferFactory, initializer)
|
||||
|
||||
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||
|
@ -11,10 +11,31 @@ class RealLUSolverTest {
|
||||
assertEquals(matrix, inverted)
|
||||
}
|
||||
|
||||
// @Test
|
||||
// fun testInvert() {
|
||||
// val matrix = realMatrix(2,2){}
|
||||
// val inverted = LUSolver.real.inverse(matrix)
|
||||
// assertEquals(matrix, inverted)
|
||||
// }
|
||||
@Test
|
||||
fun testInvert() {
|
||||
val matrix = Matrix.build(
|
||||
3.0, 1.0,
|
||||
1.0, 3.0
|
||||
)
|
||||
|
||||
val decomposed = LUSolver.real.decompose(matrix)
|
||||
val decomposition = decomposed.getFeature<LUPDecomposition<Double>>()!!
|
||||
|
||||
//Check determinant
|
||||
assertEquals(8.0, decomposition.determinant)
|
||||
|
||||
//Check decomposition
|
||||
with(MatrixContext.real) {
|
||||
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
|
||||
}
|
||||
|
||||
val inverted = LUSolver.real.inverse(decomposed)
|
||||
|
||||
val expected = Matrix.build(
|
||||
0.375, -0.125,
|
||||
-0.125, 0.375
|
||||
)
|
||||
|
||||
assertEquals(expected, inverted)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user