Min max refactor
This commit is contained in:
parent
d0281871fa
commit
218b81a242
@ -14,43 +14,6 @@ package space.kscience.kmath.tensors.api
|
||||
public interface AnalyticTensorAlgebra<T> :
|
||||
TensorPartialDivisionAlgebra<T> {
|
||||
|
||||
|
||||
/**
|
||||
* @return the minimum value of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.min(): T
|
||||
|
||||
/**
|
||||
* Returns the minimum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the minimum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* @return the maximum value of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.max(): T
|
||||
|
||||
/**
|
||||
* Returns the maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* @return the mean of all elements in the input tensor.
|
||||
*/
|
||||
|
@ -269,6 +269,41 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
*/
|
||||
public fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* @return the minimum value of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.min(): T
|
||||
|
||||
/**
|
||||
* Returns the minimum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the minimum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* @return the maximum value of all elements in the input tensor.
|
||||
*/
|
||||
public fun Tensor<T>.max(): T
|
||||
|
||||
/**
|
||||
* Returns the maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
|
||||
}
|
||||
|
@ -15,16 +15,6 @@ public object DoubleAnalyticTensorAlgebra :
|
||||
AnalyticTensorAlgebra<Double>,
|
||||
DoubleTensorAlgebra() {
|
||||
|
||||
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
|
||||
|
||||
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
|
||||
|
||||
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
|
||||
|
||||
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
|
||||
|
||||
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
|
||||
|
||||
override fun Tensor<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
|
@ -9,6 +9,8 @@ import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.*
|
||||
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra.fold
|
||||
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra.foldDim
|
||||
import space.kscience.kmath.tensors.core.broadcastOuterTensors
|
||||
import space.kscience.kmath.tensors.core.checkBufferShapeConsistency
|
||||
import space.kscience.kmath.tensors.core.checkEmptyDoubleBuffer
|
||||
@ -447,4 +449,16 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
||||
|
||||
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.sum() }, dim, keepDim)
|
||||
|
||||
|
||||
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! }
|
||||
|
||||
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
|
||||
|
||||
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! }
|
||||
|
||||
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
|
||||
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.algebras.DoubleAnalyticTensorAlgebra.tan
|
||||
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
|
||||
import kotlin.math.*
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertTrue
|
||||
@ -106,7 +106,7 @@ internal class TestDoubleAnalyticTensorAlgebra {
|
||||
val tensor2 = DoubleTensor(shape2, buffer2)
|
||||
|
||||
@Test
|
||||
fun testMin() = DoubleAnalyticTensorAlgebra {
|
||||
fun testMin() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor2.min() == -3.0 }
|
||||
assertTrue { tensor2.min(0, true) eq fromArray(
|
||||
intArrayOf(1, 2),
|
||||
@ -119,7 +119,7 @@ internal class TestDoubleAnalyticTensorAlgebra {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMax() = DoubleAnalyticTensorAlgebra {
|
||||
fun testMax() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor2.max() == 4.0 }
|
||||
assertTrue { tensor2.max(0, true) eq fromArray(
|
||||
intArrayOf(1, 2),
|
||||
@ -132,7 +132,7 @@ internal class TestDoubleAnalyticTensorAlgebra {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSum() = DoubleAnalyticTensorAlgebra {
|
||||
fun testSum() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor2.sum() == 4.0 }
|
||||
assertTrue { tensor2.sum(0, true) eq fromArray(
|
||||
intArrayOf(1, 2),
|
||||
|
Loading…
Reference in New Issue
Block a user