v0.3.0-dev-9 #324
@ -163,7 +163,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
|
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
|
||||||
val tensor = randNormal(intArrayOf(1, 15, 4, 7, 5, 3), 0)
|
val tensor = randNormal(intArrayOf(2, 5, 3), 0)
|
||||||
val (tensorU, tensorS, tensorV) = tensor.svd()
|
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||||
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||||
assertTrue(tensor.eq(tensorSVD))
|
assertTrue(tensor.eq(tensorSVD))
|
||||||
@ -171,7 +171,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
||||||
val tensor = randNormal(shape = intArrayOf(5, 3, 3), 0)
|
val tensor = randNormal(shape = intArrayOf(2, 3, 3), 0)
|
||||||
val tensorSigma = tensor + tensor.transpose()
|
val tensorSigma = tensor + tensor.transpose()
|
||||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||||
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||||
|
Loading…
Reference in New Issue
Block a user