diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt index aed4f4bb1..808b2768b 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt @@ -4,11 +4,16 @@ import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.RealMatrix import org.apache.commons.math3.linear.RealVector -inline class CMMatrix(val origin: RealMatrix) : Matrix { +class CMMatrix(val origin: RealMatrix, features: Set? = null) : Matrix { override val rowNum: Int get() = origin.rowDimension override val colNum: Int get() = origin.columnDimension - override val features: Set get() = emptySet() + override val features: Set = features ?: sequence { + if(origin is DiagonalMatrix) yield(DiagonalFeature) + }.toSet() + + override fun suggestFeature(vararg features: MatrixFeature) = + CMMatrix(origin, this.features + features) override fun get(i: Int, j: Int): Double = origin.getEntry(i, j) } @@ -23,7 +28,7 @@ fun Matrix.toCM(): CMMatrix = if (this is CMMatrix) { fun RealMatrix.toMatrix() = CMMatrix(this) -inline class CMVector(val origin: RealVector) : Point { +class CMVector(val origin: RealVector) : Point { override val size: Int get() = origin.dimension override fun get(index: Int): Double = origin.getEntry(index) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index 5b65821db..3752f6db5 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -35,6 +35,9 @@ class BufferMatrix( override val shape: IntArray get() = intArrayOf(rowNum, colNum) + override fun suggestFeature(vararg features: MatrixFeature) = + BufferMatrix(rowNum, colNum, buffer, this.features + features) + override fun get(index: IntArray): T = get(index[0], index[1]) override fun get(i: Int, j: Int): T = buffer[i * colNum + j] diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt index b2161baf2..85ddb8786 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt @@ -8,20 +8,19 @@ import scientifik.kmath.structures.MutableBufferFactory import scientifik.kmath.structures.NDStructure import scientifik.kmath.structures.get - class LUPDecomposition>( private val elementContext: Ring, internal val lu: NDStructure, val pivot: IntArray, private val even: Boolean -) : DeterminantFeature { +) : LUPDecompositionFeature, DeterminantFeature { /** * Returns the matrix L of the decomposition. * * L is a lower-triangular matrix with [Ring.one] in diagonal */ - val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + override val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j -> when { j < i -> lu[i, j] j == i -> elementContext.one @@ -35,7 +34,7 @@ class LUPDecomposition>( * * U is an upper-triangular matrix including the diagonal */ - val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j -> if (j >= i) lu[i, j] else elementContext.zero } @@ -46,7 +45,7 @@ class LUPDecomposition>( * P is a sparse matrix with exactly one element set to [Ring.one] in * each row and each column, all other elements being set to [Ring.zero]. */ - val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + override val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j == pivot[i]) elementContext.one else elementContext.zero } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index 68a81f077..533e77d61 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -116,6 +116,14 @@ interface Matrix : NDStructure { val features: Set + /** + * Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure. + * + * The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to + * add only those features that are valid. + */ + fun suggestFeature(vararg features: MatrixFeature): Matrix + operator fun get(i: Int, j: Int): T override fun get(index: IntArray): T = get(index[0], index[1]) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt index d9f4e58ca..6b45a14b1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt @@ -12,7 +12,7 @@ interface MatrixFeature object DiagonalFeature : MatrixFeature /** - * Matix with this feature has all zero elements + * Matrix with this feature has all zero elements */ object ZeroFeature : MatrixFeature @@ -21,10 +21,42 @@ object ZeroFeature : MatrixFeature */ object UnitFeature : MatrixFeature +/** + * Inverted matrix feature + */ interface InverseMatrixFeature : MatrixFeature { val inverse: Matrix } +/** + * A determinant container + */ interface DeterminantFeature : MatrixFeature { val determinant: T -} \ No newline at end of file +} + +@Suppress("FunctionName") +fun DeterminantFeature(determinant: T) = object: DeterminantFeature{ + override val determinant: T = determinant +} + +/** + * Lower triangular matrix + */ +object LFeature: MatrixFeature + +/** + * Upper triangular feature + */ +object UFeature: MatrixFeature + +/** + * TODO add documentation + */ +interface LUPDecompositionFeature : MatrixFeature { + val l: Matrix + val u: Matrix + val p: Matrix +} + +//TODO add sparse matrix feature \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt index 98655ad48..1bab52902 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt @@ -8,6 +8,9 @@ class VirtualMatrix( ) : Matrix { override fun get(i: Int, j: Int): T = generator(i, j) + override fun suggestFeature(vararg features: MatrixFeature) = + VirtualMatrix(rowNum, colNum, this.features + features, generator) + override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is Matrix<*>) return false diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt index 5c887f343..09e93483d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ShortNDRing.kt @@ -15,6 +15,7 @@ class ShortNDRing(override val shape: IntArray) : override val zero by lazy { produce { ShortRing.zero } } override val one by lazy { produce { ShortRing.one } } + @Suppress("OVERRIDE_BY_INLINE") override inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer = ShortBuffer(ShortArray(size) { initializer(it) }) diff --git a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt index 3b7894d18..343d5b8b9 100644 --- a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt +++ b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt @@ -48,10 +48,25 @@ class KomaMatrixContext(val factory: MatrixFactory(val origin: koma.matrix.Matrix) : Matrix { +class KomaMatrix(val origin: koma.matrix.Matrix, features: Set? = null) : + Matrix { override val rowNum: Int get() = origin.numRows() override val colNum: Int get() = origin.numCols() - override val features: Set get() = emptySet() + + override val features: Set = features ?: setOf( + object : DeterminantFeature { + override val determinant: T get() = origin.det() + }, + object : LUPDecompositionFeature { + private val lup by lazy { origin.LU() } + override val l: Matrix get() = KomaMatrix(lup.second) + override val u: Matrix get() = KomaMatrix(lup.third) + override val p: Matrix get() = KomaMatrix(lup.first) + } + ) + + override fun suggestFeature(vararg features: MatrixFeature): Matrix = + KomaMatrix(this.origin, this.features + features) override fun get(i: Int, j: Int): T = origin.getGeneric(i, j) }