diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index 8cd9a6b8f..d9a14864b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -152,25 +152,16 @@ interface Matrix : NDStructure { /** * Diagonal matrix of ones. The matrix is virtual no actual matrix is created */ -fun > MatrixContext.one(rows: Int, columns: Int): Matrix { - return object : Matrix { - override val rowNum: Int get() = rows - override val colNum: Int get() = columns - override val features: Set get() = setOf(DiagonalFeature, UnitFeature) - override fun get(i: Int, j: Int): T = if (i == j) elementContext.one else elementContext.zero - } +fun > MatrixContext.one(rows: Int, columns: Int): Matrix = VirtualMatrix(rows,columns){i, j-> + if (i == j) elementContext.one else elementContext.zero } + /** * A virtual matrix of zeroes */ -fun > MatrixContext.zero(rows: Int, columns: Int): Matrix { - return object : Matrix { - override val rowNum: Int get() = rows - override val colNum: Int get() = columns - override val features: Set get() = setOf(ZeroFeature) - override fun get(i: Int, j: Int): T = elementContext.zero - } +fun > MatrixContext.zero(rows: Int, columns: Int): Matrix = VirtualMatrix(rows,columns){i, j-> + elementContext.zero } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt index 1109297be..aa061fae2 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/StructureMatrix.kt @@ -23,7 +23,7 @@ class StructureMatrixContext>( override fun point(size: Int, initializer: (Int) -> T): Point = bufferFactory(size, initializer) } -data class StructureMatrix( +class StructureMatrix( val structure: NDStructure, override val features: Set = emptySet() ) : Matrix { @@ -47,4 +47,21 @@ data class StructureMatrix( override fun get(i: Int, j: Int): T = structure[i, j] override fun elements(): Sequence> = structure.elements() + + override fun equals(other: Any?): Boolean { + if (this === other) return true + return when (other) { + is StructureMatrix<*> -> return this.structure == other.structure + is Matrix<*> -> elements().all { (index, value) -> value == other[index] } + else -> false + } + } + + override fun hashCode(): Int { + var result = structure.hashCode() + result = 31 * result + features.hashCode() + return result + } + + } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt index 59c89b197..cb7b0a08e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt @@ -7,4 +7,24 @@ class VirtualMatrix( val generator: (i: Int, j: Int) -> T ) : Matrix { override fun get(i: Int, j: Int): T = generator(i, j) + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is Matrix<*>) return false + + if (rowNum != other.rowNum) return false + if (colNum != other.colNum) return false + + return elements().all { (index, value) -> value == other[index] } + } + + override fun hashCode(): Int { + var result = rowNum + result = 31 * result + colNum + result = 31 * result + features.hashCode() + result = 31 * result + generator.hashCode() + return result + } + + } \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt index daf749e8a..54a308761 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt @@ -14,7 +14,7 @@ class RealLUSolverTest { // @Test // fun testInvert() { // val matrix = realMatrix(2,2){} -// val inverted = RealLUSolver.inverse(matrix) -// assertTrue { Matrix.equals(matrix,inverted) } +// val inverted = LUSolver.real.inverse(matrix) +// assertEquals(matrix, inverted) // } } \ No newline at end of file