Cholesky decomp tests and checks - det to be fixed

This commit is contained in:
Roland Grinis 2021-04-09 10:53:36 +01:00
parent a692412cff
commit 1e8da7a87b
3 changed files with 23 additions and 2 deletions

View File

@ -67,9 +67,10 @@ public class DoubleLinearOpsTensorAlgebra :
} }
override fun DoubleTensor.cholesky(): DoubleTensor { override fun DoubleTensor.cholesky(): DoubleTensor {
//positive definite check
checkSymmetric(this) checkSymmetric(this)
checkSquareMatrix(shape) checkSquareMatrix(shape)
//TODO("Andrei the det routine has bugs")
//checkPositiveDefinite(this)
val n = shape.last() val n = shape.last()
val lTensor = zeroesLike() val lTensor = zeroesLike()

View File

@ -62,3 +62,10 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps
check(tensor.eq(tensor.transpose(), epsilon)) { check(tensor.eq(tensor.transpose(), epsilon)) {
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon" "Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
} }
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit {
for( mat in tensor.matrixSequence())
check(mat.asTensor().detLU().value() > 0.0){
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
}
}

View File

@ -112,6 +112,19 @@ class TestDoubleLinearOpsTensorAlgebra {
assertTrue((p dot tensor).eq(l dot u)) assertTrue((p dot tensor).eq(l dot u))
} }
@Test
fun testCholesky() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(intArrayOf(2, 5, 5), 0)
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding(
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
)
//checkPositiveDefinite(sigma) sigma must be positive definite
val low = sigma.cholesky()
val sigmChol = low dot low.transpose()
assertTrue(sigma.eq(sigmChol))
}
@Test @Test
fun testSVD1D() = DoubleLinearOpsTensorAlgebra { fun testSVD1D() = DoubleLinearOpsTensorAlgebra {
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))