argMax to core

This commit is contained in:
Roland Grinis 2021-07-08 12:13:21 +01:00
parent 623c96e4bb
commit e80e1dcf62
2 changed files with 13 additions and 13 deletions

View File

@ -312,16 +312,4 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
*/ */
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
} }

View File

@ -576,7 +576,19 @@ public open class DoubleTensorAlgebra :
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor = override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim) foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
/**
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> foldDim({ x ->
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble() x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
}, dim, keepDim) }, dim, keepDim)