forked from kscience/kmath
SVD test to be fixed
This commit is contained in:
parent
4336788a6b
commit
dcdc22dd9d
@ -40,7 +40,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||||
public fun diagonalEmbedding(
|
public fun diagonalEmbedding(
|
||||||
diagonalEntries: TensorType,
|
diagonalEntries: TensorType,
|
||||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
offset: Int = 0, dim1: Int = 0, dim2: Int = 1
|
||||||
): TensorType
|
): TensorType
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -362,6 +362,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun randNormal(shape: IntArray, seed: Long = 0): DoubleTensor =
|
||||||
|
DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed))
|
||||||
|
|
||||||
|
public fun DoubleTensor.randNormalLike(seed: Long = 0): DoubleTensor =
|
||||||
|
DoubleTensor(this.shape, getRandomNormals(this.shape.reduce(Int::times), seed))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.stat.RandomGenerator
|
|
||||||
import space.kscience.kmath.stat.samplers.BoxMullerNormalizedGaussianSampler
|
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
@ -7,21 +7,6 @@ import kotlin.test.assertTrue
|
|||||||
|
|
||||||
class TestDoubleLinearOpsTensorAlgebra {
|
class TestDoubleLinearOpsTensorAlgebra {
|
||||||
|
|
||||||
private val eps = 1e-5
|
|
||||||
|
|
||||||
private fun Double.epsEqual(other: Double): Boolean {
|
|
||||||
return abs(this - other) < eps
|
|
||||||
}
|
|
||||||
|
|
||||||
fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
|
|
||||||
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
|
|
||||||
if (abs(elem1 - elem2) > eps) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testDetLU() = DoubleLinearOpsTensorAlgebra {
|
fun testDetLU() = DoubleLinearOpsTensorAlgebra {
|
||||||
val tensor = fromArray(
|
val tensor = fromArray(
|
||||||
@ -136,8 +121,35 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSVD() = DoubleLinearOpsTensorAlgebra {
|
fun testSVD() = DoubleLinearOpsTensorAlgebra {
|
||||||
val epsilon = 1e-10
|
testSVDFor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
|
||||||
fun test_tensor(tensor: DoubleTensor) {
|
testSVDFor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
|
||||||
|
val tensor = randNormal(intArrayOf(7, 5, 3), 0)
|
||||||
|
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||||
|
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS,0,1,2) dot tensorV)
|
||||||
|
println(tensor.eq(tensorSVD))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun Double.epsEqual(other: Double, eps: Double = 1e-5): Boolean {
|
||||||
|
return abs(this - other) < eps
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
|
||||||
|
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
|
||||||
|
if (abs(elem1 - elem2) > eps) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
|
||||||
val svd = tensor.svd()
|
val svd = tensor.svd()
|
||||||
|
|
||||||
val tensorSVD = svd.first
|
val tensorSVD = svd.first
|
||||||
@ -150,8 +162,5 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
assertTrue { abs(x1 - x2) < epsilon }
|
assertTrue { abs(x1 - x2) < epsilon }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
test_tensor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
|
|
||||||
test_tensor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user