added tests for svd Golab Kahan

This commit is contained in:
margarita0303 2022-07-13 19:54:10 +03:00
parent de2f82c878
commit 082317b8b5

View File

@ -175,6 +175,28 @@ internal class TestDoubleLinearOpsTensorAlgebra {
assertTrue(tensor.eq(tensorSVD))
}
@Test
fun testSVDGolabKahan() = DoubleTensorAlgebra{
testSVDGolabKahanFor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
testSVDGolabKahanFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
val buffer = doubleArrayOf(
1.000000, 2.000000, 3.000000,
2.000000, 3.000000, 4.000000,
3.000000, 4.000000, 5.000000,
4.000000, 5.000000, 6.000000,
5.000000, 6.000000, 7.000000
)
testSVDGolabKahanFor(fromArray(intArrayOf(5, 3), buffer))
}
@Test
fun testBatchedSVDGolabKahan() = DoubleTensorAlgebra{
val tensor = randomNormal(intArrayOf(2, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svdGolabKahan()
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensor.eq(tensorSVD))
}
@Test
fun testBatchedSymEig() = DoubleTensorAlgebra {
val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0)
@ -199,3 +221,17 @@ private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double
assertTrue(tensor.eq(tensorSVD, epsilon))
}
private fun DoubleTensorAlgebra.testSVDGolabKahanFor(tensor: DoubleTensor) {
val svd = tensor.svdGolabKahan()
val tensorSVD = svd.first
.dot(
diagonalEmbedding(svd.second)
.dot(svd.third.transpose())
)
assertTrue(tensor.eq(tensorSVD))
}