From b2b063196d036efa7af854e8a78496b4017a8d4a Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 8 Jul 2021 12:30:27 +0100 Subject: [PATCH] argmax fixed --- .../main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 2 +- .../kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index 456f7c2a9..57bc974ac 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -90,7 +90,7 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra public override fun Tensor.view(shape: IntArray): Tensor = ndArray.reshape(shape).wrap() public override fun Tensor.viewAs(other: Tensor): Tensor = view(other.shape) - public override fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor = + public fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor = ndBase.get().argmax(ndArray, keepDim, dim).wrap() public override fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor = ndArray.mean(keepDim, dim).wrap() diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 90d6cd686..56aafbc3a 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -28,7 +28,8 @@ public interface TensorAlgebra : Algebra> { * * @return the value of a scalar tensor. */ - public fun Tensor.value(): T + public fun Tensor.value(): T = valueOrNull() + ?: throw IllegalArgumentException("Illegal value call for non scalar tensor") /** * Each element of the tensor [other] is added to this value.