This commit is contained in:
Andrei Kislitsyn 2021-05-07 13:00:20 +03:00
parent 1b1a078dea
commit 14ca7cdd31
6 changed files with 16 additions and 17 deletions

View File

@ -23,8 +23,8 @@ public interface LinearOpsTensorAlgebra<T> :
/** /**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input. * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
* Given a square matrix `a`, return the matrix `aInv` satisfying * Given a square matrix `A`, return the matrix `AInv` satisfying
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``. * `A dot AInv = AInv dot A = eye(a.shape[0])`.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
* *
* @return the multiplicative inverse of a matrix. * @return the multiplicative inverse of a matrix.
@ -37,7 +37,7 @@ public interface LinearOpsTensorAlgebra<T> :
* Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices) * Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
* positive-definite matrix or the Cholesky decompositions for a batch of such matrices. * positive-definite matrix or the Cholesky decompositions for a batch of such matrices.
* Each decomposition has the form: * Each decomposition has the form:
* Given a tensor `input`, return the tensor `L` satisfying ``input = L * L.H``, * Given a tensor `input`, return the tensor `L` satisfying `input = L dot L.H`,
* where L is a lower-triangular matrix and L.H is the conjugate transpose of L, * where L is a lower-triangular matrix and L.H is the conjugate transpose of L,
* which is just a transpose for the case of real-valued input matrices. * which is just a transpose for the case of real-valued input matrices.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
@ -50,7 +50,7 @@ public interface LinearOpsTensorAlgebra<T> :
* QR decomposition. * QR decomposition.
* *
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors. * Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``, * Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q dot R``,
* with `Q` being an orthogonal matrix or batch of orthogonal matrices * with `Q` being an orthogonal matrix or batch of orthogonal matrices
* and `R` being an upper triangular matrix or batch of upper triangular matrices. * and `R` being an upper triangular matrix or batch of upper triangular matrices.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
@ -63,7 +63,7 @@ public interface LinearOpsTensorAlgebra<T> :
* LUP decomposition * LUP decomposition
* *
* Computes the LUP decomposition of a matrix or a batch of matrices. * Computes the LUP decomposition of a matrix or a batch of matrices.
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``, * Given a tensor `input`, return tensors (P, L, U) satisfying `P dot input = L dot U`,
* with `P` being a permutation matrix or batch of matrices, * with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices, * `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices. * `U` being an upper triangular matrix or batch of matrices.
@ -77,7 +77,8 @@ public interface LinearOpsTensorAlgebra<T> :
* *
* Computes the singular value decomposition of either a matrix or batch of matrices `input`. * Computes the singular value decomposition of either a matrix or batch of matrices `input`.
* The singular value decomposition is represented as a triple `(U, S, V)`, * The singular value decomposition is represented as a triple `(U, S, V)`,
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``. * such that `input = U dot diagonalEmbedding(S) dot V.H`,
* where V.H is the conjugate transpose of V.
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input. * If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd * For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
* *
@ -94,4 +95,4 @@ public interface LinearOpsTensorAlgebra<T> :
*/ */
public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>> public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>>
} }

View File

@ -288,7 +288,7 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* @return the maximum value of all elements in the input tensor. * Returns the maximum value of all elements in the input tensor.
*/ */
public fun Tensor<T>.max(): T public fun Tensor<T>.max(): T

View File

@ -343,7 +343,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
val m2 = newOther.shape[newOther.shape.size - 2] val m2 = newOther.shape[newOther.shape.size - 2]
val n = newOther.shape[newOther.shape.size - 1] val n = newOther.shape[newOther.shape.size - 1]
check(m1 == m2) { check(m1 == m2) {
throw RuntimeException("Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)") "Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)"
} }
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last()) val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
@ -436,9 +436,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
* @param epsilon permissible error when comparing two Double values. * @param epsilon permissible error when comparing two Double values.
* @return true if two tensors have the same shape and elements, false otherwise. * @return true if two tensors have the same shape and elements, false otherwise.
*/ */
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean { public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean =
return tensor.eq(other) { x, y -> abs(x - y) < epsilon } tensor.eq(other) { x, y -> abs(x - y) < epsilon }
}
/** /**
* Compares element-wise two tensors. * Compares element-wise two tensors.
@ -510,7 +509,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
/** /**
* Build tensor from rows of input tensor * Builds tensor from rows of input tensor
* *
* @param indices the [IntArray] of 1-dimensional indices * @param indices the [IntArray] of 1-dimensional indices
* @return tensor with rows corresponding to rows by [indices] * @return tensor with rows corresponding to rows by [indices]

View File

@ -23,7 +23,6 @@ internal fun stridesFromShape(shape: IntArray): IntArray {
current-- current--
} }
return res return res
} }
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray { internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
@ -55,4 +54,4 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides {
override val linearSize: Int override val linearSize: Int
get() = shape.reduce(Int::times) get() = shape.reduce(Int::times)
} }

View File

@ -33,4 +33,4 @@ public fun IntTensor.toIntArray(): IntArray {
return IntArray(numElements) { i -> return IntArray(numElements) { i ->
mutableBuffer[bufferStart + i] mutableBuffer[bufferStart + i]
} }
} }

View File

@ -39,4 +39,4 @@ internal val Tensor<Int>.tensor: IntTensor
get() = when (this) { get() = when (this) {
is IntTensor -> this is IntTensor -> this
else -> this.toBufferedTensor().asTensor() else -> this.toBufferedTensor().asTensor()
} }