argmax fixed
This commit is contained in:
parent
e80e1dcf62
commit
b2b063196d
@ -90,7 +90,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
||||
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
|
||||
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
|
||||
|
||||
public override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
|
||||
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
|
||||
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
|
||||
|
||||
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
|
||||
|
@ -28,7 +28,8 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
*
|
||||
* @return the value of a scalar tensor.
|
||||
*/
|
||||
public fun Tensor<T>.value(): T
|
||||
public fun Tensor<T>.value(): T = valueOrNull()
|
||||
?: throw IllegalArgumentException("Illegal value call for non scalar tensor")
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is added to this value.
|
||||
|
Loading…
Reference in New Issue
Block a user