From a692412cff00241031258e37e1f7c2ad151518f4 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Fri, 9 Apr 2021 10:08:55 +0100 Subject: [PATCH] Safer cleaner for symeig --- .../tensors/core/DoubleLinearOpsTensorAlgebra.kt | 10 ++++++---- .../space/kscience/kmath/tensors/core/linutils.kt | 12 ++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index 95b668917..c26046e37 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -3,10 +3,8 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.tensors.LinearOpsTensorAlgebra import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D -import space.kscience.kmath.structures.toList -import kotlin.math.abs import kotlin.math.min -import kotlin.math.sign + public class DoubleLinearOpsTensorAlgebra : LinearOpsTensorAlgebra, @@ -125,7 +123,11 @@ public class DoubleLinearOpsTensorAlgebra : checkSymmetric(this, epsilon) val (u, s, v) = this.svd(epsilon) val shp = s.shape + intArrayOf(1) - val utv = (u.transpose() dot v).map { if (abs(it) < 0.9) 0.0 else sign(it) } + val utv = u.transpose() dot v + val n = s.shape.last() + for( matrix in utv.matrixSequence()) + cleanSymHelper(matrix.as2D(),n) + val eig = (utv dot s.view(shp)).view(s.shape) return Pair(eig, v) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt index 685c16a1b..b3cfc1092 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt @@ -6,6 +6,7 @@ import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D import kotlin.math.abs import kotlin.math.min +import kotlin.math.sign import kotlin.math.sqrt @@ -294,3 +295,14 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper( matrixV.buffer.array()[matrixV.bufferStart + i] = vBuffer[i] } } + +internal inline fun cleanSymHelper(matrix: MutableStructure2D, n: Int): Unit { + for (i in 0 until n) + for (j in 0 until n) { + if (i == j) { + matrix[i, j] = sign(matrix[i, j]) + } else { + matrix[i, j] = 0.0 + } + } +}