added new svd tests

This commit is contained in:
margarita0303 2022-07-14 14:34:44 +03:00
parent 082317b8b5
commit 684227df63

View File

@ -167,6 +167,18 @@ internal class TestDoubleLinearOpsTensorAlgebra {
testSVDFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0))) testSVDFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
} }
// @Test
// fun testSVDError() = DoubleTensorAlgebra{
// val buffer = doubleArrayOf(
// 1.000000, 2.000000, 3.000000,
// 2.000000, 3.000000, 4.000000,
// 3.000000, 4.000000, 5.000000,
// 4.000000, 5.000000, 6.000000,
// 5.000000, 6.000000, 7.000000
// )
// testSVDFor(fromArray(intArrayOf(5, 3), buffer))
// }
@Test @Test
fun testBatchedSVD() = DoubleTensorAlgebra { fun testBatchedSVD() = DoubleTensorAlgebra {
val tensor = randomNormal(intArrayOf(2, 5, 3), 0) val tensor = randomNormal(intArrayOf(2, 5, 3), 0)
@ -189,6 +201,28 @@ internal class TestDoubleLinearOpsTensorAlgebra {
testSVDGolabKahanFor(fromArray(intArrayOf(5, 3), buffer)) testSVDGolabKahanFor(fromArray(intArrayOf(5, 3), buffer))
} }
// @Test
// fun testSVDGolabKahanError() = DoubleTensorAlgebra{
// val buffer = doubleArrayOf(
// 1.0, 2.0, 3.0, 2.0, 3.0,
// 4.0, 3.0, 4.0, 5.0, 4.0,
// 5.0, 6.0, 5.0, 6.0, 7.0
// )
// testSVDGolabKahanFor(fromArray(intArrayOf(3, 5), buffer))
// }
@Test
fun testSVDGolabKahanBig() = DoubleTensorAlgebra{
val tensor = DoubleTensorAlgebra.randomNormal(intArrayOf(100, 100, 100), 0)
testSVDGolabKahanFor(tensor)
}
@Test
fun testSVDBig() = DoubleTensorAlgebra{
val tensor = DoubleTensorAlgebra.randomNormal(intArrayOf(100, 100, 100), 0)
testSVDFor(tensor)
}
@Test @Test
fun testBatchedSVDGolabKahan() = DoubleTensorAlgebra{ fun testBatchedSVDGolabKahan() = DoubleTensorAlgebra{
val tensor = randomNormal(intArrayOf(2, 5, 3), 0) val tensor = randomNormal(intArrayOf(2, 5, 3), 0)