argmax fixed
This commit is contained in:
parent
e80e1dcf62
commit
b2b063196d
@ -90,7 +90,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
|||||||
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
|
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
|
||||||
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
|
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
|
||||||
|
|
||||||
public override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
|
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
|
||||||
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
|
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
|
||||||
|
|
||||||
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
|
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
|
||||||
|
@ -28,7 +28,8 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
|||||||
*
|
*
|
||||||
* @return the value of a scalar tensor.
|
* @return the value of a scalar tensor.
|
||||||
*/
|
*/
|
||||||
public fun Tensor<T>.value(): T
|
public fun Tensor<T>.value(): T = valueOrNull()
|
||||||
|
?: throw IllegalArgumentException("Illegal value call for non scalar tensor")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Each element of the tensor [other] is added to this value.
|
* Each element of the tensor [other] is added to this value.
|
||||||
|
Loading…
Reference in New Issue
Block a user