Cholesky with precision in client API
This commit is contained in:
parent
75783bcb03
commit
a2d41d5e73
@ -66,11 +66,10 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||
checkSymmetric(this)
|
||||
public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor {
|
||||
checkSquareMatrix(shape)
|
||||
//TODO("Andrei the det routine has bugs")
|
||||
//checkPositiveDefinite(this)
|
||||
//checkPositiveDefinite(this, epsilon)
|
||||
|
||||
val n = shape.last()
|
||||
val lTensor = zeroesLike()
|
||||
@ -81,6 +80,8 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
return lTensor
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cholesky(): DoubleTensor = cholesky(1e-6)
|
||||
|
||||
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
|
||||
checkSquareMatrix(shape)
|
||||
val qTensor = zeroesLike()
|
||||
|
@ -63,7 +63,9 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps
|
||||
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
||||
}
|
||||
|
||||
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit {
|
||||
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(
|
||||
tensor: DoubleTensor, epsilon: Double = 1e-6): Unit {
|
||||
checkSymmetric(tensor, epsilon)
|
||||
for( mat in tensor.matrixSequence())
|
||||
check(mat.asTensor().detLU().value() > 0.0){
|
||||
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
||||
|
Loading…
Reference in New Issue
Block a user