diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index 82f63dd40..e0ec49333 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -66,11 +66,10 @@ public class DoubleLinearOpsTensorAlgebra : } - override fun DoubleTensor.cholesky(): DoubleTensor { - checkSymmetric(this) + public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor { checkSquareMatrix(shape) //TODO("Andrei the det routine has bugs") - //checkPositiveDefinite(this) + //checkPositiveDefinite(this, epsilon) val n = shape.last() val lTensor = zeroesLike() @@ -81,6 +80,8 @@ public class DoubleLinearOpsTensorAlgebra : return lTensor } + override fun DoubleTensor.cholesky(): DoubleTensor = cholesky(1e-6) + override fun DoubleTensor.qr(): Pair { checkSquareMatrix(shape) val qTensor = zeroesLike() diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt index 8dbf9eb81..730b6ed9a 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt @@ -63,7 +63,9 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps "Tensor is not symmetric about the last 2 dimensions at precision $epsilon" } -internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit { +internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite( + tensor: DoubleTensor, epsilon: Double = 1e-6): Unit { + checkSymmetric(tensor, epsilon) for( mat in tensor.matrixSequence()) check(mat.asTensor().detLU().value() > 0.0){ "Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"