v0.3.0-dev-9 #324
@ -3,10 +3,8 @@ package space.kscience.kmath.tensors.core
|
||||
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.structures.toList
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.min
|
||||
import kotlin.math.sign
|
||||
|
||||
|
||||
public class DoubleLinearOpsTensorAlgebra :
|
||||
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
||||
@ -125,7 +123,11 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
checkSymmetric(this, epsilon)
|
||||
val (u, s, v) = this.svd(epsilon)
|
||||
val shp = s.shape + intArrayOf(1)
|
||||
val utv = (u.transpose() dot v).map { if (abs(it) < 0.9) 0.0 else sign(it) }
|
||||
val utv = u.transpose() dot v
|
||||
val n = s.shape.last()
|
||||
for( matrix in utv.matrixSequence())
|
||||
cleanSymHelper(matrix.as2D(),n)
|
||||
|
||||
val eig = (utv dot s.view(shp)).view(s.shape)
|
||||
return Pair(eig, v)
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.min
|
||||
import kotlin.math.sign
|
||||
import kotlin.math.sqrt
|
||||
|
||||
|
||||
@ -294,3 +295,14 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
|
||||
matrixV.buffer.array()[matrixV.bufferStart + i] = vBuffer[i]
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun cleanSymHelper(matrix: MutableStructure2D<Double>, n: Int): Unit {
|
||||
for (i in 0 until n)
|
||||
for (j in 0 until n) {
|
||||
if (i == j) {
|
||||
matrix[i, j] = sign(matrix[i, j])
|
||||
} else {
|
||||
matrix[i, j] = 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user