Cholesky with precision in client API

This commit is contained in:
Roland Grinis 2021-04-14 20:30:42 +01:00
parent 75783bcb03
commit a2d41d5e73
2 changed files with 7 additions and 4 deletions

View File

@ -66,11 +66,10 @@ public class DoubleLinearOpsTensorAlgebra :
} }
override fun DoubleTensor.cholesky(): DoubleTensor { public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor {
checkSymmetric(this)
checkSquareMatrix(shape) checkSquareMatrix(shape)
//TODO("Andrei the det routine has bugs") //TODO("Andrei the det routine has bugs")
//checkPositiveDefinite(this) //checkPositiveDefinite(this, epsilon)
val n = shape.last() val n = shape.last()
val lTensor = zeroesLike() val lTensor = zeroesLike()
@ -81,6 +80,8 @@ public class DoubleLinearOpsTensorAlgebra :
return lTensor return lTensor
} }
override fun DoubleTensor.cholesky(): DoubleTensor = cholesky(1e-6)
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> { override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
checkSquareMatrix(shape) checkSquareMatrix(shape)
val qTensor = zeroesLike() val qTensor = zeroesLike()

View File

@ -63,7 +63,9 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon" "Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
} }
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit { internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(
tensor: DoubleTensor, epsilon: Double = 1e-6): Unit {
checkSymmetric(tensor, epsilon)
for( mat in tensor.matrixSequence()) for( mat in tensor.matrixSequence())
check(mat.asTensor().detLU().value() > 0.0){ check(mat.asTensor().detLU().value() > 0.0){
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}" "Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"