diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index c26046e37..4d6d14764 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -67,9 +67,10 @@ public class DoubleLinearOpsTensorAlgebra : } override fun DoubleTensor.cholesky(): DoubleTensor { - //positive definite check checkSymmetric(this) checkSquareMatrix(shape) + //TODO("Andrei the det routine has bugs") + //checkPositiveDefinite(this) val n = shape.last() val lTensor = zeroesLike() diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt index 1dde2ea56..8dbf9eb81 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt @@ -61,4 +61,11 @@ internal inline fun , internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, epsilon: Double = 1e-6): Unit = check(tensor.eq(tensor.transpose(), epsilon)) { "Tensor is not symmetric about the last 2 dimensions at precision $epsilon" - } \ No newline at end of file + } + +internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit { + for( mat in tensor.matrixSequence()) + check(mat.asTensor().detLU().value() > 0.0){ + "Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}" + } +} \ No newline at end of file diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index 0997a9b86..2914b216b 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -112,6 +112,19 @@ class TestDoubleLinearOpsTensorAlgebra { assertTrue((p dot tensor).eq(l dot u)) } + @Test + fun testCholesky() = DoubleLinearOpsTensorAlgebra { + val tensor = randNormal(intArrayOf(2, 5, 5), 0) + val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding( + fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 }) + ) + //checkPositiveDefinite(sigma) sigma must be positive definite + val low = sigma.cholesky() + val sigmChol = low dot low.transpose() + assertTrue(sigma.eq(sigmChol)) + + } + @Test fun testSVD1D() = DoubleLinearOpsTensorAlgebra { val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))