diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index 916361abe..43502e79a 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -67,7 +67,7 @@ public class DoubleLinearOpsTensorAlgebra : } override fun DoubleTensor.cholesky(): DoubleTensor { - // todo checks + checkSymmetric(this) checkSquareMatrix(shape) val n = shape.last() @@ -113,7 +113,9 @@ public class DoubleLinearOpsTensorAlgebra : } override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair { - TODO() + checkSymmetric(this) + val svd = this.svd() + TODO("U might have some columns negative to V") } public fun DoubleTensor.detLU(): DoubleTensor { diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt index f994324ff..1f70eb300 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt @@ -58,3 +58,7 @@ internal inline fun , } } +internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor): Unit = + check(tensor.eq(tensor.transpose())){ + "Tensor is not symmetric about the last 2 dimensions" + } \ No newline at end of file diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index df896bace..cc127b5a7 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -140,9 +140,9 @@ class TestDoubleLinearOpsTensorAlgebra { @Ignore fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra { val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0) - val tensorSigma = tensor + tensor.transpose(1, 2) + val tensorSigma = tensor + tensor.transpose() val (tensorS, tensorV) = tensorSigma.symEig() - val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(1, 2)) + val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) assertTrue(tensorSigma.eq(tensorSigmaCalc)) }