KMP library for tensors #300
@ -3,8 +3,10 @@ package space.kscience.kmath.tensors.core
|
||||
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.structures.toList
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.min
|
||||
import kotlin.math.sign
|
||||
|
||||
public class DoubleLinearOpsTensorAlgebra :
|
||||
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
||||
@ -113,11 +115,14 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
return Triple(resU.transpose(), resS, resV.transpose())
|
||||
}
|
||||
|
||||
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
|
||||
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
||||
checkSymmetric(this)
|
||||
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
|
||||
//see the last point
|
||||
TODO("maybe use SVD")
|
||||
val (u, s, v) = this.svd()
|
||||
val shp = s.shape + intArrayOf(1)
|
||||
val utv = (u.transpose() dot v).map { if (abs(it) < 0.99) 0.0 else sign(it) }
|
||||
val eig = (utv dot s.view(shp)).view(s.shape)
|
||||
return Pair(eig, v)
|
||||
}
|
||||
|
||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||
|
@ -1,5 +1,6 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.structures.toList
|
||||
import kotlin.math.abs
|
||||
import kotlin.test.Ignore
|
||||
import kotlin.test.Test
|
||||
@ -137,13 +138,12 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
||||
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
||||
val tensor = randNormal(shape = intArrayOf(5, 3, 3), 0)
|
||||
val tensorSigma = tensor + tensor.transpose()
|
||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||
assertTrue(tensorSigma.eq(tensorSigmaCalc))
|
||||
assertTrue(tensorSigma.eq(tensorSigmaCalc, 0.01))
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user