From cedb8a816e39896f76948cd384c06a55acef5832 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 16 Jan 2019 19:13:09 +0300 Subject: [PATCH] LUDecomposition finished. Not tested --- build.gradle.kts | 2 +- .../kmath/linear/LUDecomposition.kt | 102 +++++++++--------- .../scientifik/kmath/linear/LinearAlgrebra.kt | 26 +++-- .../kotlin/scientifik/kmath/linear/Matrix.kt | 20 ++-- .../kmath/linear/Mutable2DStructure.kt | 40 +++++++ .../kmath/linear/StructureMatrix.kt | 2 +- .../scientifik/kmath/structures/Buffers.kt | 2 - .../scientifik/kmath/linear/MatrixTest.kt | 5 +- .../kmath/linear/RealLUSolverTest.kt | 4 +- 9 files changed, 121 insertions(+), 82 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Mutable2DStructure.kt diff --git a/build.gradle.kts b/build.gradle.kts index 37e61237e..21656f173 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,5 +1,5 @@ buildscript { - extra["kotlinVersion"] = "1.3.20-eap-52" + extra["kotlinVersion"] = "1.3.20-eap-100" extra["ioVersion"] = "0.1.2" extra["coroutinesVersion"] = "1.1.0" diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt index 62e9ba903..651613489 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt @@ -57,7 +57,7 @@ private class LUPDecomposition, R : Ring>( * U is an upper-triangular matrix * @return the U matrix (or null if decomposed matrix is singular) */ - override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j >= i) lu[i, j] else context.zero } @@ -93,7 +93,11 @@ private class LUPDecomposition, R : Ring>( /** * Implementation based on Apache common-maths LU-decomposition */ -class LUPDecompositionBuilder, F : Field>(val context: F, val bufferFactory: MutableBufferFactory = ::boxing, val singularityCheck: (T) -> Boolean) { +class LUPDecompositionBuilder, F : Field>( + val context: F, + val bufferFactory: MutableBufferFactory = ::boxing, + val singularityCheck: (T) -> Boolean +) { /** * In-place transformation for [MutableNDStructure], using given transformation for each element @@ -105,23 +109,16 @@ class LUPDecompositionBuilder, F : Field>(val context: F, v private fun abs(value: T) = if (value > context.zero) value else with(context) { -value } fun decompose(matrix: Matrix): LUPDecompositionFeature { - // Use existing decomposition if it is provided by matrix - matrix.features.find { it is LUPDecompositionFeature<*> }?.let { - @Suppress("UNCHECKED_CAST") - return it as LUPDecompositionFeature - } - if (matrix.rowNum != matrix.colNum) { error("LU decomposition supports only square matrices") } val m = matrix.colNum val pivot = IntArray(matrix.rowNum) - //TODO replace by custom optimized 2d structure - val lu: MutableNDStructure = mutableNdStructure( - intArrayOf(matrix.rowNum, matrix.colNum), - bufferFactory - ) { index: IntArray -> matrix[index[0], index[1]] } + + val lu = Mutable2DStructure.create(matrix.rowNum, matrix.colNum, bufferFactory) { i, j -> + matrix[i, j] + } with(context) { @@ -188,57 +185,56 @@ class LUPDecompositionBuilder, F : Field>(val context: F, v } -class LUSolver, F : Field>(val singularityCheck: (T) -> Boolean) : LinearSolver { +class LUSolver, F : Field>( + override val context: MatrixContext, + val bufferFactory: MutableBufferFactory = ::boxing, + val singularityCheck: (T) -> Boolean +) : LinearSolver { override fun solve(a: Matrix, b: Matrix): Matrix { - val decomposition = LUPDecompositionBuilder(ring, singularityCheck).decompose(a) - if (b.rowNum != a.colNum) { error("Matrix dimension mismatch expected ${a.rowNum}, but got ${b.colNum}") } + // Use existing decomposition if it is provided by matrix + @Suppress("UNCHECKED_CAST") + val decomposition = a.features.find { it is LUPDecompositionFeature<*> }?.let { + it as LUPDecompositionFeature + } ?: LUPDecompositionBuilder(context.elementContext, bufferFactory, singularityCheck).decompose(a) -// val bp = Array(a.rowNum) { Array(b.colNum){ring.zero} } -// for (row in 0 until a.rowNum) { -// val bpRow = bp[row] -// val pRow = decomposition.pivot[row] -// for (col in 0 until b.colNum) { -// bpRow[col] = b[pRow, col] -// } -// } - - // Apply permutations to b - val bp = produce(a.rowNum, a.colNum) { i, j -> b[decomposition.pivot[i], j] } - - // Solve LY = b - for (col in 0 until a.rowNum) { - val bpCol = bp[col] - for (i in col + 1 until a.rowNum) { - val bpI = bp[i] - val luICol = decomposition.lu[i, col] - for (j in 0 until b.colNum) { - bpI[j] -= bpCol[j] * luICol + with(decomposition) { + with(context.elementContext) { + // Apply permutations to b + val bp = Mutable2DStructure.create(a.rowNum, a.colNum, bufferFactory) { i, j -> + b[pivot[i], j] } + + // Solve LY = b + for (col in 0 until a.rowNum) { + for (i in col + 1 until a.rowNum) { + for (j in 0 until b.colNum) { + bp[i, j] -= bp[col, j] * l[i, col] + } + } + } + + + // Solve UX = Y + for (col in a.rowNum - 1 downTo 0) { + for (i in 0 until col) { + for (j in 0 until b.colNum) { + bp[i, j] -= bp[col, j] / u[col, col] * u[i, col] + } + } + } + + return context.produce(a.rowNum, a.colNum) { i, j -> bp[i, j] } } } + } - // Solve UX = Y - for (col in a.rowNum - 1 downTo 0) { - val bpCol = bp[col] - val luDiag = decomposition.lu[col, col] - for (j in 0 until b.colNum) { - bpCol[j] /= luDiag - } - for (i in 0 until col) { - val bpI = bp[i] - val luICol = decomposition.lu[i, col] - for (j in 0 until b.colNum) { - bpI[j] -= bpCol[j] * luICol - } - } - } - - return produce(a.rowNum, a.colNum) { i, j -> bp[i][j] } + companion object { + val real = LUSolver(MatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 } } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt index 56ee448ab..fb542c16b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt @@ -4,26 +4,20 @@ import scientifik.kmath.operations.Field import scientifik.kmath.operations.Norm import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring +import scientifik.kmath.structures.VirtualBuffer import scientifik.kmath.structures.asSequence /** * A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors */ -interface LinearSolver> : MatrixContext { - /** - * Convert matrix to vector if it is possible - */ - fun Matrix.toVector(): Point = - if (this.colNum == 1) { - point(rowNum){ get(it, 0) } - } else error("Can't convert matrix with more than one column to vector") +interface LinearSolver> { + val context: MatrixContext - fun Point.toMatrix(): Matrix = produce(size, 1) { i, _ -> get(i) } fun solve(a: Matrix, b: Matrix): Matrix fun solve(a: Matrix, b: Point): Point = solve(a, b.toMatrix()).toVector() - fun inverse(a: Matrix): Matrix = solve(a, one(a.rowNum, a.colNum)) + fun inverse(a: Matrix): Matrix = solve(a, context.one(a.rowNum, a.colNum)) } /** @@ -41,3 +35,15 @@ object VectorL2Norm : Norm, Double> { typealias RealVector = Vector typealias RealMatrix = Matrix + + + +/** + * Convert matrix to vector if it is possible + */ +fun Matrix.toVector(): Point = + if (this.colNum == 1) { + VirtualBuffer(rowNum){ get(it, 0) } + } else error("Can't convert matrix with more than one column to vector") + +fun Point.toMatrix(): Matrix = VirtualMatrix(size, 1) { i, _ -> get(i) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index e7b791a86..8cd9a6b8f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -11,7 +11,7 @@ interface MatrixContext> { /** * The ring context for matrix elements */ - val ring: R + val elementContext: R /** * Produce a matrix with this context and given dimensions @@ -25,7 +25,7 @@ interface MatrixContext> { fun scale(a: Matrix, k: Number): Matrix { //TODO create a special wrapper class for scaled matrices - return produce(a.rowNum, a.colNum) { i, j -> ring.run { a[i, j] * k } } + return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } } } infix fun Matrix.dot(other: Matrix): Matrix { @@ -34,7 +34,7 @@ interface MatrixContext> { return produce(rowNum, other.colNum) { i, j -> val row = rows[i] val column = other.columns[j] - with(ring) { + with(elementContext) { row.asSequence().zip(column.asSequence(), ::multiply).sum() } } @@ -45,27 +45,27 @@ interface MatrixContext> { if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})") return point(rowNum) { i -> val row = rows[i] - with(ring) { + with(elementContext) { row.asSequence().zip(vector.asSequence(), ::multiply).sum() } } } operator fun Matrix.unaryMinus() = - produce(rowNum, colNum) { i, j -> ring.run { -get(i, j) } } + produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } } operator fun Matrix.plus(b: Matrix): Matrix { if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]") - return produce(rowNum, colNum) { i, j -> ring.run { get(i, j) + b[i, j] } } + return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } } } operator fun Matrix.minus(b: Matrix): Matrix { if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]") - return produce(rowNum, colNum) { i, j -> ring.run { get(i, j) + b[i, j] } } + return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } } } operator fun Matrix.times(number: Number): Matrix = - produce(rowNum, colNum) { i, j -> ring.run { get(i, j) * number } } + produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } } operator fun Number.times(m: Matrix): Matrix = m * this @@ -157,7 +157,7 @@ fun > MatrixContext.one(rows: Int, columns: Int): Mat override val rowNum: Int get() = rows override val colNum: Int get() = columns override val features: Set get() = setOf(DiagonalFeature, UnitFeature) - override fun get(i: Int, j: Int): T = if (i == j) ring.one else ring.zero + override fun get(i: Int, j: Int): T = if (i == j) elementContext.one else elementContext.zero } } @@ -169,7 +169,7 @@ fun > MatrixContext.zero(rows: Int, columns: Int): Ma override val rowNum: Int get() = rows override val colNum: Int get() = columns override val features: Set get() = setOf(ZeroFeature) - override fun get(i: Int, j: Int): T = ring.zero + override fun get(i: Int, j: Int): T = elementContext.zero } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Mutable2DStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Mutable2DStructure.kt new file mode 100644 index 000000000..0d35c57e6 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Mutable2DStructure.kt @@ -0,0 +1,40 @@ +package scientifik.kmath.linear + +import scientifik.kmath.structures.MutableBuffer +import scientifik.kmath.structures.MutableBufferFactory +import scientifik.kmath.structures.MutableNDStructure + +class Mutable2DStructure(val rowNum: Int, val colNum: Int, val buffer: MutableBuffer) : MutableNDStructure { + override val shape: IntArray + get() = intArrayOf(rowNum, colNum) + + operator fun get(i: Int, j: Int): T = buffer[i * colNum + j] + + override fun get(index: IntArray): T = get(index[0], index[1]) + + override fun elements(): Sequence> = sequence { + for (i in 0 until rowNum) { + for (j in 0 until colNum) { + yield(intArrayOf(i, j) to get(i, j)) + } + } + } + + operator fun set(i: Int, j: Int, value: T) { + buffer[i * colNum + j] = value + } + + override fun set(index: IntArray, value: T) = set(index[0], index[1], value) + + companion object { + fun create( + rowNum: Int, + colNum: Int, + bufferFactory: MutableBufferFactory, + init: (i: Int, j: Int) -> T + ): Mutable2DStructure { + val buffer = bufferFactory(rowNum * colNum) { offset -> init(offset / colNum, offset % colNum) } + return Mutable2DStructure(rowNum, colNum, buffer) + } + } +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt index a8a1d692c..1109297be 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt @@ -10,7 +10,7 @@ import scientifik.kmath.structures.ndStructure * Basic implementation of Matrix space based on [NDStructure] */ class StructureMatrixContext>( - override val ring: R, + override val elementContext: R, private val bufferFactory: BufferFactory ) : MatrixContext { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index 1cd2dc502..8f77c1d81 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -124,8 +124,6 @@ inline class MutableListBuffer(private val list: MutableList) : MutableBuf override fun copy(): MutableBuffer = MutableListBuffer(ArrayList(list)) } -fun MutableList.asBuffer() = MutableListBuffer(this) - class ArrayBuffer(private val array: Array) : MutableBuffer { //Can't inline because array is invariant override val size: Int diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt index fbbe06b4f..28934ab7f 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt @@ -22,9 +22,8 @@ class MatrixTest { @Test fun testTranspose() { - val matrix = MatrixContext.real(3, 3).one + val matrix = MatrixContext.real.one(3, 3) val transposed = matrix.transpose() - assertEquals(matrix.context, transposed.context) assertEquals((matrix as StructureMatrix).structure, (transposed as StructureMatrix).structure) assertEquals(matrix, transposed) } @@ -37,7 +36,7 @@ class MatrixTest { val matrix1 = vector1.toMatrix() val matrix2 = vector2.toMatrix().transpose() - val product = matrix1 dot matrix2 + val product = MatrixContext.real.run { matrix1 dot matrix2 } assertEquals(5.0, product[1, 0]) diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt index f167f2f8f..daf749e8a 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt @@ -6,8 +6,8 @@ import kotlin.test.assertEquals class RealLUSolverTest { @Test fun testInvertOne() { - val matrix = MatrixContext.real(2, 2).one - val inverted = RealLUSolver.inverse(matrix) + val matrix = MatrixContext.real.one(2, 2) + val inverted = LUSolver.real.inverse(matrix) assertEquals(matrix, inverted) }