Open epsilon to client to control numerical precision for power methods

This commit is contained in:
Roland Grinis 2021-04-09 09:56:37 +01:00
parent 3f0dff3ce9
commit fe8579180d
5 changed files with 20 additions and 15 deletions

View File

@ -26,6 +26,6 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType> public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
//https://pytorch.org/docs/stable/generated/torch.symeig.html //https://pytorch.org/docs/stable/generated/torch.symeig.html
public fun TensorType.symEig(eigenvectors: Boolean = true): Pair<TensorType, TensorType> public fun TensorType.symEig(): Pair<TensorType, TensorType>
} }

View File

@ -94,8 +94,10 @@ public class DoubleLinearOpsTensorAlgebra :
return qTensor to rTensor return qTensor to rTensor
} }
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
svd(epsilon = 1e-10)
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun DoubleTensor.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = this.linearStructure.dim val size = this.linearStructure.dim
val commonShape = this.shape.sliceArray(0 until size - 2) val commonShape = this.shape.sliceArray(0 until size - 2)
val (n, m) = this.shape.sliceArray(size - 2 until size) val (n, m) = this.shape.sliceArray(size - 2 until size)
@ -110,17 +112,20 @@ public class DoubleLinearOpsTensorAlgebra :
matrix.shape, matrix.shape,
matrix.buffer.array().slice(matrix.bufferStart until matrix.bufferStart + size).toDoubleArray() matrix.buffer.array().slice(matrix.bufferStart until matrix.bufferStart + size).toDoubleArray()
) )
svdHelper(curMatrix, USV, m, n) svdHelper(curMatrix, USV, m, n, epsilon)
} }
return Triple(resU.transpose(), resS, resV.transpose()) return Triple(resU.transpose(), resS, resV.transpose())
} }
override fun DoubleTensor.symEig(): Pair<DoubleTensor, DoubleTensor> =
symEig(epsilon = 1e-15)
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html //http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> { public fun DoubleTensor.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(this) checkSymmetric(this, epsilon)
val (u, s, v) = this.svd() val (u, s, v) = this.svd(epsilon)
val shp = s.shape + intArrayOf(1) val shp = s.shape + intArrayOf(1)
val utv = (u.transpose() dot v).map { if (abs(it) < 0.99) 0.0 else sign(it) } val utv = (u.transpose() dot v).map { if (abs(it) < 0.9) 0.0 else sign(it) }
val eig = (utv dot s.view(shp)).view(s.shape) val eig = (utv dot s.view(shp)).view(s.shape)
return Pair(eig, v) return Pair(eig, v)
} }

View File

@ -58,7 +58,7 @@ internal inline fun <T, TensorType : TensorStructure<T>,
} }
} }
internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor): Unit = internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, epsilon: Double = 1e-6): Unit =
check(tensor.eq(tensor.transpose())){ check(tensor.eq(tensor.transpose(), epsilon)) {
"Tensor is not symmetric about the last 2 dimensions" "Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
} }

View File

@ -246,7 +246,7 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svd1d(a: DoubleTensor, epsilon:
internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper( internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
matrix: DoubleTensor, matrix: DoubleTensor,
USV: Pair<BufferedTensor<Double>, Pair<BufferedTensor<Double>, BufferedTensor<Double>>>, USV: Pair<BufferedTensor<Double>, Pair<BufferedTensor<Double>, BufferedTensor<Double>>>,
m: Int, n: Int m: Int, n: Int, epsilon: Double
): Unit { ): Unit {
val res = ArrayList<Triple<Double, DoubleTensor, DoubleTensor>>(0) val res = ArrayList<Triple<Double, DoubleTensor, DoubleTensor>>(0)
val (matrixU, SV) = USV val (matrixU, SV) = USV
@ -267,12 +267,12 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper(
var u: DoubleTensor var u: DoubleTensor
var norm: Double var norm: Double
if (n > m) { if (n > m) {
v = svd1d(a) v = svd1d(a, epsilon)
u = matrix.dot(v) u = matrix.dot(v)
norm = DoubleAnalyticTensorAlgebra { (u dot u).sqrt().value() } norm = DoubleAnalyticTensorAlgebra { (u dot u).sqrt().value() }
u = u.times(1.0 / norm) u = u.times(1.0 / norm)
} else { } else {
u = svd1d(a) u = svd1d(a, epsilon)
v = matrix.transpose(0, 1).dot(u) v = matrix.transpose(0, 1).dot(u)
norm = DoubleAnalyticTensorAlgebra { (v dot v).sqrt().value() } norm = DoubleAnalyticTensorAlgebra { (v dot v).sqrt().value() }
v = v.times(1.0 / norm) v = v.times(1.0 / norm)

View File

@ -143,7 +143,7 @@ class TestDoubleLinearOpsTensorAlgebra {
val tensorSigma = tensor + tensor.transpose() val tensorSigma = tensor + tensor.transpose()
val (tensorS, tensorV) = tensorSigma.symEig() val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensorSigma.eq(tensorSigmaCalc, 0.01)) assertTrue(tensorSigma.eq(tensorSigmaCalc))
} }