fixes
This commit is contained in:
parent
1b1a078dea
commit
14ca7cdd31
@ -23,8 +23,8 @@ public interface LinearOpsTensorAlgebra<T> :
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
|
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
|
||||||
* Given a square matrix `a`, return the matrix `aInv` satisfying
|
* Given a square matrix `A`, return the matrix `AInv` satisfying
|
||||||
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``.
|
* `A dot AInv = AInv dot A = eye(a.shape[0])`.
|
||||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
||||||
*
|
*
|
||||||
* @return the multiplicative inverse of a matrix.
|
* @return the multiplicative inverse of a matrix.
|
||||||
@ -37,7 +37,7 @@ public interface LinearOpsTensorAlgebra<T> :
|
|||||||
* Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
|
* Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
|
||||||
* positive-definite matrix or the Cholesky decompositions for a batch of such matrices.
|
* positive-definite matrix or the Cholesky decompositions for a batch of such matrices.
|
||||||
* Each decomposition has the form:
|
* Each decomposition has the form:
|
||||||
* Given a tensor `input`, return the tensor `L` satisfying ``input = L * L.H``,
|
* Given a tensor `input`, return the tensor `L` satisfying `input = L dot L.H`,
|
||||||
* where L is a lower-triangular matrix and L.H is the conjugate transpose of L,
|
* where L is a lower-triangular matrix and L.H is the conjugate transpose of L,
|
||||||
* which is just a transpose for the case of real-valued input matrices.
|
* which is just a transpose for the case of real-valued input matrices.
|
||||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
|
||||||
@ -50,7 +50,7 @@ public interface LinearOpsTensorAlgebra<T> :
|
|||||||
* QR decomposition.
|
* QR decomposition.
|
||||||
*
|
*
|
||||||
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
|
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a pair `(Q, R)` of tensors.
|
||||||
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``,
|
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q dot R``,
|
||||||
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
|
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
|
||||||
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
|
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
|
||||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
||||||
@ -63,7 +63,7 @@ public interface LinearOpsTensorAlgebra<T> :
|
|||||||
* LUP decomposition
|
* LUP decomposition
|
||||||
*
|
*
|
||||||
* Computes the LUP decomposition of a matrix or a batch of matrices.
|
* Computes the LUP decomposition of a matrix or a batch of matrices.
|
||||||
* Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``,
|
* Given a tensor `input`, return tensors (P, L, U) satisfying `P dot input = L dot U`,
|
||||||
* with `P` being a permutation matrix or batch of matrices,
|
* with `P` being a permutation matrix or batch of matrices,
|
||||||
* `L` being a lower triangular matrix or batch of matrices,
|
* `L` being a lower triangular matrix or batch of matrices,
|
||||||
* `U` being an upper triangular matrix or batch of matrices.
|
* `U` being an upper triangular matrix or batch of matrices.
|
||||||
@ -77,7 +77,8 @@ public interface LinearOpsTensorAlgebra<T> :
|
|||||||
*
|
*
|
||||||
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
|
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
|
||||||
* The singular value decomposition is represented as a triple `(U, S, V)`,
|
* The singular value decomposition is represented as a triple `(U, S, V)`,
|
||||||
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``.
|
* such that `input = U dot diagonalEmbedding(S) dot V.H`,
|
||||||
|
* where V.H is the conjugate transpose of V.
|
||||||
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
|
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
|
||||||
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
||||||
*
|
*
|
||||||
|
@ -288,7 +288,7 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
|||||||
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
|
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return the maximum value of all elements in the input tensor.
|
* Returns the maximum value of all elements in the input tensor.
|
||||||
*/
|
*/
|
||||||
public fun Tensor<T>.max(): T
|
public fun Tensor<T>.max(): T
|
||||||
|
|
||||||
|
@ -343,7 +343,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
|||||||
val m2 = newOther.shape[newOther.shape.size - 2]
|
val m2 = newOther.shape[newOther.shape.size - 2]
|
||||||
val n = newOther.shape[newOther.shape.size - 1]
|
val n = newOther.shape[newOther.shape.size - 1]
|
||||||
check(m1 == m2) {
|
check(m1 == m2) {
|
||||||
throw RuntimeException("Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)")
|
"Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)"
|
||||||
}
|
}
|
||||||
|
|
||||||
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
|
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
|
||||||
@ -436,9 +436,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
|||||||
* @param epsilon permissible error when comparing two Double values.
|
* @param epsilon permissible error when comparing two Double values.
|
||||||
* @return true if two tensors have the same shape and elements, false otherwise.
|
* @return true if two tensors have the same shape and elements, false otherwise.
|
||||||
*/
|
*/
|
||||||
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean {
|
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean =
|
||||||
return tensor.eq(other) { x, y -> abs(x - y) < epsilon }
|
tensor.eq(other) { x, y -> abs(x - y) < epsilon }
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compares element-wise two tensors.
|
* Compares element-wise two tensors.
|
||||||
@ -510,7 +509,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build tensor from rows of input tensor
|
* Builds tensor from rows of input tensor
|
||||||
*
|
*
|
||||||
* @param indices the [IntArray] of 1-dimensional indices
|
* @param indices the [IntArray] of 1-dimensional indices
|
||||||
* @return tensor with rows corresponding to rows by [indices]
|
* @return tensor with rows corresponding to rows by [indices]
|
||||||
|
@ -23,7 +23,6 @@ internal fun stridesFromShape(shape: IntArray): IntArray {
|
|||||||
current--
|
current--
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||||
|
Loading…
Reference in New Issue
Block a user