This commit is contained in:
Andrei Kislitsyn 2021-05-03 20:42:18 +03:00
parent 8898f908ef
commit 7f8914d8ea
2 changed files with 5 additions and 19 deletions

View File

@ -9,16 +9,12 @@ import space.kscience.kmath.tensors.core.DoubleTensor
/** /**
* Common algebra with statistics methods. Operates on [Tensor]. * Common algebra with statistics methods. Operates on [Tensor].
*
* @param T the type of items closed under division in the tensors.
*/ */
public interface StatisticTensorAlgebra<T>: TensorAlgebra<T> { public interface StatisticTensorAlgebra<T>: TensorAlgebra<T> {
/** /**
* Returns the minimum value of all elements in the input tensor. * Returns the minimum value of all elements in the input tensor.
*
* @return the minimum value of all elements in the input tensor.
*/ */
public fun Tensor<T>.min(): Double public fun Tensor<T>.min(): Double
@ -37,8 +33,6 @@ public interface StatisticTensorAlgebra<T>: TensorAlgebra<T> {
/** /**
* Returns the maximum value of all elements in the input tensor. * Returns the maximum value of all elements in the input tensor.
*
* @return the maximum value of all elements in the input tensor.
*/ */
public fun Tensor<T>.max(): Double public fun Tensor<T>.max(): Double
@ -57,8 +51,6 @@ public interface StatisticTensorAlgebra<T>: TensorAlgebra<T> {
/** /**
* Returns the sum of all elements in the input tensor. * Returns the sum of all elements in the input tensor.
*
* @return the sum of all elements in the input tensor.
*/ */
public fun Tensor<T>.sum(): Double public fun Tensor<T>.sum(): Double
@ -77,8 +69,6 @@ public interface StatisticTensorAlgebra<T>: TensorAlgebra<T> {
/** /**
* Returns the mean of all elements in the input tensor. * Returns the mean of all elements in the input tensor.
*
* @return the mean of all elements in the input tensor.
*/ */
public fun Tensor<T>.mean(): Double public fun Tensor<T>.mean(): Double
@ -97,8 +87,6 @@ public interface StatisticTensorAlgebra<T>: TensorAlgebra<T> {
/** /**
* Returns the standard deviation of all elements in the input tensor. * Returns the standard deviation of all elements in the input tensor.
*
* @return the standard deviation of all elements in the input tensor.
*/ */
public fun Tensor<T>.std(): Double public fun Tensor<T>.std(): Double
@ -117,8 +105,6 @@ public interface StatisticTensorAlgebra<T>: TensorAlgebra<T> {
/** /**
* Returns the variance of all elements in the input tensor. * Returns the variance of all elements in the input tensor.
*
* @return the variance of all elements in the input tensor.
*/ */
public fun Tensor<T>.variance(): Double public fun Tensor<T>.variance(): Double
@ -134,4 +120,5 @@ public interface StatisticTensorAlgebra<T>: TensorAlgebra<T> {
* @return the variance of each row of the input tensor in the given dimension [dim]. * @return the variance of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.variance(dim: Int, keepDim: Boolean): DoubleTensor public fun Tensor<T>.variance(dim: Int, keepDim: Boolean): DoubleTensor
}
}

View File

@ -17,9 +17,8 @@ import space.kscience.kmath.tensors.core.algebras.DoubleStatisticTensorAlgebra.v
public object DoubleStatisticTensorAlgebra : StatisticTensorAlgebra<Double>, DoubleTensorAlgebra() { public object DoubleStatisticTensorAlgebra : StatisticTensorAlgebra<Double>, DoubleTensorAlgebra() {
private fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double { private fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
return foldFunction(this.tensor.toDoubleArray()) foldFunction(this.tensor.toDoubleArray())
}
private fun Tensor<Double>.foldDim( private fun Tensor<Double>.foldDim(
foldFunction: (DoubleArray) -> Double, foldFunction: (DoubleArray) -> Double,
@ -102,4 +101,4 @@ public object DoubleStatisticTensorAlgebra : StatisticTensorAlgebra<Double>, Dou
keepDim keepDim
) )
} }