Add proper equality check for EJML matrices
This commit is contained in:
parent
d0c9d97706
commit
1c7bd05c58
@ -43,7 +43,7 @@ public class BufferMatrix<T : Any>(
|
||||
if (this === other) return true
|
||||
|
||||
return when (other) {
|
||||
is NDStructure<*> -> NDStructure.equals(this, other)
|
||||
is NDStructure<*> -> NDStructure.contentEquals(this, other)
|
||||
else -> false
|
||||
}
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ public interface NDStructure<T> {
|
||||
/**
|
||||
* Indicates whether some [NDStructure] is equal to another one.
|
||||
*/
|
||||
public fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||
public fun contentEquals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean {
|
||||
if (st1 === st2) return true
|
||||
|
||||
// fast comparison of buffers if possible
|
||||
@ -275,7 +275,7 @@ public abstract class NDBuffer<T> : NDStructure<T> {
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to this[it] }
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
|
@ -24,7 +24,7 @@ public class LazyNDStructure<T>(
|
||||
}
|
||||
|
||||
public override fun equals(other: Any?): Boolean {
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false)
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int {
|
||||
|
@ -3,6 +3,7 @@ package kscience.kmath.ejml
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.NDStructure
|
||||
import kscience.kmath.structures.RealBuffer
|
||||
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
@ -15,7 +16,7 @@ import kotlin.reflect.cast
|
||||
* @property origin the underlying [SimpleMatrix].
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public inline class EjmlMatrix(
|
||||
public class EjmlMatrix(
|
||||
public val origin: SimpleMatrix,
|
||||
) : Matrix<Double> {
|
||||
public override val rowNum: Int get() = origin.numRows()
|
||||
@ -49,7 +50,7 @@ public inline class EjmlMatrix(
|
||||
override val q: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) }
|
||||
override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) }
|
||||
}
|
||||
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
|
||||
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
|
||||
override val l: Matrix<Double> by lazy {
|
||||
val cholesky =
|
||||
DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) }
|
||||
@ -57,7 +58,7 @@ public inline class EjmlMatrix(
|
||||
EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature
|
||||
}
|
||||
}
|
||||
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
|
||||
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
|
||||
private val lup by lazy {
|
||||
DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) }
|
||||
}
|
||||
@ -73,7 +74,17 @@ public inline class EjmlMatrix(
|
||||
override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
|
||||
}
|
||||
else -> null
|
||||
}?.let{type.cast(it)}
|
||||
}?.let { type.cast(it) }
|
||||
|
||||
public override operator fun get(i: Int, j: Int): Double = origin[i, j]
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is Matrix<*>) return false
|
||||
return NDStructure.contentEquals(this, other)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = origin.hashCode()
|
||||
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user