SVD test to be fixed

This commit is contained in:
Roland Grinis 2021-04-06 11:04:00 +01:00
parent 4336788a6b
commit dcdc22dd9d
4 changed files with 44 additions and 31 deletions

View File

@ -40,7 +40,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
public fun diagonalEmbedding(
diagonalEntries: TensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
offset: Int = 0, dim1: Int = 0, dim2: Int = 1
): TensorType
}

View File

@ -362,6 +362,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
return true
}
public fun randNormal(shape: IntArray, seed: Long = 0): DoubleTensor =
DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed))
public fun DoubleTensor.randNormalLike(seed: Long = 0): DoubleTensor =
DoubleTensor(this.shape, getRandomNormals(this.shape.reduce(Int::times), seed))
}

View File

@ -1,7 +1,5 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.samplers.BoxMullerNormalizedGaussianSampler
import space.kscience.kmath.structures.*
import kotlin.random.Random
import kotlin.math.*

View File

@ -7,21 +7,6 @@ import kotlin.test.assertTrue
class TestDoubleLinearOpsTensorAlgebra {
private val eps = 1e-5
private fun Double.epsEqual(other: Double): Boolean {
return abs(this - other) < eps
}
fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
if (abs(elem1 - elem2) > eps) {
return false
}
}
return true
}
@Test
fun testDetLU() = DoubleLinearOpsTensorAlgebra {
val tensor = fromArray(
@ -136,8 +121,35 @@ class TestDoubleLinearOpsTensorAlgebra {
@Test
fun testSVD() = DoubleLinearOpsTensorAlgebra {
val epsilon = 1e-10
fun test_tensor(tensor: DoubleTensor) {
testSVDFor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
testSVDFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
}
@Test
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(intArrayOf(7, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svd()
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS,0,1,2) dot tensorV)
println(tensor.eq(tensorSVD))
}
}
private inline fun Double.epsEqual(other: Double, eps: Double = 1e-5): Boolean {
return abs(this - other) < eps
}
private inline fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
if (abs(elem1 - elem2) > eps) {
return false
}
}
return true
}
private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
val svd = tensor.svd()
val tensorSVD = svd.first
@ -150,8 +162,5 @@ class TestDoubleLinearOpsTensorAlgebra {
assertTrue { abs(x1 - x2) < epsilon }
}
}
test_tensor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
test_tensor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
}
}