Fixed LUP solver implementation

This commit is contained in:
Alexander Nozik 2019-01-17 12:42:56 +03:00
parent 271d7886df
commit b4e8b9050f
6 changed files with 154 additions and 112 deletions

View File

@ -1,52 +1,28 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import scientifik.kmath.structures.MutableBuffer.Companion.boxing import scientifik.kmath.structures.MutableBuffer.Companion.boxing
/**
* Matrix LUP decomposition
*/
interface LUPDecompositionFeature<T : Any> : DeterminantFeature<T> {
/**
* A reference to L-matrix
*/
val l: Matrix<T>
/**
* A reference to u-matrix
*/
val u: Matrix<T>
/**
* Pivoting points for each row
*/
val pivot: IntArray
/**
* Permutation matrix based on [pivot]
*/
val p: Matrix<T>
}
class LUPDecomposition<T : Comparable<T>>(
private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>( private val elementContext: Ring<T>,
val context: R, private val lu: NDStructure<T>,
val lu: NDStructure<T>, val pivot: IntArray,
override val pivot: IntArray,
private val even: Boolean private val even: Boolean
) : LUPDecompositionFeature<T> { ) : DeterminantFeature<T> {
/** /**
* Returns the matrix L of the decomposition. * Returns the matrix L of the decomposition.
* *
* L is a lower-triangular matrix * L is a lower-triangular matrix with [Ring.one] in diagonal
* @return the L matrix (or null if decomposed matrix is singular)
*/ */
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
when { when {
j < i -> lu[i, j] j < i -> lu[i, j]
j == i -> context.one j == i -> elementContext.one
else -> context.zero else -> elementContext.zero
} }
} }
@ -54,26 +30,21 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
/** /**
* Returns the matrix U of the decomposition. * Returns the matrix U of the decomposition.
* *
* U is an upper-triangular matrix * U is an upper-triangular matrix including the diagonal
* @return the U matrix (or null if decomposed matrix is singular)
*/ */
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
if (j >= i) lu[i, j] else context.zero if (j >= i) lu[i, j] else elementContext.zero
} }
/** /**
* Returns the P rows permutation matrix. * Returns the P rows permutation matrix.
* *
* P is a sparse matrix with exactly one element set to 1.0 in * P is a sparse matrix with exactly one element set to [Ring.one] in
* each row and each column, all other elements being set to 0.0. * each row and each column, all other elements being set to [Ring.zero].
*
* The positions of the 1 elements are given by the [ pivot permutation vector][.getPivot].
* @return the P rows permutation matrix (or null if decomposed matrix is singular)
* @see .getPivot
*/ */
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
if (j == pivot[i]) context.one else context.zero if (j == pivot[i]) elementContext.one else elementContext.zero
} }
@ -82,7 +53,7 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
* @return determinant of the matrix * @return determinant of the matrix
*/ */
override val determinant: T by lazy { override val determinant: T by lazy {
with(context) { with(elementContext) {
(0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] }
} }
} }
@ -90,14 +61,12 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
} }
/** class LUSolver<T : Comparable<T>, F : Field<T>>(
* Implementation based on Apache common-maths LU-decomposition override val context: MatrixContext<T, F>,
*/
class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
val context: F,
val bufferFactory: MutableBufferFactory<T> = ::boxing, val bufferFactory: MutableBufferFactory<T> = ::boxing,
val singularityCheck: (T) -> Boolean val singularityCheck: (T) -> Boolean
) { ) : LinearSolver<T, F> {
/** /**
* In-place transformation for [MutableNDStructure], using given transformation for each element * In-place transformation for [MutableNDStructure], using given transformation for each element
@ -106,9 +75,10 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
this[intArrayOf(i, j)] = value this[intArrayOf(i, j)] = value
} }
private fun abs(value: T) = if (value > context.zero) value else with(context) { -value } private fun abs(value: T) =
if (value > context.elementContext.zero) value else with(context.elementContext) { -value }
fun decompose(matrix: Matrix<T>): LUPDecompositionFeature<T> { fun buildDecomposition(matrix: Matrix<T>): LUPDecomposition<T> {
if (matrix.rowNum != matrix.colNum) { if (matrix.rowNum != matrix.colNum) {
error("LU decomposition supports only square matrices") error("LU decomposition supports only square matrices")
} }
@ -121,7 +91,7 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
} }
with(context) { with(context.elementContext) {
// Initialize permutation array and parity // Initialize permutation array and parity
for (row in 0 until m) { for (row in 0 until m) {
pivot[row] = row pivot[row] = row
@ -174,23 +144,22 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
lu[row, col] = lu[row, col] / luDiag lu[row, col] = lu[row, col] / luDiag
} }
} }
return LUPDecomposition(context, lu, pivot, even) return LUPDecomposition(context.elementContext, lu, pivot, even)
} }
} }
companion object { /**
val real: LUPDecompositionBuilder<Double, RealField> = LUPDecompositionBuilder(RealField) { it < 1e-11 } * Produce a matrix with added decomposition feature
*/
fun decompose(matrix: Matrix<T>): Matrix<T> {
if (matrix.hasFeature<LUPDecomposition<*>>()) {
return matrix
} else {
val decomposition = buildDecomposition(matrix)
return VirtualMatrix.wrap(matrix, decomposition)
}
} }
}
class LUSolver<T : Comparable<T>, F : Field<T>>(
override val context: MatrixContext<T, F>,
val bufferFactory: MutableBufferFactory<T> = ::boxing,
val singularityCheck: (T) -> Boolean
) : LinearSolver<T, F> {
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> { override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
if (b.rowNum != a.colNum) { if (b.rowNum != a.colNum) {
@ -198,10 +167,7 @@ class LUSolver<T : Comparable<T>, F : Field<T>>(
} }
// Use existing decomposition if it is provided by matrix // Use existing decomposition if it is provided by matrix
@Suppress("UNCHECKED_CAST") val decomposition = a.getFeature() ?: buildDecomposition(a)
val decomposition = a.features.find { it is LUPDecompositionFeature<*> }?.let {
it as LUPDecompositionFeature<T>
} ?: LUPDecompositionBuilder(context.elementContext, bufferFactory, singularityCheck).decompose(a)
with(decomposition) { with(decomposition) {
with(context.elementContext) { with(context.elementContext) {
@ -219,12 +185,14 @@ class LUSolver<T : Comparable<T>, F : Field<T>>(
} }
} }
// Solve UX = Y // Solve UX = Y
for (col in a.rowNum - 1 downTo 0) { for (col in a.rowNum - 1 downTo 0) {
for(j in 0 until b.colNum){
bp[col,j]/= u[col,col]
}
for (i in 0 until col) { for (i in 0 until col) {
for (j in 0 until b.colNum) { for (j in 0 until b.colNum) {
bp[i, j] -= bp[col, j] / u[col, col] * u[i, col] bp[i, j] -= bp[col, j] * u[i, col]
} }
} }
} }

View File

@ -5,6 +5,7 @@ import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory
import scientifik.kmath.structures.Buffer.Companion.boxing import scientifik.kmath.structures.Buffer.Companion.boxing
import kotlin.math.sqrt
interface MatrixContext<T : Any, R : Ring<T>> { interface MatrixContext<T : Any, R : Ring<T>> {
@ -146,43 +147,56 @@ interface Matrix<T : Any> : NDStructure<T> {
companion object { companion object {
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
MatrixContext.real.produce(rows, columns, initializer) MatrixContext.real.produce(rows, columns, initializer)
/**
* Build a square matrix from given elements.
*/
fun <T : Any> build(vararg elements: T): Matrix<T> {
val buffer = elements.asBuffer()
val size: Int = sqrt(elements.size.toDouble()).toInt()
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
val structure = Mutable2DStructure(size, size, buffer)
return StructureMatrix(structure)
}
} }
} }
/**
* Check if matrix has the given feature class
*/
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean = features.find { T::class.isInstance(it) } != null
/**
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
*/
inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInstance<T>().firstOrNull()
/** /**
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created * Diagonal matrix of ones. The matrix is virtual no actual matrix is created
*/ */
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> = VirtualMatrix<T>(rows,columns){i, j-> fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
if (i == j) elementContext.one else elementContext.zero VirtualMatrix<T>(rows, columns) { i, j ->
} if (i == j) elementContext.one else elementContext.zero
}
/** /**
* A virtual matrix of zeroes * A virtual matrix of zeroes
*/ */
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> = VirtualMatrix<T>(rows,columns){i, j-> fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
elementContext.zero VirtualMatrix<T>(rows, columns) { i, j ->
} elementContext.zero
}
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
inline class TransposedMatrix<T : Any>(val original: Matrix<T>) : Matrix<T> {
override val rowNum: Int get() = original.colNum
override val colNum: Int get() = original.rowNum
override val features: Set<MatrixFeature> get() = emptySet() //TODO retain some features
override fun get(i: Int, j: Int): T = original[j, i]
override fun elements(): Sequence<Pair<IntArray, T>> =
original.elements().map { (key, value) -> intArrayOf(key[1], key[0]) to value }
}
/** /**
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
*/ */
fun <T : Any, R : Ring<T>> Matrix<T>.transpose(): Matrix<T> { fun <T : Any, R : Ring<T>> Matrix<T>.transpose(): Matrix<T> {
return if (this is TransposedMatrix) { return this.getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
original this.colNum,
} else { this.rowNum,
TransposedMatrix(this) setOf(TransposedFeature(this))
} ) { i, j -> get(j, i) }
} }

View File

@ -1,10 +1,7 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.*
import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.get
import scientifik.kmath.structures.ndStructure
/** /**
* Basic implementation of Matrix space based on [NDStructure] * Basic implementation of Matrix space based on [NDStructure]
@ -24,7 +21,7 @@ class StructureMatrixContext<T : Any, R : Ring<T>>(
} }
class StructureMatrix<T : Any>( class StructureMatrix<T : Any>(
val structure: NDStructure<T>, val structure: NDStructure<out T>,
override val features: Set<MatrixFeature> = emptySet() override val features: Set<MatrixFeature> = emptySet()
) : Matrix<T> { ) : Matrix<T> {
@ -51,8 +48,7 @@ class StructureMatrix<T : Any>(
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
return when (other) { return when (other) {
is StructureMatrix<*> -> return this.structure == other.structure is NDStructure<*> -> return NDStructure.equals(this, other)
is Matrix<*> -> elements().all { (index, value) -> value == other[index] }
else -> false else -> false
} }
} }
@ -63,5 +59,16 @@ class StructureMatrix<T : Any>(
return result return result
} }
override fun toString(): String {
return if (rowNum <= 5 && colNum <= 5) {
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") {
it.asSequence().joinToString(separator = "\t") { it.toString() }
}
} else {
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
}
}
} }

View File

@ -27,4 +27,18 @@ class VirtualMatrix<T : Any>(
} }
companion object {
/**
* Wrap a matrix adding additional features to it
*/
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): Matrix<T> {
return if (matrix is VirtualMatrix) {
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
} else {
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j ->
matrix[i, j]
}
}
}
}
} }

View File

@ -11,6 +11,16 @@ interface NDStructure<T> {
operator fun get(index: IntArray): T operator fun get(index: IntArray): T
fun elements(): Sequence<Pair<IntArray, T>> fun elements(): Sequence<Pair<IntArray, T>>
companion object {
fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
return when {
st1 === st2 -> true
st1 is BufferNDStructure<*> && st2 is BufferNDStructure<*> && st1.strides == st2.strides -> st1.buffer.contentEquals(st2.buffer)
else -> st1.elements().all { (index, value) -> value == st2[index] }
}
}
}
} }
operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index) operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
@ -140,7 +150,7 @@ interface NDBuffer<T> : NDStructure<T> {
/** /**
* Boxing generic [NDStructure] * Boxing generic [NDStructure]
*/ */
data class BufferNDStructure<T>( class BufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : NDBuffer<T> { ) : NDBuffer<T> {
@ -160,7 +170,7 @@ data class BufferNDStructure<T>(
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
return when { return when {
this === other -> true this === other -> true
other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer) other is BufferNDStructure<*> && this.strides == other.strides -> this.buffer.contentEquals(other.buffer)
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] } other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
else -> false else -> false
} }
@ -193,7 +203,11 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
* *
* Strides should be reused if possible * Strides should be reused if possible
*/ */
fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, initializer: (IntArray) -> T) = fun <T> ndStructure(
strides: Strides,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T
) =
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
/** /**
@ -202,7 +216,11 @@ fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = Buffer.C
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
fun <T> ndStructure(shape: IntArray, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, initializer: (IntArray) -> T) = fun <T> ndStructure(
shape: IntArray,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T
) =
ndStructure(DefaultStrides(shape), bufferFactory, initializer) ndStructure(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =

View File

@ -11,10 +11,31 @@ class RealLUSolverTest {
assertEquals(matrix, inverted) assertEquals(matrix, inverted)
} }
// @Test @Test
// fun testInvert() { fun testInvert() {
// val matrix = realMatrix(2,2){} val matrix = Matrix.build(
// val inverted = LUSolver.real.inverse(matrix) 3.0, 1.0,
// assertEquals(matrix, inverted) 1.0, 3.0
// } )
val decomposed = LUSolver.real.decompose(matrix)
val decomposition = decomposed.getFeature<LUPDecomposition<Double>>()!!
//Check determinant
assertEquals(8.0, decomposition.determinant)
//Check decomposition
with(MatrixContext.real) {
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
}
val inverted = LUSolver.real.inverse(decomposed)
val expected = Matrix.build(
0.375, -0.125,
-0.125, 0.375
)
assertEquals(expected, inverted)
}
} }