forked from kscience/kmath
Added equality to matrices
This commit is contained in:
parent
cedb8a816e
commit
271d7886df
@ -152,25 +152,16 @@ interface Matrix<T : Any> : NDStructure<T> {
|
|||||||
/**
|
/**
|
||||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||||
*/
|
*/
|
||||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> {
|
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> = VirtualMatrix<T>(rows,columns){i, j->
|
||||||
return object : Matrix<T> {
|
if (i == j) elementContext.one else elementContext.zero
|
||||||
override val rowNum: Int get() = rows
|
|
||||||
override val colNum: Int get() = columns
|
|
||||||
override val features: Set<MatrixFeature> get() = setOf(DiagonalFeature, UnitFeature)
|
|
||||||
override fun get(i: Int, j: Int): T = if (i == j) elementContext.one else elementContext.zero
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A virtual matrix of zeroes
|
* A virtual matrix of zeroes
|
||||||
*/
|
*/
|
||||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> {
|
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> = VirtualMatrix<T>(rows,columns){i, j->
|
||||||
return object : Matrix<T> {
|
elementContext.zero
|
||||||
override val rowNum: Int get() = rows
|
|
||||||
override val colNum: Int get() = columns
|
|
||||||
override val features: Set<MatrixFeature> get() = setOf(ZeroFeature)
|
|
||||||
override fun get(i: Int, j: Int): T = elementContext.zero
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ class StructureMatrixContext<T : Any, R : Ring<T>>(
|
|||||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
data class StructureMatrix<T : Any>(
|
class StructureMatrix<T : Any>(
|
||||||
val structure: NDStructure<T>,
|
val structure: NDStructure<T>,
|
||||||
override val features: Set<MatrixFeature> = emptySet()
|
override val features: Set<MatrixFeature> = emptySet()
|
||||||
) : Matrix<T> {
|
) : Matrix<T> {
|
||||||
@ -47,4 +47,21 @@ data class StructureMatrix<T : Any>(
|
|||||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||||
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
return when (other) {
|
||||||
|
is StructureMatrix<*> -> return this.structure == other.structure
|
||||||
|
is Matrix<*> -> elements().all { (index, value) -> value == other[index] }
|
||||||
|
else -> false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = structure.hashCode()
|
||||||
|
result = 31 * result + features.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
@ -7,4 +7,24 @@ class VirtualMatrix<T : Any>(
|
|||||||
val generator: (i: Int, j: Int) -> T
|
val generator: (i: Int, j: Int) -> T
|
||||||
) : Matrix<T> {
|
) : Matrix<T> {
|
||||||
override fun get(i: Int, j: Int): T = generator(i, j)
|
override fun get(i: Int, j: Int): T = generator(i, j)
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other !is Matrix<*>) return false
|
||||||
|
|
||||||
|
if (rowNum != other.rowNum) return false
|
||||||
|
if (colNum != other.colNum) return false
|
||||||
|
|
||||||
|
return elements().all { (index, value) -> value == other[index] }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = rowNum
|
||||||
|
result = 31 * result + colNum
|
||||||
|
result = 31 * result + features.hashCode()
|
||||||
|
result = 31 * result + generator.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
@ -14,7 +14,7 @@ class RealLUSolverTest {
|
|||||||
// @Test
|
// @Test
|
||||||
// fun testInvert() {
|
// fun testInvert() {
|
||||||
// val matrix = realMatrix(2,2){}
|
// val matrix = realMatrix(2,2){}
|
||||||
// val inverted = RealLUSolver.inverse(matrix)
|
// val inverted = LUSolver.real.inverse(matrix)
|
||||||
// assertTrue { Matrix.equals(matrix,inverted) }
|
// assertEquals(matrix, inverted)
|
||||||
// }
|
// }
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user