forked from kscience/kmath
LUDecomposition finished. Not tested
This commit is contained in:
parent
58e939e0cf
commit
cedb8a816e
@ -1,5 +1,5 @@
|
|||||||
buildscript {
|
buildscript {
|
||||||
extra["kotlinVersion"] = "1.3.20-eap-52"
|
extra["kotlinVersion"] = "1.3.20-eap-100"
|
||||||
extra["ioVersion"] = "0.1.2"
|
extra["ioVersion"] = "0.1.2"
|
||||||
extra["coroutinesVersion"] = "1.1.0"
|
extra["coroutinesVersion"] = "1.1.0"
|
||||||
|
|
||||||
|
@ -93,7 +93,11 @@ private class LUPDecomposition<T : Comparable<T>, R : Ring<T>>(
|
|||||||
/**
|
/**
|
||||||
* Implementation based on Apache common-maths LU-decomposition
|
* Implementation based on Apache common-maths LU-decomposition
|
||||||
*/
|
*/
|
||||||
class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(val context: F, val bufferFactory: MutableBufferFactory<T> = ::boxing, val singularityCheck: (T) -> Boolean) {
|
class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(
|
||||||
|
val context: F,
|
||||||
|
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
||||||
|
val singularityCheck: (T) -> Boolean
|
||||||
|
) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In-place transformation for [MutableNDStructure], using given transformation for each element
|
* In-place transformation for [MutableNDStructure], using given transformation for each element
|
||||||
@ -105,23 +109,16 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(val context: F, v
|
|||||||
private fun abs(value: T) = if (value > context.zero) value else with(context) { -value }
|
private fun abs(value: T) = if (value > context.zero) value else with(context) { -value }
|
||||||
|
|
||||||
fun decompose(matrix: Matrix<T>): LUPDecompositionFeature<T> {
|
fun decompose(matrix: Matrix<T>): LUPDecompositionFeature<T> {
|
||||||
// Use existing decomposition if it is provided by matrix
|
|
||||||
matrix.features.find { it is LUPDecompositionFeature<*> }?.let {
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
return it as LUPDecompositionFeature<T>
|
|
||||||
}
|
|
||||||
|
|
||||||
if (matrix.rowNum != matrix.colNum) {
|
if (matrix.rowNum != matrix.colNum) {
|
||||||
error("LU decomposition supports only square matrices")
|
error("LU decomposition supports only square matrices")
|
||||||
}
|
}
|
||||||
|
|
||||||
val m = matrix.colNum
|
val m = matrix.colNum
|
||||||
val pivot = IntArray(matrix.rowNum)
|
val pivot = IntArray(matrix.rowNum)
|
||||||
//TODO replace by custom optimized 2d structure
|
|
||||||
val lu: MutableNDStructure<T> = mutableNdStructure(
|
val lu = Mutable2DStructure.create(matrix.rowNum, matrix.colNum, bufferFactory) { i, j ->
|
||||||
intArrayOf(matrix.rowNum, matrix.colNum),
|
matrix[i, j]
|
||||||
bufferFactory
|
}
|
||||||
) { index: IntArray -> matrix[index[0], index[1]] }
|
|
||||||
|
|
||||||
|
|
||||||
with(context) {
|
with(context) {
|
||||||
@ -188,57 +185,56 @@ class LUPDecompositionBuilder<T : Comparable<T>, F : Field<T>>(val context: F, v
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LUSolver<T : Comparable<T>, F : Field<T>>(val singularityCheck: (T) -> Boolean) : LinearSolver<T, F> {
|
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||||
|
override val context: MatrixContext<T, F>,
|
||||||
|
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
||||||
|
val singularityCheck: (T) -> Boolean
|
||||||
|
) : LinearSolver<T, F> {
|
||||||
|
|
||||||
|
|
||||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||||
val decomposition = LUPDecompositionBuilder(ring, singularityCheck).decompose(a)
|
|
||||||
|
|
||||||
if (b.rowNum != a.colNum) {
|
if (b.rowNum != a.colNum) {
|
||||||
error("Matrix dimension mismatch expected ${a.rowNum}, but got ${b.colNum}")
|
error("Matrix dimension mismatch expected ${a.rowNum}, but got ${b.colNum}")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use existing decomposition if it is provided by matrix
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
val decomposition = a.features.find { it is LUPDecompositionFeature<*> }?.let {
|
||||||
|
it as LUPDecompositionFeature<T>
|
||||||
|
} ?: LUPDecompositionBuilder(context.elementContext, bufferFactory, singularityCheck).decompose(a)
|
||||||
|
|
||||||
// val bp = Array(a.rowNum) { Array<T>(b.colNum){ring.zero} }
|
with(decomposition) {
|
||||||
// for (row in 0 until a.rowNum) {
|
with(context.elementContext) {
|
||||||
// val bpRow = bp[row]
|
|
||||||
// val pRow = decomposition.pivot[row]
|
|
||||||
// for (col in 0 until b.colNum) {
|
|
||||||
// bpRow[col] = b[pRow, col]
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Apply permutations to b
|
// Apply permutations to b
|
||||||
val bp = produce(a.rowNum, a.colNum) { i, j -> b[decomposition.pivot[i], j] }
|
val bp = Mutable2DStructure.create(a.rowNum, a.colNum, bufferFactory) { i, j ->
|
||||||
|
b[pivot[i], j]
|
||||||
|
}
|
||||||
|
|
||||||
// Solve LY = b
|
// Solve LY = b
|
||||||
for (col in 0 until a.rowNum) {
|
for (col in 0 until a.rowNum) {
|
||||||
val bpCol = bp[col]
|
|
||||||
for (i in col + 1 until a.rowNum) {
|
for (i in col + 1 until a.rowNum) {
|
||||||
val bpI = bp[i]
|
|
||||||
val luICol = decomposition.lu[i, col]
|
|
||||||
for (j in 0 until b.colNum) {
|
for (j in 0 until b.colNum) {
|
||||||
bpI[j] -= bpCol[j] * luICol
|
bp[i, j] -= bp[col, j] * l[i, col]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Solve UX = Y
|
// Solve UX = Y
|
||||||
for (col in a.rowNum - 1 downTo 0) {
|
for (col in a.rowNum - 1 downTo 0) {
|
||||||
val bpCol = bp[col]
|
|
||||||
val luDiag = decomposition.lu[col, col]
|
|
||||||
for (j in 0 until b.colNum) {
|
|
||||||
bpCol[j] /= luDiag
|
|
||||||
}
|
|
||||||
for (i in 0 until col) {
|
for (i in 0 until col) {
|
||||||
val bpI = bp[i]
|
|
||||||
val luICol = decomposition.lu[i, col]
|
|
||||||
for (j in 0 until b.colNum) {
|
for (j in 0 until b.colNum) {
|
||||||
bpI[j] -= bpCol[j] * luICol
|
bp[i, j] -= bp[col, j] / u[col, col] * u[i, col]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return produce(a.rowNum, a.colNum) { i, j -> bp[i][j] }
|
return context.produce(a.rowNum, a.colNum) { i, j -> bp[i, j] }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
val real = LUSolver(MatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,26 +4,20 @@ import scientifik.kmath.operations.Field
|
|||||||
import scientifik.kmath.operations.Norm
|
import scientifik.kmath.operations.Norm
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
|
import scientifik.kmath.structures.VirtualBuffer
|
||||||
import scientifik.kmath.structures.asSequence
|
import scientifik.kmath.structures.asSequence
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
||||||
*/
|
*/
|
||||||
interface LinearSolver<T : Any, R : Ring<T>> : MatrixContext<T, R> {
|
interface LinearSolver<T : Any, R : Ring<T>> {
|
||||||
/**
|
val context: MatrixContext<T,R>
|
||||||
* Convert matrix to vector if it is possible
|
|
||||||
*/
|
|
||||||
fun Matrix<T>.toVector(): Point<T> =
|
|
||||||
if (this.colNum == 1) {
|
|
||||||
point(rowNum){ get(it, 0) }
|
|
||||||
} else error("Can't convert matrix with more than one column to vector")
|
|
||||||
|
|
||||||
fun Point<T>.toMatrix(): Matrix<T> = produce(size, 1) { i, _ -> get(i) }
|
|
||||||
|
|
||||||
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
||||||
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector()
|
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector()
|
||||||
fun inverse(a: Matrix<T>): Matrix<T> = solve(a, one(a.rowNum, a.colNum))
|
fun inverse(a: Matrix<T>): Matrix<T> = solve(a, context.one(a.rowNum, a.colNum))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -41,3 +35,15 @@ object VectorL2Norm : Norm<Point<out Number>, Double> {
|
|||||||
|
|
||||||
typealias RealVector = Vector<Double, RealField>
|
typealias RealVector = Vector<Double, RealField>
|
||||||
typealias RealMatrix = Matrix<Double>
|
typealias RealMatrix = Matrix<Double>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert matrix to vector if it is possible
|
||||||
|
*/
|
||||||
|
fun <T: Any> Matrix<T>.toVector(): Point<T> =
|
||||||
|
if (this.colNum == 1) {
|
||||||
|
VirtualBuffer(rowNum){ get(it, 0) }
|
||||||
|
} else error("Can't convert matrix with more than one column to vector")
|
||||||
|
|
||||||
|
fun <T: Any> Point<T>.toMatrix(): Matrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
@ -11,7 +11,7 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
|||||||
/**
|
/**
|
||||||
* The ring context for matrix elements
|
* The ring context for matrix elements
|
||||||
*/
|
*/
|
||||||
val ring: R
|
val elementContext: R
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Produce a matrix with this context and given dimensions
|
* Produce a matrix with this context and given dimensions
|
||||||
@ -25,7 +25,7 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
|||||||
|
|
||||||
fun scale(a: Matrix<T>, k: Number): Matrix<T> {
|
fun scale(a: Matrix<T>, k: Number): Matrix<T> {
|
||||||
//TODO create a special wrapper class for scaled matrices
|
//TODO create a special wrapper class for scaled matrices
|
||||||
return produce(a.rowNum, a.colNum) { i, j -> ring.run { a[i, j] * k } }
|
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } }
|
||||||
}
|
}
|
||||||
|
|
||||||
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||||
@ -34,7 +34,7 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
|||||||
return produce(rowNum, other.colNum) { i, j ->
|
return produce(rowNum, other.colNum) { i, j ->
|
||||||
val row = rows[i]
|
val row = rows[i]
|
||||||
val column = other.columns[j]
|
val column = other.columns[j]
|
||||||
with(ring) {
|
with(elementContext) {
|
||||||
row.asSequence().zip(column.asSequence(), ::multiply).sum()
|
row.asSequence().zip(column.asSequence(), ::multiply).sum()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -45,27 +45,27 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
|||||||
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
||||||
return point(rowNum) { i ->
|
return point(rowNum) { i ->
|
||||||
val row = rows[i]
|
val row = rows[i]
|
||||||
with(ring) {
|
with(elementContext) {
|
||||||
row.asSequence().zip(vector.asSequence(), ::multiply).sum()
|
row.asSequence().zip(vector.asSequence(), ::multiply).sum()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Matrix<T>.unaryMinus() =
|
operator fun Matrix<T>.unaryMinus() =
|
||||||
produce(rowNum, colNum) { i, j -> ring.run { -get(i, j) } }
|
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
||||||
|
|
||||||
operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T> {
|
operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T> {
|
||||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]")
|
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]")
|
||||||
return produce(rowNum, colNum) { i, j -> ring.run { get(i, j) + b[i, j] } }
|
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
||||||
return produce(rowNum, colNum) { i, j -> ring.run { get(i, j) + b[i, j] } }
|
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Matrix<T>.times(number: Number): Matrix<T> =
|
operator fun Matrix<T>.times(number: Number): Matrix<T> =
|
||||||
produce(rowNum, colNum) { i, j -> ring.run { get(i, j) * number } }
|
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } }
|
||||||
|
|
||||||
operator fun Number.times(m: Matrix<T>): Matrix<T> = m * this
|
operator fun Number.times(m: Matrix<T>): Matrix<T> = m * this
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Mat
|
|||||||
override val rowNum: Int get() = rows
|
override val rowNum: Int get() = rows
|
||||||
override val colNum: Int get() = columns
|
override val colNum: Int get() = columns
|
||||||
override val features: Set<MatrixFeature> get() = setOf(DiagonalFeature, UnitFeature)
|
override val features: Set<MatrixFeature> get() = setOf(DiagonalFeature, UnitFeature)
|
||||||
override fun get(i: Int, j: Int): T = if (i == j) ring.one else ring.zero
|
override fun get(i: Int, j: Int): T = if (i == j) elementContext.one else elementContext.zero
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,7 +169,7 @@ fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Ma
|
|||||||
override val rowNum: Int get() = rows
|
override val rowNum: Int get() = rows
|
||||||
override val colNum: Int get() = columns
|
override val colNum: Int get() = columns
|
||||||
override val features: Set<MatrixFeature> get() = setOf(ZeroFeature)
|
override val features: Set<MatrixFeature> get() = setOf(ZeroFeature)
|
||||||
override fun get(i: Int, j: Int): T = ring.zero
|
override fun get(i: Int, j: Int): T = elementContext.zero
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.MutableBuffer
|
||||||
|
import scientifik.kmath.structures.MutableBufferFactory
|
||||||
|
import scientifik.kmath.structures.MutableNDStructure
|
||||||
|
|
||||||
|
class Mutable2DStructure<T>(val rowNum: Int, val colNum: Int, val buffer: MutableBuffer<T>) : MutableNDStructure<T> {
|
||||||
|
override val shape: IntArray
|
||||||
|
get() = intArrayOf(rowNum, colNum)
|
||||||
|
|
||||||
|
operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
||||||
|
|
||||||
|
override fun get(index: IntArray): T = get(index[0], index[1])
|
||||||
|
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||||
|
for (i in 0 until rowNum) {
|
||||||
|
for (j in 0 until colNum) {
|
||||||
|
yield(intArrayOf(i, j) to get(i, j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
operator fun set(i: Int, j: Int, value: T) {
|
||||||
|
buffer[i * colNum + j] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun set(index: IntArray, value: T) = set(index[0], index[1], value)
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
fun <T> create(
|
||||||
|
rowNum: Int,
|
||||||
|
colNum: Int,
|
||||||
|
bufferFactory: MutableBufferFactory<T>,
|
||||||
|
init: (i: Int, j: Int) -> T
|
||||||
|
): Mutable2DStructure<T> {
|
||||||
|
val buffer = bufferFactory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) }
|
||||||
|
return Mutable2DStructure(rowNum, colNum, buffer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,7 @@ import scientifik.kmath.structures.ndStructure
|
|||||||
* Basic implementation of Matrix space based on [NDStructure]
|
* Basic implementation of Matrix space based on [NDStructure]
|
||||||
*/
|
*/
|
||||||
class StructureMatrixContext<T : Any, R : Ring<T>>(
|
class StructureMatrixContext<T : Any, R : Ring<T>>(
|
||||||
override val ring: R,
|
override val elementContext: R,
|
||||||
private val bufferFactory: BufferFactory<T>
|
private val bufferFactory: BufferFactory<T>
|
||||||
) : MatrixContext<T, R> {
|
) : MatrixContext<T, R> {
|
||||||
|
|
||||||
|
@ -124,8 +124,6 @@ inline class MutableListBuffer<T>(private val list: MutableList<T>) : MutableBuf
|
|||||||
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T> MutableList<T>.asBuffer() = MutableListBuffer(this)
|
|
||||||
|
|
||||||
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
|
||||||
//Can't inline because array is invariant
|
//Can't inline because array is invariant
|
||||||
override val size: Int
|
override val size: Int
|
||||||
|
@ -22,9 +22,8 @@ class MatrixTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testTranspose() {
|
fun testTranspose() {
|
||||||
val matrix = MatrixContext.real(3, 3).one
|
val matrix = MatrixContext.real.one(3, 3)
|
||||||
val transposed = matrix.transpose()
|
val transposed = matrix.transpose()
|
||||||
assertEquals(matrix.context, transposed.context)
|
|
||||||
assertEquals((matrix as StructureMatrix).structure, (transposed as StructureMatrix).structure)
|
assertEquals((matrix as StructureMatrix).structure, (transposed as StructureMatrix).structure)
|
||||||
assertEquals(matrix, transposed)
|
assertEquals(matrix, transposed)
|
||||||
}
|
}
|
||||||
@ -37,7 +36,7 @@ class MatrixTest {
|
|||||||
|
|
||||||
val matrix1 = vector1.toMatrix()
|
val matrix1 = vector1.toMatrix()
|
||||||
val matrix2 = vector2.toMatrix().transpose()
|
val matrix2 = vector2.toMatrix().transpose()
|
||||||
val product = matrix1 dot matrix2
|
val product = MatrixContext.real.run { matrix1 dot matrix2 }
|
||||||
|
|
||||||
|
|
||||||
assertEquals(5.0, product[1, 0])
|
assertEquals(5.0, product[1, 0])
|
||||||
|
@ -6,8 +6,8 @@ import kotlin.test.assertEquals
|
|||||||
class RealLUSolverTest {
|
class RealLUSolverTest {
|
||||||
@Test
|
@Test
|
||||||
fun testInvertOne() {
|
fun testInvertOne() {
|
||||||
val matrix = MatrixContext.real(2, 2).one
|
val matrix = MatrixContext.real.one(2, 2)
|
||||||
val inverted = RealLUSolver.inverse(matrix)
|
val inverted = LUSolver.real.inverse(matrix)
|
||||||
assertEquals(matrix, inverted)
|
assertEquals(matrix, inverted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user