Merge pull request #333 from mipt-npm/fix/tensor-cov-api

remove cov from tensors API
This commit is contained in:
Alexander Nozik 2021-05-14 09:10:57 +03:00 committed by GitHub
commit a86e8eb164
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 11 deletions

View File

@ -67,16 +67,6 @@ public interface AnalyticTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
*/
public fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the covariance matrix M of given vectors.
*
* M[i, j] contains covariance of i-th and j-th given vectors
*
* @param tensors the [List] of 1-dimensional tensors with same shape
* @return the covariance matrix
*/
public fun cov(tensors: List<Tensor<T>>): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.exp.html
public fun Tensor<T>.exp(): Tensor<T>

View File

@ -629,7 +629,15 @@ public open class DoubleTensorAlgebra :
return ((x - x.mean()) * (y - y.mean())).mean() * n / (n - 1)
}
override fun cov(tensors: List<Tensor<Double>>): DoubleTensor {
/**
* Returns the covariance matrix M of given vectors.
*
* M[i, j] contains covariance of i-th and j-th given vectors
*
* @param tensors the [List] of 1-dimensional tensors with same shape
* @return the covariance matrix
*/
public fun cov(tensors: List<Tensor<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val n = tensors.size
val m = tensors[0].shape[0]