From c7669d4fba20ccd856918504d4836d2619961927 Mon Sep 17 00:00:00 2001 From: Andrei Kislitsyn Date: Wed, 14 Apr 2021 23:05:54 +0300 Subject: [PATCH] fix det --- .../core/DoubleLinearOpsTensorAlgebra.kt | 22 +++++++++---- .../kscience/kmath/tensors/core/checks.kt | 10 ++++++ .../kscience/kmath/tensors/core/linutils.kt | 11 +++++-- .../core/TestDoubleLinearOpsAlgebra.kt | 17 ++++++++-- kmath-tensors/src/jvmMain/kotlin/andMain.kt | 32 +++++++++++++++++++ 5 files changed, 81 insertions(+), 11 deletions(-) create mode 100644 kmath-tensors/src/jvmMain/kotlin/andMain.kt diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index 82f63dd40..59678947a 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -14,8 +14,7 @@ public class DoubleLinearOpsTensorAlgebra : override fun DoubleTensor.det(): DoubleTensor = detLU() - override fun DoubleTensor.lu(): Pair { - + internal fun DoubleTensor.luForDet(forDet: Boolean = false): Pair { checkSquareMatrix(shape) val luTensor = copy() @@ -31,10 +30,22 @@ public class DoubleLinearOpsTensorAlgebra : ) for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence())) - luHelper(lu.as2D(), pivots.as1D(), m) + try { + luHelper(lu.as2D(), pivots.as1D(), m) + } catch (e: RuntimeException) { + if (forDet) { + lu.as2D()[intArrayOf(0, 0)] = 0.0 + } else { + throw IllegalStateException("LUP decomposition can't be performed") + } + } + return Pair(luTensor, pivotsTensor) + } + override fun DoubleTensor.lu(): Pair { + return luForDet(false) } override fun luPivot( @@ -69,8 +80,7 @@ public class DoubleLinearOpsTensorAlgebra : override fun DoubleTensor.cholesky(): DoubleTensor { checkSymmetric(this) checkSquareMatrix(shape) - //TODO("Andrei the det routine has bugs") - //checkPositiveDefinite(this) + checkPositiveDefinite(this) val n = shape.last() val lTensor = zeroesLike() @@ -134,7 +144,7 @@ public class DoubleLinearOpsTensorAlgebra : } public fun DoubleTensor.detLU(): DoubleTensor { - val (luTensor, pivotsTensor) = lu() + val (luTensor, pivotsTensor) = luForDet(forDet = true) val n = shape.size val detTensorShape = IntArray(n - 1) { i -> shape[i] } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt index 8dbf9eb81..c61759da2 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt @@ -2,6 +2,7 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.tensors.TensorAlgebra import space.kscience.kmath.tensors.TensorStructure +import kotlin.math.abs internal inline fun , @@ -68,4 +69,13 @@ internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: D check(mat.asTensor().detLU().value() > 0.0){ "Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}" } +} + +internal inline fun DoubleLinearOpsTensorAlgebra.checkNonSingularMatrix(tensor: DoubleTensor): Unit { + for( mat in tensor.matrixSequence()) { + val detTensor = mat.asTensor().detLU() + check(!(detTensor.eq(detTensor.zeroesLike()))){ + "Tensor contains matrices which are singular" + } + } } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt index b3cfc1092..f2c9d8c76 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt @@ -62,10 +62,10 @@ internal inline fun dotHelper( } internal inline fun luHelper(lu: MutableStructure2D, pivots: MutableStructure1D, m: Int) { - for (row in 0 until m) pivots[row] = row + for (row in 0..m) pivots[row] = row for (i in 0 until m) { - var maxVal = -1.0 + var maxVal = 0.0 var maxInd = i for (k in i until m) { @@ -76,7 +76,9 @@ internal inline fun luHelper(lu: MutableStructure2D, pivots: MutableStru } } - //todo check singularity + if (abs(maxVal) < 1e-9) { + throw RuntimeException() + } if (maxInd != i) { @@ -158,6 +160,9 @@ internal inline fun choleskyHelper( internal inline fun luMatrixDet(luTensor: MutableStructure2D, pivotsTensor: MutableStructure1D): Double { val lu = luTensor.as2D() val pivots = pivotsTensor.as1D() + if (lu[0, 0] == 0.0) { + return 0.0 + } val m = lu.shape[0] val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0 return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index 37caf88fe..75ff12355 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -36,6 +36,7 @@ class TestDoubleLinearOpsTensorAlgebra { @Test fun testDet() = DoubleLinearOpsTensorAlgebra { + val expectedValue = 0.019827417 val m = fromArray( intArrayOf(3, 3), doubleArrayOf( 2.1843, 1.4391, -0.4845, @@ -43,8 +44,20 @@ class TestDoubleLinearOpsTensorAlgebra { -0.4845, 0.4055, 0.7519 ) ) - println(m.det().value()) - println(0.0197) //expected value + + assertTrue { abs(m.det().value() - expectedValue) < 1e-5} + } + + @Test + fun testDetSingle() = DoubleLinearOpsTensorAlgebra { + val expectedValue = 48.151623 + val m = fromArray( + intArrayOf(1, 1), doubleArrayOf( + expectedValue + ) + ) + + assertTrue { abs(m.det().value() - expectedValue) < 1e-5} } @Test diff --git a/kmath-tensors/src/jvmMain/kotlin/andMain.kt b/kmath-tensors/src/jvmMain/kotlin/andMain.kt new file mode 100644 index 000000000..68facf467 --- /dev/null +++ b/kmath-tensors/src/jvmMain/kotlin/andMain.kt @@ -0,0 +1,32 @@ +import space.kscience.kmath.nd.* +import space.kscience.kmath.tensors.core.DoubleLinearOpsTensorAlgebra +import space.kscience.kmath.tensors.core.DoubleTensorAlgebra +import space.kscience.kmath.tensors.core.array +import kotlin.math.abs +import kotlin.math.sqrt + + +fun main() { + + DoubleTensorAlgebra { + val tensor = fromArray( + intArrayOf(2, 2, 2), + doubleArrayOf( + 1.0, 3.0, + 1.0, 2.0, + 1.5, 1.0, + 10.0, 2.0 + ) + ) + val tensor2 = fromArray( + intArrayOf(2, 2), + doubleArrayOf( + 0.0, 0.0, + 0.0, 0.0 + ) + ) + DoubleLinearOpsTensorAlgebra { + println(tensor2.det().value()) + } + } +}