implement svd function and tests for it

This commit is contained in:
AlyaNovikova 2021-04-06 00:06:14 +03:00
parent 3e98240b94
commit 814eab8cde
2 changed files with 108 additions and 1 deletions

View File

@ -3,6 +3,8 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import kotlin.math.abs
import kotlin.math.min
public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
@ -89,8 +91,81 @@ public class DoubleLinearOpsTensorAlgebra :
return qTensor to rTensor
}
internal fun svd1d(a: DoubleTensor, epsilon: Double = 1e-10): DoubleTensor {
val (n, m) = a.shape
var v: DoubleTensor
val b: DoubleTensor
if (n > m) {
b = a.transpose(0, 1).dot(a)
v = DoubleTensor(intArrayOf(m), getRandomNormals(m, 0))
} else {
b = a.dot(a.transpose(0, 1))
v = DoubleTensor(intArrayOf(n), getRandomNormals(n, 0))
}
var lastV: DoubleTensor
while (true) {
lastV = v
v = b.dot(lastV)
val norm = DoubleAnalyticTensorAlgebra { (v dot v).sqrt().value() }
v = v.times(1.0 / norm)
if (abs(v.dot(lastV).value()) > 1 - epsilon) {
return v
}
}
}
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
TODO("ALYA")
val size = this.shape.size
val commonShape = this.shape.sliceArray(0 until size - 2)
val (n, m) = this.shape.sliceArray(size - 2 until size)
val resU = zeros(commonShape + intArrayOf(n, min(n, m)))
val resS = zeros(commonShape + intArrayOf(min(n, m)))
val resV = zeros(commonShape + intArrayOf(min(n, m), m))
for ((matrix, USV) in this.matrixSequence()
.zip(resU.matrixSequence().zip(resS.vectorSequence().zip(resV.matrixSequence())))) {
val res = ArrayList<Triple<Double, DoubleTensor, DoubleTensor>>(0)
val (matrixU, SV) = USV
val (matrixS, matrixV) = SV
for (k in 0 until min(n, m)) {
var a = matrix.asTensor().copy()
for ((singularValue, u, v) in res.slice(0 until k)) {
val outerProduct = DoubleArray(u.shape[0] * v.shape[0])
for (i in 0 until u.shape[0]) {
for (j in 0 until v.shape[0]) {
outerProduct[i * v.shape[0] + j] = u[i].value() * v[j].value()
}
}
a = a - singularValue.times(DoubleTensor(intArrayOf(u.shape[0], v.shape[0]), outerProduct))
}
var v: DoubleTensor
var u: DoubleTensor
var norm: Double
if (n > m) {
v = svd1d(a)
u = matrix.asTensor().dot(v)
norm = DoubleAnalyticTensorAlgebra { (u dot u).sqrt().value() }
u = u.times(1.0 / norm)
} else {
u = svd1d(a)
v = matrix.asTensor().transpose(0, 1).dot(u)
norm = DoubleAnalyticTensorAlgebra { (v dot v).sqrt().value() }
v = v.times(1.0 / norm)
}
res.add(Triple(norm, u, v))
}
val s = res.map { it.first }.toDoubleArray()
val uBuffer = res.map { it.second }.flatMap { it.buffer.array().toList() }.toDoubleArray()
val vBuffer = res.map { it.third }.flatMap { it.buffer.array().toList() }.toDoubleArray()
uBuffer.copyInto(matrixU.buffer.array())
s.copyInto(matrixS.buffer.array())
vBuffer.copyInto(matrixV.buffer.array())
}
return Triple(resU, resS, resV.transpose(size - 2, size - 1))
}
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {

View File

@ -122,4 +122,36 @@ class TestDoubleLinearOpsTensorAlgebra {
assertTrue { p.dot(tensor).buffer.array().epsEqual(l.dot(u).buffer.array()) }
}
@Test
fun svd1d() = DoubleLinearOpsTensorAlgebra {
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = svd1d(tensor2)
assertTrue(res.shape contentEquals intArrayOf(2))
assertTrue { abs(abs(res.buffer.array()[res.bufferStart]) - 0.386) < 0.01}
assertTrue { abs(abs(res.buffer.array()[res.bufferStart + 1]) - 0.922) < 0.01}
}
@Test
fun svd() = DoubleLinearOpsTensorAlgebra {
val epsilon = 1e-10
fun test_tensor(tensor: DoubleTensor) {
val svd = tensor.svd()
val tensorSVD = svd.first
.dot(
diagonalEmbedding(svd.second, 0, 0, 1)
.dot(svd.third.transpose(0, 1))
)
for ((x1, x2) in tensor.buffer.array() zip tensorSVD.buffer.array()) {
assertTrue { abs(x1 - x2) < epsilon }
}
}
test_tensor(fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)))
test_tensor(fromArray(intArrayOf(2, 2), doubleArrayOf(-1.0, 0.0, 239.0, 238.0)))
}
}