forked from kscience/kmath
Safer cleaner for symeig
This commit is contained in:
parent
fe8579180d
commit
a692412cff
@ -3,10 +3,8 @@ package space.kscience.kmath.tensors.core
|
|||||||
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
||||||
import space.kscience.kmath.nd.as1D
|
import space.kscience.kmath.nd.as1D
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.structures.toList
|
|
||||||
import kotlin.math.abs
|
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
import kotlin.math.sign
|
|
||||||
|
|
||||||
public class DoubleLinearOpsTensorAlgebra :
|
public class DoubleLinearOpsTensorAlgebra :
|
||||||
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
||||||
@ -125,7 +123,11 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
checkSymmetric(this, epsilon)
|
checkSymmetric(this, epsilon)
|
||||||
val (u, s, v) = this.svd(epsilon)
|
val (u, s, v) = this.svd(epsilon)
|
||||||
val shp = s.shape + intArrayOf(1)
|
val shp = s.shape + intArrayOf(1)
|
||||||
val utv = (u.transpose() dot v).map { if (abs(it) < 0.9) 0.0 else sign(it) }
|
val utv = u.transpose() dot v
|
||||||
|
val n = s.shape.last()
|
||||||
|
for( matrix in utv.matrixSequence())
|
||||||
|
cleanSymHelper(matrix.as2D(),n)
|
||||||
|
|
||||||
val eig = (utv dot s.view(shp)).view(s.shape)
|
val eig = (utv dot s.view(shp)).view(s.shape)
|
||||||
return Pair(eig, v)
|
return Pair(eig, v)
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import space.kscience.kmath.nd.as1D
|
|||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
|
import kotlin.math.sign
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
|
||||||
@ -294,3 +295,14 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
|
|||||||
matrixV.buffer.array()[matrixV.bufferStart + i] = vBuffer[i]
|
matrixV.buffer.array()[matrixV.bufferStart + i] = vBuffer[i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal inline fun cleanSymHelper(matrix: MutableStructure2D<Double>, n: Int): Unit {
|
||||||
|
for (i in 0 until n)
|
||||||
|
for (j in 0 until n) {
|
||||||
|
if (i == j) {
|
||||||
|
matrix[i, j] = sign(matrix[i, j])
|
||||||
|
} else {
|
||||||
|
matrix[i, j] = 0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user