Min max refactor

This commit is contained in:
Roland Grinis 2021-05-05 16:11:46 +01:00
parent d0281871fa
commit 218b81a242
5 changed files with 53 additions and 51 deletions

View File

@ -14,43 +14,6 @@ package space.kscience.kmath.tensors.api
public interface AnalyticTensorAlgebra<T> :
TensorPartialDivisionAlgebra<T> {
/**
* @return the minimum value of all elements in the input tensor.
*/
public fun Tensor<T>.min(): T
/**
* Returns the minimum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the minimum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the maximum value of all elements in the input tensor.
*/
public fun Tensor<T>.max(): T
/**
* Returns the maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the mean of all elements in the input tensor.
*/

View File

@ -269,6 +269,41 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
*/
public fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the minimum value of all elements in the input tensor.
*/
public fun Tensor<T>.min(): T
/**
* Returns the minimum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the minimum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
/**
* @return the maximum value of all elements in the input tensor.
*/
public fun Tensor<T>.max(): T
/**
* Returns the maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
}

View File

@ -15,16 +15,6 @@ public object DoubleAnalyticTensorAlgebra :
AnalyticTensorAlgebra<Double>,
DoubleTensorAlgebra() {
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
override fun Tensor<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor =

View File

@ -9,6 +9,8 @@ import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra.fold
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra.foldDim
import space.kscience.kmath.tensors.core.broadcastOuterTensors
import space.kscience.kmath.tensors.core.checkBufferShapeConsistency
import space.kscience.kmath.tensors.core.checkEmptyDoubleBuffer
@ -447,4 +449,16 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.sum() }, dim, keepDim)
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
}

View File

@ -2,7 +2,7 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra.tan
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
import kotlin.math.*
import kotlin.test.Test
import kotlin.test.assertTrue
@ -106,7 +106,7 @@ internal class TestDoubleAnalyticTensorAlgebra {
val tensor2 = DoubleTensor(shape2, buffer2)
@Test
fun testMin() = DoubleAnalyticTensorAlgebra {
fun testMin() = DoubleTensorAlgebra {
assertTrue { tensor2.min() == -3.0 }
assertTrue { tensor2.min(0, true) eq fromArray(
intArrayOf(1, 2),
@ -119,7 +119,7 @@ internal class TestDoubleAnalyticTensorAlgebra {
}
@Test
fun testMax() = DoubleAnalyticTensorAlgebra {
fun testMax() = DoubleTensorAlgebra {
assertTrue { tensor2.max() == 4.0 }
assertTrue { tensor2.max(0, true) eq fromArray(
intArrayOf(1, 2),
@ -132,7 +132,7 @@ internal class TestDoubleAnalyticTensorAlgebra {
}
@Test
fun testSum() = DoubleAnalyticTensorAlgebra {
fun testSum() = DoubleTensorAlgebra {
assertTrue { tensor2.sum() == 4.0 }
assertTrue { tensor2.sum(0, true) eq fromArray(
intArrayOf(1, 2),