diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt index 527e5d386..ec070b6bd 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt @@ -81,7 +81,7 @@ public interface LinearOpsTensorAlgebra : * If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input. * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd * - * @return the determinant. + * @return triple `(U, S, V)`. */ public fun Tensor.svd(): Triple, Tensor, Tensor> diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt index 89345e315..d17dc70fe 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt @@ -20,7 +20,10 @@ import space.kscience.kmath.tensors.core.luPivotHelper import space.kscience.kmath.tensors.core.pivInit import kotlin.math.min - +/** + * Implementation of common linear algebra operations on double numbers. + * Implements the LinearOpsTensorAlgebra interface. + */ public object DoubleLinearOpsTensorAlgebra : LinearOpsTensorAlgebra, DoubleTensorAlgebra() { @@ -29,12 +32,41 @@ public object DoubleLinearOpsTensorAlgebra : override fun Tensor.det(): DoubleTensor = detLU(1e-9) + /** + * Computes the LU factorization of a matrix or batches of matrices `input`. + * Returns a tuple containing the LU factorization and pivots of `input`. + * + * @param epsilon permissible error when comparing the determinant of a matrix with zero + * @return pair of `factorization` and `pivots`. + * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. + * The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. + */ public fun Tensor.luFactor(epsilon: Double): Pair = computeLU(tensor, epsilon) ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") + /** + * Computes the LU factorization of a matrix or batches of matrices `input`. + * Returns a tuple containing the LU factorization and pivots of `input`. + * Uses an error of ``1e-9`` when calculating whether a matrix is degenerate. + * + * @return pair of `factorization` and `pivots`. + * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. + * The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. + */ public fun Tensor.luFactor(): Pair = luFactor(1e-9) + /** + * Unpacks the data and pivots from a LU factorization of a tensor. + * Given a tensor [luTensor], return tensors (P, L, U) satisfying ``P * luTensor = L * U``, + * with `P` being a permutation matrix or batch of matrices, + * `L` being a lower triangular matrix or batch of matrices, + * `U` being an upper triangular matrix or batch of matrices. + * + * @param luTensor the packed LU factorization data + * @param pivotsTensor the packed LU factorization pivots + * @return triple of P, L and U tensors + */ public fun luPivot( luTensor: Tensor, pivotsTensor: Tensor @@ -66,6 +98,18 @@ public object DoubleLinearOpsTensorAlgebra : return Triple(pTensor, lTensor, uTensor) } + /** + * QR decomposition. + * + * Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors. + * Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``, + * with `Q` being an orthogonal matrix or batch of orthogonal matrices + * and `R` being an upper triangular matrix or batch of upper triangular matrices. + * + * @param epsilon permissible error when comparing tensors for equality. + * Used when checking the positive definiteness of the input matrix or matrices. + * @return pair of Q and R tensors. + */ public fun Tensor.cholesky(epsilon: Double): DoubleTensor { checkSquareMatrix(shape) checkPositiveDefinite(tensor, epsilon) @@ -98,6 +142,18 @@ public object DoubleLinearOpsTensorAlgebra : override fun Tensor.svd(): Triple = svd(epsilon = 1e-10) + /** + * Singular Value Decomposition. + * + * Computes the singular value decomposition of either a matrix or batch of matrices `input`. + * The singular value decomposition is represented as a triple `(U, S, V)`, + * such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``. + * If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input. + * + * @param epsilon permissible error when calculating the dot product of vectors, + * i.e. the precision with which the cosine approaches 1 in an iterative algorithm. + * @return triple `(U, S, V)`. + */ public fun Tensor.svd(epsilon: Double): Triple { val size = tensor.linearStructure.dim val commonShape = tensor.shape.sliceArray(0 until size - 2) @@ -125,7 +181,14 @@ public object DoubleLinearOpsTensorAlgebra : override fun Tensor.symEig(): Pair = symEig(epsilon = 1e-15) - //For information: http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html + /** + * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, + * represented by a pair (eigenvalues, eigenvectors). + * + * @param epsilon permissible error when comparing tensors for equality + * and when the cosine approaches 1 in the SVD algorithm. + * @return a pair (eigenvalues, eigenvectors) + */ public fun Tensor.symEig(epsilon: Double): Pair { checkSymmetric(tensor, epsilon) val (u, s, v) = tensor.svd(epsilon) @@ -139,6 +202,13 @@ public object DoubleLinearOpsTensorAlgebra : return eig to v } + /** + * Computes the determinant of a square matrix input, or of each square matrix in a batched input + * using LU factorization algorithm. + * + * @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero + * @return the determinant. + */ public fun Tensor.detLU(epsilon: Double = 1e-9): DoubleTensor { checkSquareMatrix(tensor.shape) @@ -164,6 +234,15 @@ public object DoubleLinearOpsTensorAlgebra : return detTensor } + /** + * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input + * using LU factorization algorithm. + * Given a square matrix `a`, return the matrix `aInv` satisfying + * ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``. + * + * @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero + * @return the multiplicative inverse of a matrix. + */ public fun Tensor.invLU(epsilon: Double = 1e-9): DoubleTensor { val (luTensor, pivotsTensor) = luFactor(epsilon) val invTensor = luTensor.zeroesLike() @@ -177,6 +256,18 @@ public object DoubleLinearOpsTensorAlgebra : return invTensor } + /** + * LUP decomposition + * + * Computes the LUP decomposition of a matrix or a batch of matrices. + * Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``, + * with `P` being a permutation matrix or batch of matrices, + * `L` being a lower triangular matrix or batch of matrices, + * `U` being an upper triangular matrix or batch of matrices. + * + * @param epsilon permissible error when comparing the determinant of a matrix with zero + * @return triple of P, L and U tensors + */ public fun Tensor.lu(epsilon: Double = 1e-9): Triple { val (lu, pivots) = this.luFactor(epsilon) return luPivot(lu, pivots)