Merge commit 'refs/pull/315/head' of ssh://git.jetbrains.space/mipt-npm/sci/kmath into feature/tensor-algebra
This commit is contained in:
commit
0680a3a1cb
@ -81,7 +81,7 @@ public interface LinearOpsTensorAlgebra<T> :
|
|||||||
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
|
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
|
||||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
||||||
*
|
*
|
||||||
* @return the determinant.
|
* @return triple `(U, S, V)`.
|
||||||
*/
|
*/
|
||||||
public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
|
public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
|
||||||
|
|
||||||
|
@ -20,7 +20,10 @@ import space.kscience.kmath.tensors.core.luPivotHelper
|
|||||||
import space.kscience.kmath.tensors.core.pivInit
|
import space.kscience.kmath.tensors.core.pivInit
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implementation of common linear algebra operations on double numbers.
|
||||||
|
* Implements the LinearOpsTensorAlgebra<Double> interface.
|
||||||
|
*/
|
||||||
public object DoubleLinearOpsTensorAlgebra :
|
public object DoubleLinearOpsTensorAlgebra :
|
||||||
LinearOpsTensorAlgebra<Double>,
|
LinearOpsTensorAlgebra<Double>,
|
||||||
DoubleTensorAlgebra() {
|
DoubleTensorAlgebra() {
|
||||||
@ -29,12 +32,41 @@ public object DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9)
|
override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the LU factorization of a matrix or batches of matrices `input`.
|
||||||
|
* Returns a tuple containing the LU factorization and pivots of `input`.
|
||||||
|
*
|
||||||
|
* @param epsilon permissible error when comparing the determinant of a matrix with zero
|
||||||
|
* @return pair of `factorization` and `pivots`.
|
||||||
|
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
|
||||||
|
* The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
|
||||||
|
*/
|
||||||
public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
|
public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
|
||||||
computeLU(tensor, epsilon)
|
computeLU(tensor, epsilon)
|
||||||
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
|
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the LU factorization of a matrix or batches of matrices `input`.
|
||||||
|
* Returns a tuple containing the LU factorization and pivots of `input`.
|
||||||
|
* Uses an error of ``1e-9`` when calculating whether a matrix is degenerate.
|
||||||
|
*
|
||||||
|
* @return pair of `factorization` and `pivots`.
|
||||||
|
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
|
||||||
|
* The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
|
||||||
|
*/
|
||||||
public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
|
public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unpacks the data and pivots from a LU factorization of a tensor.
|
||||||
|
* Given a tensor [luTensor], return tensors (P, L, U) satisfying ``P * luTensor = L * U``,
|
||||||
|
* with `P` being a permutation matrix or batch of matrices,
|
||||||
|
* `L` being a lower triangular matrix or batch of matrices,
|
||||||
|
* `U` being an upper triangular matrix or batch of matrices.
|
||||||
|
*
|
||||||
|
* @param luTensor the packed LU factorization data
|
||||||
|
* @param pivotsTensor the packed LU factorization pivots
|
||||||
|
* @return triple of P, L and U tensors
|
||||||
|
*/
|
||||||
public fun luPivot(
|
public fun luPivot(
|
||||||
luTensor: Tensor<Double>,
|
luTensor: Tensor<Double>,
|
||||||
pivotsTensor: Tensor<Int>
|
pivotsTensor: Tensor<Int>
|
||||||
@ -66,6 +98,18 @@ public object DoubleLinearOpsTensorAlgebra :
|
|||||||
return Triple(pTensor, lTensor, uTensor)
|
return Triple(pTensor, lTensor, uTensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* QR decomposition.
|
||||||
|
*
|
||||||
|
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
|
||||||
|
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``,
|
||||||
|
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
|
||||||
|
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
|
||||||
|
*
|
||||||
|
* @param epsilon permissible error when comparing tensors for equality.
|
||||||
|
* Used when checking the positive definiteness of the input matrix or matrices.
|
||||||
|
* @return pair of Q and R tensors.
|
||||||
|
*/
|
||||||
public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor {
|
public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor {
|
||||||
checkSquareMatrix(shape)
|
checkSquareMatrix(shape)
|
||||||
checkPositiveDefinite(tensor, epsilon)
|
checkPositiveDefinite(tensor, epsilon)
|
||||||
@ -98,6 +142,18 @@ public object DoubleLinearOpsTensorAlgebra :
|
|||||||
override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
|
override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
|
||||||
svd(epsilon = 1e-10)
|
svd(epsilon = 1e-10)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Singular Value Decomposition.
|
||||||
|
*
|
||||||
|
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
|
||||||
|
* The singular value decomposition is represented as a triple `(U, S, V)`,
|
||||||
|
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``.
|
||||||
|
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
|
||||||
|
*
|
||||||
|
* @param epsilon permissible error when calculating the dot product of vectors,
|
||||||
|
* i.e. the precision with which the cosine approaches 1 in an iterative algorithm.
|
||||||
|
* @return triple `(U, S, V)`.
|
||||||
|
*/
|
||||||
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
val size = tensor.linearStructure.dim
|
val size = tensor.linearStructure.dim
|
||||||
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
||||||
@ -125,7 +181,14 @@ public object DoubleLinearOpsTensorAlgebra :
|
|||||||
override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
|
override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
|
||||||
symEig(epsilon = 1e-15)
|
symEig(epsilon = 1e-15)
|
||||||
|
|
||||||
//For information: http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
|
/**
|
||||||
|
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
|
||||||
|
* represented by a pair (eigenvalues, eigenvectors).
|
||||||
|
*
|
||||||
|
* @param epsilon permissible error when comparing tensors for equality
|
||||||
|
* and when the cosine approaches 1 in the SVD algorithm.
|
||||||
|
* @return a pair (eigenvalues, eigenvectors)
|
||||||
|
*/
|
||||||
public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
|
public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
|
||||||
checkSymmetric(tensor, epsilon)
|
checkSymmetric(tensor, epsilon)
|
||||||
val (u, s, v) = tensor.svd(epsilon)
|
val (u, s, v) = tensor.svd(epsilon)
|
||||||
@ -139,6 +202,13 @@ public object DoubleLinearOpsTensorAlgebra :
|
|||||||
return eig to v
|
return eig to v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the determinant of a square matrix input, or of each square matrix in a batched input
|
||||||
|
* using LU factorization algorithm.
|
||||||
|
*
|
||||||
|
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
|
||||||
|
* @return the determinant.
|
||||||
|
*/
|
||||||
public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
|
public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
|
||||||
|
|
||||||
checkSquareMatrix(tensor.shape)
|
checkSquareMatrix(tensor.shape)
|
||||||
@ -164,6 +234,15 @@ public object DoubleLinearOpsTensorAlgebra :
|
|||||||
return detTensor
|
return detTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input
|
||||||
|
* using LU factorization algorithm.
|
||||||
|
* Given a square matrix `a`, return the matrix `aInv` satisfying
|
||||||
|
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``.
|
||||||
|
*
|
||||||
|
* @param epsilon error in the LU algorithm - permissible error when comparing the determinant of a matrix with zero
|
||||||
|
* @return the multiplicative inverse of a matrix.
|
||||||
|
*/
|
||||||
public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
|
public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
|
||||||
val (luTensor, pivotsTensor) = luFactor(epsilon)
|
val (luTensor, pivotsTensor) = luFactor(epsilon)
|
||||||
val invTensor = luTensor.zeroesLike()
|
val invTensor = luTensor.zeroesLike()
|
||||||
@ -177,6 +256,18 @@ public object DoubleLinearOpsTensorAlgebra :
|
|||||||
return invTensor
|
return invTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LUP decomposition
|
||||||
|
*
|
||||||
|
* Computes the LUP decomposition of a matrix or a batch of matrices.
|
||||||
|
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
|
||||||
|
* with `P` being a permutation matrix or batch of matrices,
|
||||||
|
* `L` being a lower triangular matrix or batch of matrices,
|
||||||
|
* `U` being an upper triangular matrix or batch of matrices.
|
||||||
|
*
|
||||||
|
* @param epsilon permissible error when comparing the determinant of a matrix with zero
|
||||||
|
* @return triple of P, L and U tensors
|
||||||
|
*/
|
||||||
public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
val (lu, pivots) = this.luFactor(epsilon)
|
val (lu, pivots) = this.luFactor(epsilon)
|
||||||
return luPivot(lu, pivots)
|
return luPivot(lu, pivots)
|
||||||
|
Loading…
Reference in New Issue
Block a user