From 7e83b080ade5111d64ccf91e8a3b07bae26aa341 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 16 Jan 2019 14:52:27 +0300 Subject: [PATCH] Matrix revision --- .../kmath/linear/LUDecomposition.kt | 258 +++++++++--------- .../scientifik/kmath/linear/LinearAlgrebra.kt | 57 ++-- .../kotlin/scientifik/kmath/linear/Matrix.kt | 217 ++++++++------- .../kmath/linear/StructureMatrix.kt | 50 ++++ .../scientifik/kmath/linear/VirtualMatrix.kt | 10 + .../scientifik/kmath/operations/Algebra.kt | 2 +- .../scientifik/kmath/structures/Buffers.kt | 21 ++ .../scientifik/kmath/linear/MatrixTest.kt | 2 +- .../kmath/linear/RealLUSolverTest.kt | 2 +- 9 files changed, 345 insertions(+), 274 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt index 2ceb5bcd3..7b645c74b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt @@ -2,31 +2,39 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Field import scientifik.kmath.operations.RealField +import scientifik.kmath.operations.Ring +import scientifik.kmath.structures.* import scientifik.kmath.structures.MutableBuffer.Companion.boxing -import scientifik.kmath.structures.MutableNDStructure -import scientifik.kmath.structures.NDStructure -import scientifik.kmath.structures.get -import scientifik.kmath.structures.mutableNdStructure -import kotlin.math.absoluteValue /** - * Implementation based on Apache common-maths LU-decomposition + * Matrix LUP decomposition */ -abstract class LUDecomposition, F : Field>(val matrix: Matrix) { +interface LUPDecompositionFeature : DeterminantFeature { + /** + * A reference to L-matrix + */ + val l: Matrix + /** + * A reference to u-matrix + */ + val u: Matrix + /** + * Pivoting points for each row + */ + val pivot: IntArray + /** + * Permutation matrix based on [pivot] + */ + val p: Matrix +} - private val field get() = matrix.context.ring - /** Entries of LU decomposition. */ - internal val lu: NDStructure - /** Pivot permutation associated with LU decomposition. */ - internal val pivot: IntArray - /** Parity of the permutation associated with the LU decomposition. */ - private var even: Boolean = false - init { - val pair = calculateLU() - lu = pair.first - pivot = pair.second - } +private class LUPDecomposition, R : Ring>( + val context: R, + val lu: NDStructure, + override val pivot: IntArray, + private val even: Boolean +) : LUPDecompositionFeature { /** * Returns the matrix L of the decomposition. @@ -34,13 +42,11 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr * L is a lower-triangular matrix * @return the L matrix (or null if decomposed matrix is singular) */ - val l: Matrix by lazy { - matrix.context.produce(matrix.numRows, matrix.numCols) { i, j -> - when { - j < i -> lu[i, j] - j == i -> matrix.context.ring.one - else -> matrix.context.ring.zero - } + override val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + when { + j < i -> lu[i, j] + j == i -> context.one + else -> context.zero } } @@ -51,12 +57,11 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr * U is an upper-triangular matrix * @return the U matrix (or null if decomposed matrix is singular) */ - val u: Matrix by lazy { - matrix.context.produce(matrix.numRows, matrix.numCols) { i, j -> - if (j >= i) lu[i, j] else field.zero - } + override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + if (j >= i) lu[i, j] else context.zero } + /** * Returns the P rows permutation matrix. * @@ -67,59 +72,64 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr * @return the P rows permutation matrix (or null if decomposed matrix is singular) * @see .getPivot */ - val p: Matrix by lazy { - matrix.context.produce(matrix.numRows, matrix.numCols) { i, j -> - //TODO ineffective. Need sparse matrix for that - if (j == pivot[i]) field.one else field.zero - } + override val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + if (j == pivot[i]) context.one else context.zero } + /** * Return the determinant of the matrix * @return determinant of the matrix */ - val determinant: T - get() { - with(matrix.context.ring) { - var determinant = if (even) one else -one - for (i in 0 until matrix.numRows) { - determinant *= lu[i, i] - } - return determinant - } + override val determinant: T by lazy { + with(context) { + (0 until lu.shape[0]).fold(if (even) one else -one) { value, i -> value * lu[i, i] } } + } + +} + + +/** + * Implementation based on Apache common-maths LU-decomposition + */ +class LUPDecompositionBuilder, F : Field>(val context: F, val bufferFactory: MutableBufferFactory = ::boxing, val singularityCheck: (T) -> Boolean) { /** * In-place transformation for [MutableNDStructure], using given transformation for each element */ - operator fun MutableNDStructure.set(i: Int, j: Int, value: T) { + private operator fun MutableNDStructure.set(i: Int, j: Int, value: T) { this[intArrayOf(i, j)] = value } - abstract fun isSingular(value: T): Boolean + private fun abs(value: T) = if (value > context.zero) value else with(context) { -value } - private fun abs(value: T) = if (value > matrix.context.ring.zero) value else with(matrix.context.ring) { -value } + fun decompose(matrix: Matrix): LUPDecompositionFeature { + // Use existing decomposition if it is provided by matrix + matrix.features.find { it is LUPDecompositionFeature<*> }?.let { + @Suppress("UNCHECKED_CAST") + return it as LUPDecompositionFeature + } - private fun calculateLU(): Pair, IntArray> { - if (matrix.numRows != matrix.numCols) { + if (matrix.rowNum != matrix.colNum) { error("LU decomposition supports only square matrices") } - val m = matrix.numCols - val pivot = IntArray(matrix.numRows) - //TODO fix performance + val m = matrix.colNum + val pivot = IntArray(matrix.rowNum) + //TODO replace by custom optimized 2d structure val lu: MutableNDStructure = mutableNdStructure( - intArrayOf(matrix.numRows, matrix.numCols), - ::boxing + intArrayOf(matrix.rowNum, matrix.colNum), + bufferFactory ) { index: IntArray -> matrix[index[0], index[1]] } - with(matrix.context.ring) { + with(context) { // Initialize permutation array and parity for (row in 0 until m) { pivot[row] = row } - even = true + var even = true // Loop over columns for (col in 0 until m) { @@ -139,22 +149,18 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr for (i in 0 until col) { sum -= lu[row, i] * lu[i, col] } - //luRow[col] = sum lu[row, col] = sum abs(sum) } ?: col // Singularity check - if (isSingular(lu[max, col])) { + if (singularityCheck(lu[max, col])) { error("Singular matrix") } // Pivot if necessary if (max != col) { - //var tmp = zero - //val luMax = lu[max] - //val luCol = lu[col] for (i in 0 until m) { lu[max, i] = lu[col, i] lu[col, i] = lu[max, i] @@ -169,86 +175,70 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr val luDiag = lu[col, col] for (row in col + 1 until m) { lu[row, col] = lu[row, col] / luDiag -// lu[row, col] /= luDiag } } + return LUPDecomposition(context, lu, pivot, even) } - return Pair(lu, pivot) - } - - /** - * Returns the pivot permutation vector. - * @return the pivot permutation vector - * @see .getP - */ - fun getPivot(): IntArray = pivot.copyOf() - -} - -class RealLUDecomposition(matrix: RealMatrix, private val singularityThreshold: Double = DEFAULT_TOO_SMALL) : - LUDecomposition(matrix) { - override fun isSingular(value: Double): Boolean { - return value.absoluteValue < singularityThreshold } companion object { - /** Default bound to determine effective singularity in LU decomposition. */ - private const val DEFAULT_TOO_SMALL = 1e-11 + val real: LUPDecompositionBuilder = LUPDecompositionBuilder(RealField) { it < 1e-11 } } + } -/** Specialized solver. */ -object RealLUSolver : LinearSolver { - - fun decompose(mat: Matrix, threshold: Double = 1e-11): RealLUDecomposition = - RealLUDecomposition(mat, threshold) - - override fun solve(a: RealMatrix, b: RealMatrix): RealMatrix { - val decomposition = decompose(a, a.context.ring.zero) - - if (b.numRows != a.numCols) { - error("Matrix dimension mismatch expected ${a.numRows}, but got ${b.numCols}") - } - - // Apply permutations to b - val bp = Array(a.numRows) { DoubleArray(b.numCols) } - for (row in 0 until a.numRows) { - val bpRow = bp[row] - val pRow = decomposition.pivot[row] - for (col in 0 until b.numCols) { - bpRow[col] = b[pRow, col] - } - } - - // Solve LY = b - for (col in 0 until a.numRows) { - val bpCol = bp[col] - for (i in col + 1 until a.numRows) { - val bpI = bp[i] - val luICol = decomposition.lu[i, col] - for (j in 0 until b.numCols) { - bpI[j] -= bpCol[j] * luICol - } - } - } - - // Solve UX = Y - for (col in a.numRows - 1 downTo 0) { - val bpCol = bp[col] - val luDiag = decomposition.lu[col, col] - for (j in 0 until b.numCols) { - bpCol[j] /= luDiag - } - for (i in 0 until col) { - val bpI = bp[i] - val luICol = decomposition.lu[i, col] - for (j in 0 until b.numCols) { - bpI[j] -= bpCol[j] * luICol - } - } - } - - return a.context.produce(a.numRows, a.numCols) { i, j -> bp[i][j] } - } -} +//class LUSolver, F : Field>(val singularityCheck: (T) -> Boolean) : LinearSolver { +// +// +// override fun solve(a: Matrix, b: Matrix): Matrix { +// val decomposition = LUPDecompositionBuilder(ring, singularityCheck).decompose(a) +// +// if (b.rowNum != a.colNum) { +// error("Matrix dimension mismatch expected ${a.rowNum}, but got ${b.colNum}") +// } +// +// +//// val bp = Array(a.rowNum) { Array(b.colNum){ring.zero} } +//// for (row in 0 until a.rowNum) { +//// val bpRow = bp[row] +//// val pRow = decomposition.pivot[row] +//// for (col in 0 until b.colNum) { +//// bpRow[col] = b[pRow, col] +//// } +//// } +// +// // Apply permutations to b +// val bp = produce(a.rowNum, a.colNum) { i, j -> b[decomposition.pivot[i], j] } +// +// // Solve LY = b +// for (col in 0 until a.rowNum) { +// val bpCol = bp[col] +// for (i in col + 1 until a.rowNum) { +// val bpI = bp[i] +// val luICol = decomposition.lu[i, col] +// for (j in 0 until b.colNum) { +// bpI[j] -= bpCol[j] * luICol +// } +// } +// } +// +// // Solve UX = Y +// for (col in a.rowNum - 1 downTo 0) { +// val bpCol = bp[col] +// val luDiag = decomposition.lu[col, col] +// for (j in 0 until b.colNum) { +// bpCol[j] /= luDiag +// } +// for (i in 0 until col) { +// val bpI = bp[i] +// val luICol = decomposition.lu[i, col] +// for (j in 0 until b.colNum) { +// bpI[j] -= bpCol[j] * luICol +// } +// } +// } +// +// return produce(a.rowNum, a.colNum) { i, j -> bp[i][j] } +// } +//} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt index a3d2feccc..56ee448ab 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt @@ -4,18 +4,26 @@ import scientifik.kmath.operations.Field import scientifik.kmath.operations.Norm import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring -import scientifik.kmath.structures.Buffer.Companion.boxing import scientifik.kmath.structures.asSequence - /** * A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors */ -interface LinearSolver> { - fun solve(a: Matrix, b: Matrix): Matrix - fun solve(a: Matrix, b: Vector): Vector = solve(a, b.toMatrix()).toVector() - fun inverse(a: Matrix): Matrix = solve(a, a.context.one) +interface LinearSolver> : MatrixContext { + /** + * Convert matrix to vector if it is possible + */ + fun Matrix.toVector(): Point = + if (this.colNum == 1) { + point(rowNum){ get(it, 0) } + } else error("Can't convert matrix with more than one column to vector") + + fun Point.toMatrix(): Matrix = produce(size, 1) { i, _ -> get(i) } + + fun solve(a: Matrix, b: Matrix): Matrix + fun solve(a: Matrix, b: Point): Point = solve(a, b.toMatrix()).toVector() + fun inverse(a: Matrix): Matrix = solve(a, one(a.rowNum, a.colNum)) } /** @@ -26,39 +34,10 @@ fun Array.toVector(field: Field) = Vector.generic(size, field) { fun DoubleArray.toVector() = Vector.real(this.size) { this[it] } fun List.toVector() = Vector.real(this.size) { this[it] } -/** - * Convert matrix to vector if it is possible - */ -fun > Matrix.toVector(): Vector = - if (this.numCols == 1) { -// if (this is ArrayMatrix) { -// //Reuse existing underlying array -// ArrayVector(ArrayVectorSpace(rows, context.field, context.ndFactory), array) -// } else { -// //Generic vector -// vector(rows, context.field) { get(it, 0) } -// } - Vector.generic(numRows, context.ring) { get(it, 0) } - } else error("Can't convert matrix with more than one column to vector") - -fun > Vector.toMatrix(): Matrix { -// val context = StructureMatrixContext(size, 1, context.space) -// -// return if (this is ArrayVector) { -// //Reuse existing underlying array -// StructureMatrix(context,this.buffer) -// } else { -// //Generic vector -// matrix(size, 1, context.field) { i, j -> get(i) } -// } - //return Matrix.of(size, 1, context.space) { i, _ -> get(i) } - return StructureMatrixSpace(size, 1, context.space, ::boxing).produce { i, _ -> get(i) } -} - -object VectorL2Norm : Norm, Double> { - override fun norm(arg: Vector): Double = - kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() }) +object VectorL2Norm : Norm, Double> { + override fun norm(arg: Point): Double = + kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() }) } typealias RealVector = Vector -typealias RealMatrix = Matrix +typealias RealMatrix = Matrix diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index 0edf748e6..e7b791a86 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -2,64 +2,91 @@ package scientifik.kmath.linear import scientifik.kmath.operations.RealField import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.SpaceElement import scientifik.kmath.structures.* import scientifik.kmath.structures.Buffer.Companion.DoubleBufferFactory import scientifik.kmath.structures.Buffer.Companion.boxing -interface MatrixSpace> : Space> { +interface MatrixContext> { /** * The ring context for matrix elements */ val ring: R - val rowNum: Int - val colNum: Int - /** * Produce a matrix with this context and given dimensions */ - fun produce(rows: Int = rowNum, columns: Int = colNum, initializer: (i: Int, j: Int) -> T): Matrix + fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix /** * Produce a point compatible with matrix space */ fun point(size: Int, initializer: (Int) -> T): Point - override val zero: Matrix get() = produce { _, _ -> ring.zero } + fun scale(a: Matrix, k: Number): Matrix { + //TODO create a special wrapper class for scaled matrices + return produce(a.rowNum, a.colNum) { i, j -> ring.run { a[i, j] * k } } + } - val one get() = produce { i, j -> if (i == j) ring.one else ring.zero } + infix fun Matrix.dot(other: Matrix): Matrix { + //TODO add typed error + if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") + return produce(rowNum, other.colNum) { i, j -> + val row = rows[i] + val column = other.columns[j] + with(ring) { + row.asSequence().zip(column.asSequence(), ::multiply).sum() + } + } + } - override fun add(a: Matrix, b: Matrix): Matrix = - produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } } + infix fun Matrix.dot(vector: Point): Point { + //TODO add typed error + if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})") + return point(rowNum) { i -> + val row = rows[i] + with(ring) { + row.asSequence().zip(vector.asSequence(), ::multiply).sum() + } + } + } + + operator fun Matrix.unaryMinus() = + produce(rowNum, colNum) { i, j -> ring.run { -get(i, j) } } + + operator fun Matrix.plus(b: Matrix): Matrix { + if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]") + return produce(rowNum, colNum) { i, j -> ring.run { get(i, j) + b[i, j] } } + } + + operator fun Matrix.minus(b: Matrix): Matrix { + if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]") + return produce(rowNum, colNum) { i, j -> ring.run { get(i, j) + b[i, j] } } + } + + operator fun Matrix.times(number: Number): Matrix = + produce(rowNum, colNum) { i, j -> ring.run { get(i, j) * number } } + + operator fun Number.times(m: Matrix): Matrix = m * this - override fun multiply(a: Matrix, k: Number): Matrix = - produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } } companion object { /** * Non-boxing double matrix */ - fun real(rows: Int, columns: Int): MatrixSpace = - StructureMatrixSpace(rows, columns, RealField, DoubleBufferFactory) + val real: MatrixContext = StructureMatrixContext(RealField, DoubleBufferFactory) /** * A structured matrix with custom buffer */ - fun > buffered( - rows: Int, - columns: Int, - ring: R, - bufferFactory: BufferFactory = ::boxing - ): MatrixSpace = StructureMatrixSpace(rows, columns, ring, bufferFactory) + fun > buffered(ring: R, bufferFactory: BufferFactory = ::boxing): MatrixContext = + StructureMatrixContext(ring, bufferFactory) /** * Automatic buffered matrix, unboxed if it is possible */ - inline fun > auto(rows: Int, columns: Int, ring: R): MatrixSpace = - buffered(rows, columns, ring, Buffer.Companion::auto) + inline fun > auto(ring: R): MatrixContext = + buffered(ring, Buffer.Companion::auto) } } @@ -69,108 +96,102 @@ interface MatrixSpace> : Space> { */ interface MatrixFeature +object DiagonalFeature : MatrixFeature + +object ZeroFeature : MatrixFeature + +object UnitFeature : MatrixFeature + +interface InverseMatrixFeature : MatrixFeature { + val inverse: Matrix +} + +interface DeterminantFeature : MatrixFeature { + val determinant: T +} + /** * Specialized 2-d structure */ -interface Matrix> : NDStructure, SpaceElement, Matrix, MatrixSpace> { +interface Matrix : NDStructure { + val rowNum: Int + val colNum: Int + + val features: Set + operator fun get(i: Int, j: Int): T override fun get(index: IntArray): T = get(index[0], index[1]) - val numRows get() = context.rowNum - val numCols get() = context.colNum + override val shape: IntArray get() = intArrayOf(rowNum, colNum) - //TODO replace by lazy buffers val rows: Point> - get() = ListBuffer((0 until numRows).map { i -> - context.point(numCols) { j -> get(i, j) } - }) + get() = VirtualBuffer(rowNum) { i -> + VirtualBuffer(colNum) { j -> get(i, j) } + } val columns: Point> - get() = ListBuffer((0 until numCols).map { j -> - context.point(numRows) { i -> get(i, j) } - }) + get() = VirtualBuffer(colNum) { j -> + VirtualBuffer(rowNum) { i -> get(i, j) } + } - val features: Set + override fun elements(): Sequence> = sequence { + for (i in (0 until rowNum)) { + for (j in (0 until colNum)) { + yield(intArrayOf(i, j) to get(i, j)) + } + } + } companion object { fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = - MatrixSpace.real(rows, columns).produce(rows, columns, initializer) + MatrixContext.real.produce(rows, columns, initializer) } } -infix fun > Matrix.dot(other: Matrix): Matrix { - //TODO add typed error - if (this.numCols != other.numRows) error("Matrix dot operation dimension mismatch: ($numRows, $numCols) x (${other.numRows}, ${other.numCols})") - return context.produce(numRows, other.numCols) { i, j -> - val row = rows[i] - val column = other.columns[j] - with(context.ring) { - row.asSequence().zip(column.asSequence(), ::multiply).sum() - } +/** + * Diagonal matrix of ones. The matrix is virtual no actual matrix is created + */ +fun > MatrixContext.one(rows: Int, columns: Int): Matrix { + return object : Matrix { + override val rowNum: Int get() = rows + override val colNum: Int get() = columns + override val features: Set get() = setOf(DiagonalFeature, UnitFeature) + override fun get(i: Int, j: Int): T = if (i == j) ring.one else ring.zero } } -infix fun > Matrix.dot(vector: Point): Point { - //TODO add typed error - if (this.numCols != vector.size) error("Matrix dot vector operation dimension mismatch: ($numRows, $numCols) x (${vector.size})") - return context.point(numRows) { i -> - val row = rows[i] - with(context.ring) { - row.asSequence().zip(vector.asSequence(), ::multiply).sum() - } +/** + * A virtual matrix of zeroes + */ +fun > MatrixContext.zero(rows: Int, columns: Int): Matrix { + return object : Matrix { + override val rowNum: Int get() = rows + override val colNum: Int get() = columns + override val features: Set get() = setOf(ZeroFeature) + override fun get(i: Int, j: Int): T = ring.zero } } -data class StructureMatrixSpace>( - override val rowNum: Int, - override val colNum: Int, - override val ring: R, - private val bufferFactory: BufferFactory -) : MatrixSpace { - val shape: IntArray = intArrayOf(rowNum, colNum) +inline class TransposedMatrix(val original: Matrix) : Matrix { + override val rowNum: Int get() = original.colNum + override val colNum: Int get() = original.rowNum + override val features: Set get() = emptySet() //TODO retain some features - private val strides = DefaultStrides(shape) + override fun get(i: Int, j: Int): T = original[j, i] - override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix { - return if (rows == rowNum && columns == colNum) { - val structure = ndStructure(strides, bufferFactory) { initializer(it[0], it[1]) } - StructureMatrix(this, structure) - } else { - val context = StructureMatrixSpace(rows, columns, ring, bufferFactory) - val structure = ndStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) } - StructureMatrix(context, structure) - } - } - - override fun point(size: Int, initializer: (Int) -> T): Point = bufferFactory(size, initializer) + override fun elements(): Sequence> = + original.elements().map { (key, value) -> intArrayOf(key[1], key[0]) to value } } -data class StructureMatrix>( - override val context: StructureMatrixSpace, - val structure: NDStructure, - override val features: Set = emptySet() -) : Matrix { - init { - if (structure.shape.size != 2 || structure.shape[0] != context.rowNum || structure.shape[1] != context.colNum) { - error("Dimension mismatch for structure, (${context.rowNum}, ${context.colNum}) expected, but ${structure.shape} found") - } +/** + * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` + */ +fun > Matrix.transpose(): Matrix { + return if (this is TransposedMatrix) { + original + } else { + TransposedMatrix(this) } - - override fun unwrap(): Matrix = this - - override fun Matrix.wrap(): Matrix = this - - override val shape: IntArray get() = structure.shape - - override fun get(index: IntArray): T = structure[index] - - override fun get(i: Int, j: Int): T = structure[i, j] - - override fun elements(): Sequence> = structure.elements() -} - -//TODO produce transposed matrix via reference without creating new space and structure -fun > Matrix.transpose(): Matrix = - context.produce(numCols, numRows) { i, j -> get(j, i) } \ No newline at end of file +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt new file mode 100644 index 000000000..a8a1d692c --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt @@ -0,0 +1,50 @@ +package scientifik.kmath.linear + +import scientifik.kmath.operations.Ring +import scientifik.kmath.structures.BufferFactory +import scientifik.kmath.structures.NDStructure +import scientifik.kmath.structures.get +import scientifik.kmath.structures.ndStructure + +/** + * Basic implementation of Matrix space based on [NDStructure] + */ +class StructureMatrixContext>( + override val ring: R, + private val bufferFactory: BufferFactory +) : MatrixContext { + + override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix { + val structure = + ndStructure(intArrayOf(rows, columns), bufferFactory) { index -> initializer(index[0], index[1]) } + return StructureMatrix(structure) + } + + override fun point(size: Int, initializer: (Int) -> T): Point = bufferFactory(size, initializer) +} + +data class StructureMatrix( + val structure: NDStructure, + override val features: Set = emptySet() +) : Matrix { + + init { + if (structure.shape.size != 2) { + error("Dimension mismatch for matrix structure") + } + } + + override val rowNum: Int + get() = structure.shape[0] + override val colNum: Int + get() = structure.shape[1] + + + override val shape: IntArray get() = structure.shape + + override fun get(index: IntArray): T = structure[index] + + override fun get(i: Int, j: Int): T = structure[i, j] + + override fun elements(): Sequence> = structure.elements() +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt new file mode 100644 index 000000000..59c89b197 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt @@ -0,0 +1,10 @@ +package scientifik.kmath.linear + +class VirtualMatrix( + override val rowNum: Int, + override val colNum: Int, + override val features: Set = emptySet(), + val generator: (i: Int, j: Int) -> T +) : Matrix { + override fun get(i: Int, j: Int): T = generator(i, j) +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 1534f4c82..8aa3a0e05 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -33,8 +33,8 @@ interface Space { operator fun T.times(k: Number) = multiply(this, k.toDouble()) operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble()) operator fun Number.times(b: T) = b * this - fun Iterable.sum(): T = fold(zero) { left, right -> add(left,right) } + fun Iterable.sum(): T = fold(zero) { left, right -> add(left,right) } fun Sequence.sum(): T = fold(zero) { left, right -> add(left, right) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index f7c04bdbe..a8a9d7ed3 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -95,6 +95,8 @@ inline class ListBuffer(private val list: List) : Buffer { override fun iterator(): Iterator = list.iterator() } +fun List.asBuffer() = ListBuffer(this) + inline class MutableListBuffer(private val list: MutableList) : MutableBuffer { override val size: Int @@ -191,6 +193,25 @@ inline class ReadOnlyBuffer(private val buffer: MutableBuffer) : Buffer override fun iterator(): Iterator = buffer.iterator() } +/** + * A buffer with content calculated on-demand. The calculated contect is not stored, so it is recalculated on each call. + * Useful when one needs single element from the buffer. + */ +class VirtualBuffer(override val size: Int, private val generator: (Int) -> T) : Buffer { + override fun get(index: Int): T = generator(index) + + override fun iterator(): Iterator = (0 until size).asSequence().map(generator).iterator() + + override fun contentEquals(other: Buffer<*>): Boolean { + return if (other is VirtualBuffer) { + this.size == other.size && this.generator == other.generator + } else { + super.contentEquals(other) + } + } + +} + /** * Convert this buffer to read-only buffer */ diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt index d02cb2c6f..fbbe06b4f 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt @@ -22,7 +22,7 @@ class MatrixTest { @Test fun testTranspose() { - val matrix = MatrixSpace.real(3, 3).one + val matrix = MatrixContext.real(3, 3).one val transposed = matrix.transpose() assertEquals(matrix.context, transposed.context) assertEquals((matrix as StructureMatrix).structure, (transposed as StructureMatrix).structure) diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt index 569b9e450..f167f2f8f 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt @@ -6,7 +6,7 @@ import kotlin.test.assertEquals class RealLUSolverTest { @Test fun testInvertOne() { - val matrix = MatrixSpace.real(2, 2).one + val matrix = MatrixContext.real(2, 2).one val inverted = RealLUSolver.inverse(matrix) assertEquals(matrix, inverted) }