SVD test to be fixed

This commit is contained in:
Roland Grinis 2021-04-06 11:04:00 +01:00
parent 4336788a6b
commit dcdc22dd9d
4 changed files with 44 additions and 31 deletions

View File

@ -40,7 +40,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html //https://pytorch.org/docs/stable/generated/torch.diag_embed.html
public fun diagonalEmbedding( public fun diagonalEmbedding(
diagonalEntries: TensorType, diagonalEntries: TensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1 offset: Int = 0, dim1: Int = 0, dim2: Int = 1
): TensorType ): TensorType
} }

View File

@ -362,6 +362,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
return true return true
} }
public fun randNormal(shape: IntArray, seed: Long = 0): DoubleTensor =
DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed))
public fun DoubleTensor.randNormalLike(seed: Long = 0): DoubleTensor =
DoubleTensor(this.shape, getRandomNormals(this.shape.reduce(Int::times), seed))
} }

View File

@ -1,7 +1,5 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.samplers.BoxMullerNormalizedGaussianSampler
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
import kotlin.random.Random import kotlin.random.Random
import kotlin.math.* import kotlin.math.*

View File

@ -7,21 +7,6 @@ import kotlin.test.assertTrue
class TestDoubleLinearOpsTensorAlgebra { class TestDoubleLinearOpsTensorAlgebra {
private val eps = 1e-5
private fun Double.epsEqual(other: Double): Boolean {
return abs(this - other) < eps
}
fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
if (abs(elem1 - elem2) > eps) {
return false
}
}
return true
}
@Test @Test
fun testDetLU() = DoubleLinearOpsTensorAlgebra { fun testDetLU() = DoubleLinearOpsTensorAlgebra {
val tensor = fromArray( val tensor = fromArray(
@ -136,22 +121,46 @@ class TestDoubleLinearOpsTensorAlgebra {
@Test @Test
fun testSVD() = DoubleLinearOpsTensorAlgebra { fun testSVD() = DoubleLinearOpsTensorAlgebra {
val epsilon = 1e-10 testSVDFor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
fun test_tensor(tensor: DoubleTensor) { testSVDFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
val svd = tensor.svd() }
val tensorSVD = svd.first @Test
.dot( fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
diagonalEmbedding(svd.second, 0, 0, 1) val tensor = randNormal(intArrayOf(7, 5, 3), 0)
.dot(svd.third.transpose(0, 1)) val (tensorU, tensorS, tensorV) = tensor.svd()
) val tensorSVD = tensorU dot (diagonalEmbedding(tensorS,0,1,2) dot tensorV)
println(tensor.eq(tensorSVD))
}
for ((x1, x2) in tensor.buffer.array() zip tensorSVD.buffer.array()) {
assertTrue { abs(x1 - x2) < epsilon } }
}
private inline fun Double.epsEqual(other: Double, eps: Double = 1e-5): Boolean {
return abs(this - other) < eps
}
private inline fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
if (abs(elem1 - elem2) > eps) {
return false
} }
test_tensor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))) }
test_tensor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0))) return true
}
private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
val svd = tensor.svd()
val tensorSVD = svd.first
.dot(
diagonalEmbedding(svd.second, 0, 0, 1)
.dot(svd.third.transpose(0, 1))
)
for ((x1, x2) in tensor.buffer.array() zip tensorSVD.buffer.array()) {
assertTrue { abs(x1 - x2) < epsilon }
} }
} }