This commit is contained in:
Andrei Kislitsyn 2021-05-07 13:00:20 +03:00
parent 1b1a078dea
commit 14ca7cdd31
6 changed files with 16 additions and 17 deletions

View File

@ -23,8 +23,8 @@ public interface LinearOpsTensorAlgebra<T> :
/**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
* Given a square matrix `a`, return the matrix `aInv` satisfying
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``.
* Given a square matrix `A`, return the matrix `AInv` satisfying
* `A dot AInv = AInv dot A = eye(a.shape[0])`.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
*
* @return the multiplicative inverse of a matrix.
@ -37,7 +37,7 @@ public interface LinearOpsTensorAlgebra<T> :
* Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
* positive-definite matrix or the Cholesky decompositions for a batch of such matrices.
* Each decomposition has the form:
* Given a tensor `input`, return the tensor `L` satisfying ``input = L * L.H``,
* Given a tensor `input`, return the tensor `L` satisfying `input = L dot L.H`,
* where L is a lower-triangular matrix and L.H is the conjugate transpose of L,
* which is just a transpose for the case of real-valued input matrices.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
@ -50,7 +50,7 @@ public interface LinearOpsTensorAlgebra<T> :
* QR decomposition.
*
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``,
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q dot R``,
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
@ -63,7 +63,7 @@ public interface LinearOpsTensorAlgebra<T> :
* LUP decomposition
*
* Computes the LUP decomposition of a matrix or a batch of matrices.
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
* Given a tensor `input`, return tensors (P, L, U) satisfying `P dot input = L dot U`,
* with `P` being a permutation matrix or batch of matrices,
* `L` being a lower triangular matrix or batch of matrices,
* `U` being an upper triangular matrix or batch of matrices.
@ -77,7 +77,8 @@ public interface LinearOpsTensorAlgebra<T> :
*
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
* The singular value decomposition is represented as a triple `(U, S, V)`,
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``.
* such that `input = U dot diagonalEmbedding(S) dot V.H`,
* where V.H is the conjugate transpose of V.
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
*
@ -94,4 +95,4 @@ public interface LinearOpsTensorAlgebra<T> :
*/
public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>>
}
}

View File

@ -288,7 +288,7 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the maximum value of all elements in the input tensor.
* Returns the maximum value of all elements in the input tensor.
*/
public fun Tensor<T>.max(): T

View File

@ -343,7 +343,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
val m2 = newOther.shape[newOther.shape.size - 2]
val n = newOther.shape[newOther.shape.size - 1]
check(m1 == m2) {
throw RuntimeException("Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)")
"Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)"
}
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
@ -436,9 +436,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
* @param epsilon permissible error when comparing two Double values.
* @return true if two tensors have the same shape and elements, false otherwise.
*/
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean {
return tensor.eq(other) { x, y -> abs(x - y) < epsilon }
}
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean =
tensor.eq(other) { x, y -> abs(x - y) < epsilon }
/**
* Compares element-wise two tensors.
@ -510,7 +509,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
}
/**
* Build tensor from rows of input tensor
* Builds tensor from rows of input tensor
*
* @param indices the [IntArray] of 1-dimensional indices
* @return tensor with rows corresponding to rows by [indices]

View File

@ -23,7 +23,6 @@ internal fun stridesFromShape(shape: IntArray): IntArray {
current--
}
return res
}
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
@ -55,4 +54,4 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides {
override val linearSize: Int
get() = shape.reduce(Int::times)
}
}

View File

@ -33,4 +33,4 @@ public fun IntTensor.toIntArray(): IntArray {
return IntArray(numElements) { i ->
mutableBuffer[bufferStart + i]
}
}
}

View File

@ -39,4 +39,4 @@ internal val Tensor<Int>.tensor: IntTensor
get() = when (this) {
is IntTensor -> this
else -> this.toBufferedTensor().asTensor()
}
}