Cholesky with precision in client API
This commit is contained in:
parent
75783bcb03
commit
a2d41d5e73
@ -66,11 +66,10 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor {
|
||||||
checkSymmetric(this)
|
|
||||||
checkSquareMatrix(shape)
|
checkSquareMatrix(shape)
|
||||||
//TODO("Andrei the det routine has bugs")
|
//TODO("Andrei the det routine has bugs")
|
||||||
//checkPositiveDefinite(this)
|
//checkPositiveDefinite(this, epsilon)
|
||||||
|
|
||||||
val n = shape.last()
|
val n = shape.last()
|
||||||
val lTensor = zeroesLike()
|
val lTensor = zeroesLike()
|
||||||
@ -81,6 +80,8 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
return lTensor
|
return lTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.cholesky(): DoubleTensor = cholesky(1e-6)
|
||||||
|
|
||||||
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
|
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
|
||||||
checkSquareMatrix(shape)
|
checkSquareMatrix(shape)
|
||||||
val qTensor = zeroesLike()
|
val qTensor = zeroesLike()
|
||||||
|
@ -63,7 +63,9 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps
|
|||||||
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit {
|
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(
|
||||||
|
tensor: DoubleTensor, epsilon: Double = 1e-6): Unit {
|
||||||
|
checkSymmetric(tensor, epsilon)
|
||||||
for( mat in tensor.matrixSequence())
|
for( mat in tensor.matrixSequence())
|
||||||
check(mat.asTensor().detLU().value() > 0.0){
|
check(mat.asTensor().detLU().value() > 0.0){
|
||||||
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
||||||
|
Loading…
Reference in New Issue
Block a user