v0.3.0-dev-9 #324

Merged
altavir merged 265 commits from dev into master 2021-05-08 17:16:29 +03:00
Showing only changes of commit bd068b2c14 - Show all commits

View File

@ -163,7 +163,7 @@ class TestDoubleLinearOpsTensorAlgebra {
@Test @Test
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra { fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(intArrayOf(1, 15, 4, 7, 5, 3), 0) val tensor = randNormal(intArrayOf(2, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svd() val (tensorU, tensorS, tensorV) = tensor.svd()
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensor.eq(tensorSVD)) assertTrue(tensor.eq(tensorSVD))
@ -171,7 +171,7 @@ class TestDoubleLinearOpsTensorAlgebra {
@Test @Test
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra { fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(shape = intArrayOf(5, 3, 3), 0) val tensor = randNormal(shape = intArrayOf(2, 3, 3), 0)
val tensorSigma = tensor + tensor.transpose() val tensorSigma = tensor + tensor.transpose()
val (tensorS, tensorV) = tensorSigma.symEig() val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())