added tests for svd Golab Kahan
This commit is contained in:
parent
de2f82c878
commit
082317b8b5
@ -175,6 +175,28 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
assertTrue(tensor.eq(tensorSVD))
|
assertTrue(tensor.eq(tensorSVD))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSVDGolabKahan() = DoubleTensorAlgebra{
|
||||||
|
testSVDGolabKahanFor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
|
||||||
|
testSVDGolabKahanFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
|
||||||
|
val buffer = doubleArrayOf(
|
||||||
|
1.000000, 2.000000, 3.000000,
|
||||||
|
2.000000, 3.000000, 4.000000,
|
||||||
|
3.000000, 4.000000, 5.000000,
|
||||||
|
4.000000, 5.000000, 6.000000,
|
||||||
|
5.000000, 6.000000, 7.000000
|
||||||
|
)
|
||||||
|
testSVDGolabKahanFor(fromArray(intArrayOf(5, 3), buffer))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSVDGolabKahan() = DoubleTensorAlgebra{
|
||||||
|
val tensor = randomNormal(intArrayOf(2, 5, 3), 0)
|
||||||
|
val (tensorU, tensorS, tensorV) = tensor.svdGolabKahan()
|
||||||
|
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||||
|
assertTrue(tensor.eq(tensorSVD))
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testBatchedSymEig() = DoubleTensorAlgebra {
|
fun testBatchedSymEig() = DoubleTensorAlgebra {
|
||||||
val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0)
|
val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0)
|
||||||
@ -199,3 +221,17 @@ private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double
|
|||||||
|
|
||||||
assertTrue(tensor.eq(tensorSVD, epsilon))
|
assertTrue(tensor.eq(tensorSVD, epsilon))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun DoubleTensorAlgebra.testSVDGolabKahanFor(tensor: DoubleTensor) {
|
||||||
|
val svd = tensor.svdGolabKahan()
|
||||||
|
|
||||||
|
val tensorSVD = svd.first
|
||||||
|
.dot(
|
||||||
|
diagonalEmbedding(svd.second)
|
||||||
|
.dot(svd.third.transpose())
|
||||||
|
)
|
||||||
|
|
||||||
|
assertTrue(tensor.eq(tensorSVD))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user