argmax fixed

This commit is contained in:
Roland Grinis 2021-07-08 12:30:27 +01:00
parent e80e1dcf62
commit b2b063196d
2 changed files with 3 additions and 2 deletions

View File

@ -90,7 +90,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
public override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()

View File

@ -28,7 +28,8 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
*
* @return the value of a scalar tensor.
*/
public fun Tensor<T>.value(): T
public fun Tensor<T>.value(): T = valueOrNull()
?: throw IllegalArgumentException("Illegal value call for non scalar tensor")
/**
* Each element of the tensor [other] is added to this value.