diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index 82f63dd40..027fcf5d0 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -14,8 +14,7 @@ public class DoubleLinearOpsTensorAlgebra : override fun DoubleTensor.det(): DoubleTensor = detLU() - override fun DoubleTensor.lu(): Pair { - + internal fun DoubleTensor.luForDet(forDet: Boolean = false): Pair { checkSquareMatrix(shape) val luTensor = copy() @@ -31,10 +30,22 @@ public class DoubleLinearOpsTensorAlgebra : ) for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence())) - luHelper(lu.as2D(), pivots.as1D(), m) + try { + luHelper(lu.as2D(), pivots.as1D(), m) + } catch (e: RuntimeException) { + if (forDet) { + lu.as2D()[intArrayOf(0, 0)] = 0.0 + } else { + throw IllegalStateException("LUP decomposition can't be performed") + } + } + return Pair(luTensor, pivotsTensor) + } + override fun DoubleTensor.lu(): Pair { + return luForDet(false) } override fun luPivot( @@ -66,11 +77,10 @@ public class DoubleLinearOpsTensorAlgebra : } - override fun DoubleTensor.cholesky(): DoubleTensor { - checkSymmetric(this) + public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor { checkSquareMatrix(shape) - //TODO("Andrei the det routine has bugs") - //checkPositiveDefinite(this) + checkPositiveDefinite(this) + //checkPositiveDefinite(this, epsilon) val n = shape.last() val lTensor = zeroesLike() @@ -81,6 +91,8 @@ public class DoubleLinearOpsTensorAlgebra : return lTensor } + override fun DoubleTensor.cholesky(): DoubleTensor = cholesky(1e-6) + override fun DoubleTensor.qr(): Pair { checkSquareMatrix(shape) val qTensor = zeroesLike() @@ -134,7 +146,7 @@ public class DoubleLinearOpsTensorAlgebra : } public fun DoubleTensor.detLU(): DoubleTensor { - val (luTensor, pivotsTensor) = lu() + val (luTensor, pivotsTensor) = luForDet(forDet = true) val n = shape.size val detTensorShape = IntArray(n - 1) { i -> shape[i] } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt index 8dbf9eb81..45a71d1bc 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt @@ -2,6 +2,7 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.tensors.TensorAlgebra import space.kscience.kmath.tensors.TensorStructure +import kotlin.math.abs internal inline fun , @@ -63,9 +64,20 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps "Tensor is not symmetric about the last 2 dimensions at precision $epsilon" } -internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit { +internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite( + tensor: DoubleTensor, epsilon: Double = 1e-6): Unit { + checkSymmetric(tensor, epsilon) for( mat in tensor.matrixSequence()) check(mat.asTensor().detLU().value() > 0.0){ "Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}" } +} + +internal inline fun DoubleLinearOpsTensorAlgebra.checkNonSingularMatrix(tensor: DoubleTensor): Unit { + for( mat in tensor.matrixSequence()) { + val detTensor = mat.asTensor().detLU() + check(!(detTensor.eq(detTensor.zeroesLike()))){ + "Tensor contains matrices which are singular" + } + } } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt index b3cfc1092..f2c9d8c76 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt @@ -62,10 +62,10 @@ internal inline fun dotHelper( } internal inline fun luHelper(lu: MutableStructure2D, pivots: MutableStructure1D, m: Int) { - for (row in 0 until m) pivots[row] = row + for (row in 0..m) pivots[row] = row for (i in 0 until m) { - var maxVal = -1.0 + var maxVal = 0.0 var maxInd = i for (k in i until m) { @@ -76,7 +76,9 @@ internal inline fun luHelper(lu: MutableStructure2D, pivots: MutableStru } } - //todo check singularity + if (abs(maxVal) < 1e-9) { + throw RuntimeException() + } if (maxInd != i) { @@ -158,6 +160,9 @@ internal inline fun choleskyHelper( internal inline fun luMatrixDet(luTensor: MutableStructure2D, pivotsTensor: MutableStructure1D): Double { val lu = luTensor.as2D() val pivots = pivotsTensor.as1D() + if (lu[0, 0] == 0.0) { + return 0.0 + } val m = lu.shape[0] val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0 return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index 37caf88fe..75ff12355 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -36,6 +36,7 @@ class TestDoubleLinearOpsTensorAlgebra { @Test fun testDet() = DoubleLinearOpsTensorAlgebra { + val expectedValue = 0.019827417 val m = fromArray( intArrayOf(3, 3), doubleArrayOf( 2.1843, 1.4391, -0.4845, @@ -43,8 +44,20 @@ class TestDoubleLinearOpsTensorAlgebra { -0.4845, 0.4055, 0.7519 ) ) - println(m.det().value()) - println(0.0197) //expected value + + assertTrue { abs(m.det().value() - expectedValue) < 1e-5} + } + + @Test + fun testDetSingle() = DoubleLinearOpsTensorAlgebra { + val expectedValue = 48.151623 + val m = fromArray( + intArrayOf(1, 1), doubleArrayOf( + expectedValue + ) + ) + + assertTrue { abs(m.det().value() - expectedValue) < 1e-5} } @Test