Safer cleaner for symeig

This commit is contained in:
Roland Grinis 2021-04-09 10:08:55 +01:00
parent fe8579180d
commit a692412cff
2 changed files with 18 additions and 4 deletions

View File

@ -3,10 +3,8 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.toList
import kotlin.math.abs
import kotlin.math.min import kotlin.math.min
import kotlin.math.sign
public class DoubleLinearOpsTensorAlgebra : public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>, LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
@ -125,7 +123,11 @@ public class DoubleLinearOpsTensorAlgebra :
checkSymmetric(this, epsilon) checkSymmetric(this, epsilon)
val (u, s, v) = this.svd(epsilon) val (u, s, v) = this.svd(epsilon)
val shp = s.shape + intArrayOf(1) val shp = s.shape + intArrayOf(1)
val utv = (u.transpose() dot v).map { if (abs(it) < 0.9) 0.0 else sign(it) } val utv = u.transpose() dot v
val n = s.shape.last()
for( matrix in utv.matrixSequence())
cleanSymHelper(matrix.as2D(),n)
val eig = (utv dot s.view(shp)).view(s.shape) val eig = (utv dot s.view(shp)).view(s.shape)
return Pair(eig, v) return Pair(eig, v)
} }

View File

@ -6,6 +6,7 @@ import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.min import kotlin.math.min
import kotlin.math.sign
import kotlin.math.sqrt import kotlin.math.sqrt
@ -294,3 +295,14 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
matrixV.buffer.array()[matrixV.bufferStart + i] = vBuffer[i] matrixV.buffer.array()[matrixV.bufferStart + i] = vBuffer[i]
} }
} }
internal inline fun cleanSymHelper(matrix: MutableStructure2D<Double>, n: Int): Unit {
for (i in 0 until n)
for (j in 0 until n) {
if (i == j) {
matrix[i, j] = sign(matrix[i, j])
} else {
matrix[i, j] = 0.0
}
}
}