forked from kscience/kmath
SVD test to be fixed
This commit is contained in:
parent
4336788a6b
commit
dcdc22dd9d
@ -40,7 +40,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||
public fun diagonalEmbedding(
|
||||
diagonalEntries: TensorType,
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||
offset: Int = 0, dim1: Int = 0, dim2: Int = 1
|
||||
): TensorType
|
||||
|
||||
}
|
||||
|
@ -362,6 +362,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
return true
|
||||
}
|
||||
|
||||
public fun randNormal(shape: IntArray, seed: Long = 0): DoubleTensor =
|
||||
DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed))
|
||||
|
||||
public fun DoubleTensor.randNormalLike(seed: Long = 0): DoubleTensor =
|
||||
DoubleTensor(this.shape, getRandomNormals(this.shape.reduce(Int::times), seed))
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.stat.RandomGenerator
|
||||
import space.kscience.kmath.stat.samplers.BoxMullerNormalizedGaussianSampler
|
||||
import space.kscience.kmath.structures.*
|
||||
import kotlin.random.Random
|
||||
import kotlin.math.*
|
||||
|
@ -7,21 +7,6 @@ import kotlin.test.assertTrue
|
||||
|
||||
class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
private val eps = 1e-5
|
||||
|
||||
private fun Double.epsEqual(other: Double): Boolean {
|
||||
return abs(this - other) < eps
|
||||
}
|
||||
|
||||
fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
|
||||
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
|
||||
if (abs(elem1 - elem2) > eps) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDetLU() = DoubleLinearOpsTensorAlgebra {
|
||||
val tensor = fromArray(
|
||||
@ -136,22 +121,46 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun testSVD() = DoubleLinearOpsTensorAlgebra {
|
||||
val epsilon = 1e-10
|
||||
fun test_tensor(tensor: DoubleTensor) {
|
||||
val svd = tensor.svd()
|
||||
testSVDFor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
|
||||
testSVDFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
|
||||
}
|
||||
|
||||
val tensorSVD = svd.first
|
||||
.dot(
|
||||
diagonalEmbedding(svd.second, 0, 0, 1)
|
||||
.dot(svd.third.transpose(0, 1))
|
||||
)
|
||||
@Test
|
||||
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
|
||||
val tensor = randNormal(intArrayOf(7, 5, 3), 0)
|
||||
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS,0,1,2) dot tensorV)
|
||||
println(tensor.eq(tensorSVD))
|
||||
}
|
||||
|
||||
for ((x1, x2) in tensor.buffer.array() zip tensorSVD.buffer.array()) {
|
||||
assertTrue { abs(x1 - x2) < epsilon }
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private inline fun Double.epsEqual(other: Double, eps: Double = 1e-5): Boolean {
|
||||
return abs(this - other) < eps
|
||||
}
|
||||
|
||||
private inline fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
|
||||
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
|
||||
if (abs(elem1 - elem2) > eps) {
|
||||
return false
|
||||
}
|
||||
test_tensor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
|
||||
test_tensor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
|
||||
val svd = tensor.svd()
|
||||
|
||||
val tensorSVD = svd.first
|
||||
.dot(
|
||||
diagonalEmbedding(svd.second, 0, 0, 1)
|
||||
.dot(svd.third.transpose(0, 1))
|
||||
)
|
||||
|
||||
for ((x1, x2) in tensor.buffer.array() zip tensorSVD.buffer.array()) {
|
||||
assertTrue { abs(x1 - x2) < epsilon }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user