argMax to core
This commit is contained in:
parent
623c96e4bb
commit
e80e1dcf62
@ -312,16 +312,4 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
|||||||
*/
|
*/
|
||||||
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
|
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
|
|
||||||
*
|
|
||||||
* If [keepDim] is true, the output tensor is of the same size as
|
|
||||||
* input except in the dimension [dim] where it is of size 1.
|
|
||||||
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
|
||||||
*
|
|
||||||
* @param dim the dimension to reduce.
|
|
||||||
* @param keepDim whether the output tensor has [dim] retained or not.
|
|
||||||
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
|
|
||||||
*/
|
|
||||||
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
|
|
||||||
}
|
}
|
||||||
|
@ -576,7 +576,19 @@ public open class DoubleTensorAlgebra :
|
|||||||
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||||
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
|
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
|
||||||
|
|
||||||
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
|
|
||||||
|
/**
|
||||||
|
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||||
|
*
|
||||||
|
* If [keepDim] is true, the output tensor is of the same size as
|
||||||
|
* input except in the dimension [dim] where it is of size 1.
|
||||||
|
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
|
||||||
|
*
|
||||||
|
* @param dim the dimension to reduce.
|
||||||
|
* @param keepDim whether the output tensor has [dim] retained or not.
|
||||||
|
* @return the the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||||
|
*/
|
||||||
|
public fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
|
||||||
foldDim({ x ->
|
foldDim({ x ->
|
||||||
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
|
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
|
||||||
}, dim, keepDim)
|
}, dim, keepDim)
|
||||||
|
Loading…
Reference in New Issue
Block a user