Merge commit 'refs/pull/315/head' of ssh://git.jetbrains.space/mipt-npm/sci/kmath into feature/tensor-algebra

This commit is contained in:
Roland Grinis 2021-05-06 09:04:01 +01:00
commit 0680a3a1cb
2 changed files with 94 additions and 3 deletions

View File

@ -81,7 +81,7 @@ public interface LinearOpsTensorAlgebra<T> :
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input. * If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
* *
* @return the determinant. * @return triple `(U, S, V)`.
*/ */
public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>> public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>

View File

@ -20,7 +20,10 @@ import space.kscience.kmath.tensors.core.luPivotHelper
import space.kscience.kmath.tensors.core.pivInit import space.kscience.kmath.tensors.core.pivInit
import kotlin.math.min import kotlin.math.min
/**
* Implementation of common linear algebra operations on double numbers.
* Implements the LinearOpsTensorAlgebra<Double> interface.
*/
public object DoubleLinearOpsTensorAlgebra : public object DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double>, LinearOpsTensorAlgebra<Double>,
DoubleTensorAlgebra() { DoubleTensorAlgebra() {
@ -29,12 +32,41 @@ public object DoubleLinearOpsTensorAlgebra :
override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9) override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9)
/**
* Computes the LU factorization of a matrix or batches of matrices `input`.
* Returns a tuple containing the LU factorization and pivots of `input`.
*
* @param epsilon permissible error when comparing the determinant of a matrix with zero
* @return pair of `factorization` and `pivots`.
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/
public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> = public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon) computeLU(tensor, epsilon)
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
/**
* Computes the LU factorization of a matrix or batches of matrices `input`.
* Returns a tuple containing the LU factorization and pivots of `input`.
* Uses an error of ``1e-9`` when calculating whether a matrix is degenerate.
*
* @return pair of `factorization` and `pivots`.
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/
public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9) public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
/**
* Unpacks the data and pivots from a LU factorization of a tensor.
* Given a tensor [luTensor], return tensors (P, L, U) satisfying ``P * luTensor = L * U``,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
*
* @param luTensor the packed LU factorization data
* @param pivotsTensor the packed LU factorization pivots
* @return triple of P, L and U tensors
*/
public fun luPivot( public fun luPivot(
luTensor: Tensor<Double>, luTensor: Tensor<Double>,
pivotsTensor: Tensor<Int> pivotsTensor: Tensor<Int>
@ -66,6 +98,18 @@ public object DoubleLinearOpsTensorAlgebra :
return Triple(pTensor, lTensor, uTensor) return Triple(pTensor, lTensor, uTensor)
} }
/**
* QR decomposition.
*
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``,
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
*
* @param epsilon permissible error when comparing tensors for equality.
* Used when checking the positive definiteness of the input matrix or matrices.
* @return pair of Q and R tensors.
*/
public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor { public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor {
checkSquareMatrix(shape) checkSquareMatrix(shape)
checkPositiveDefinite(tensor, epsilon) checkPositiveDefinite(tensor, epsilon)
@ -98,6 +142,18 @@ public object DoubleLinearOpsTensorAlgebra :
override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
svd(epsilon = 1e-10) svd(epsilon = 1e-10)
/**
* Singular Value Decomposition.
*
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
* The singular value decomposition is represented as a triple `(U, S, V)`,
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``.
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
*
* @param epsilon permissible error when calculating the dot product of vectors,
* i.e. the precision with which the cosine approaches 1 in an iterative algorithm.
* @return triple `(U, S, V)`.
*/
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = tensor.linearStructure.dim val size = tensor.linearStructure.dim
val commonShape = tensor.shape.sliceArray(0 until size - 2) val commonShape = tensor.shape.sliceArray(0 until size - 2)
@ -125,7 +181,14 @@ public object DoubleLinearOpsTensorAlgebra :
override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
symEig(epsilon = 1e-15) symEig(epsilon = 1e-15)
//For information: http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html /**
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
* represented by a pair (eigenvalues, eigenvectors).
*
* @param epsilon permissible error when comparing tensors for equality
* and when the cosine approaches 1 in the SVD algorithm.
* @return a pair (eigenvalues, eigenvectors)
*/
public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> { public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(tensor, epsilon) checkSymmetric(tensor, epsilon)
val (u, s, v) = tensor.svd(epsilon) val (u, s, v) = tensor.svd(epsilon)
@ -139,6 +202,13 @@ public object DoubleLinearOpsTensorAlgebra :
return eig to v return eig to v
} }
/**
* Computes the determinant of a square matrix input, or of each square matrix in a batched input
* using LU factorization algorithm.
*
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
* @return the determinant.
*/
public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor { public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
checkSquareMatrix(tensor.shape) checkSquareMatrix(tensor.shape)
@ -164,6 +234,15 @@ public object DoubleLinearOpsTensorAlgebra :
return detTensor return detTensor
} }
/**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input
* using LU factorization algorithm.
* Given a square matrix `a`, return the matrix `aInv` satisfying
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``.
*
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
* @return the multiplicative inverse of a matrix.
*/
public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor { public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
val (luTensor, pivotsTensor) = luFactor(epsilon) val (luTensor, pivotsTensor) = luFactor(epsilon)
val invTensor = luTensor.zeroesLike() val invTensor = luTensor.zeroesLike()
@ -177,6 +256,18 @@ public object DoubleLinearOpsTensorAlgebra :
return invTensor return invTensor
} }
/**
* LUP decomposition
*
* Computes the LUP decomposition of a matrix or a batch of matrices.
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
*
* @param epsilon permissible error when comparing the determinant of a matrix with zero
* @return triple of P, L and U tensors
*/
public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val (lu, pivots) = this.luFactor(epsilon) val (lu, pivots) = this.luFactor(epsilon)
return luPivot(lu, pivots) return luPivot(lu, pivots)