From 082317b8b5f8c5c2625f7304bdb52bff83f49952 Mon Sep 17 00:00:00 2001 From: margarita0303 Date: Wed, 13 Jul 2022 19:54:10 +0300 Subject: [PATCH] added tests for svd Golab Kahan --- .../core/TestDoubleLinearOpsAlgebra.kt | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index e025d4b71..3b305b25a 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -175,6 +175,28 @@ internal class TestDoubleLinearOpsTensorAlgebra { assertTrue(tensor.eq(tensorSVD)) } + @Test + fun testSVDGolabKahan() = DoubleTensorAlgebra{ + testSVDGolabKahanFor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))) + testSVDGolabKahanFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0))) + val buffer = doubleArrayOf( + 1.000000, 2.000000, 3.000000, + 2.000000, 3.000000, 4.000000, + 3.000000, 4.000000, 5.000000, + 4.000000, 5.000000, 6.000000, + 5.000000, 6.000000, 7.000000 + ) + testSVDGolabKahanFor(fromArray(intArrayOf(5, 3), buffer)) + } + + @Test + fun testBatchedSVDGolabKahan() = DoubleTensorAlgebra{ + val tensor = randomNormal(intArrayOf(2, 5, 3), 0) + val (tensorU, tensorS, tensorV) = tensor.svdGolabKahan() + val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) + assertTrue(tensor.eq(tensorSVD)) + } + @Test fun testBatchedSymEig() = DoubleTensorAlgebra { val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0) @@ -199,3 +221,17 @@ private fun DoubleTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double assertTrue(tensor.eq(tensorSVD, epsilon)) } + +private fun DoubleTensorAlgebra.testSVDGolabKahanFor(tensor: DoubleTensor) { + val svd = tensor.svdGolabKahan() + + val tensorSVD = svd.first + .dot( + diagonalEmbedding(svd.second) + .dot(svd.third.transpose()) + ) + + assertTrue(tensor.eq(tensorSVD)) +} + +