argMax to core
This commit is contained in:
parent
623c96e4bb
commit
e80e1dcf62
@ -312,16 +312,4 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
*/
|
||||
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
/**
|
||||
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
}
|
||||
|
@ -576,7 +576,19 @@ public open class DoubleTensorAlgebra :
|
||||
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
|
||||
|
||||
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
|
||||
/**
|
||||
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*
|
||||
* If [keepDim] is true, the output tensor is of the same size as
|
||||
* input except in the dimension [dim] where it is of size 1.
|
||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||
*
|
||||
* @param dim the dimension to reduce.
|
||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||
foldDim({ x ->
|
||||
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
|
||||
}, dim, keepDim)
|
||||
|
Loading…
Reference in New Issue
Block a user