From 803a88ac2c9f17cb5168fd783a2e0d58459f4f4c Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 8 Jul 2021 21:37:13 +0100 Subject: [PATCH] NoaTensorAlgebra --- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 5 +---- .../main/java/space/kscience/kmath/noa/JNoa.java | 12 ++++++++++++ .../kotlin/space/kscience/kmath/noa/algebras.kt | 16 ++++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index bc58eefc3..6e84371ae 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -28,8 +28,6 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra */ public fun INDArray.wrap(): Nd4jArrayStructure - public fun INDArray.wrapInt(): Nd4jArrayStructure - /** * Unwraps to or acquires [INDArray] from [StructureND]. */ @@ -93,7 +91,7 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra public override fun Tensor.viewAs(other: Tensor): Tensor = view(other.shape) override fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor = - ndBase.get().argmax(ndArray, keepDim, dim).wrapInt() + ndBase.get().argmax(ndArray, keepDim, dim).asIntStructure() public override fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor = ndArray.mean(keepDim, dim).wrap() @@ -146,7 +144,6 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra */ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { public override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() - public override fun INDArray.wrapInt(): Nd4jArrayStructure = asIntStructure() @OptIn(PerformancePitfall::class) public override val StructureND.ndArray: INDArray diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index 20f8c7510..f2d65aa5c 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -191,6 +191,18 @@ class JNoa { public static native long sumDimTensor(long tensorHandle, int dim, boolean keepDim); + public static native long minTensor(long tensorHandle); + + public static native long minDimTensor(long tensorHandle, int dim, boolean keepDim); + + public static native long maxTensor(long tensorHandle); + + public static native long maxDimTensor(long tensorHandle, int dim, boolean keepDim); + + public static native long argMaxTensor(long tensorHandle, int dim, boolean keepDim); + + public static native long flattenTensor(long tensorHandle); + public static native long matmul(long lhs, long rhs); public static native void matmulAssign(long lhs, long rhs); diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index a5d69b6bd..3a0ce28df 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -100,6 +100,22 @@ constructor(protected val scope: NoaScope) : TensorAlgebra { override fun Tensor.sum(dim: Int, keepDim: Boolean): TensorType = wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim)) + public fun Tensor.minAll(): TensorType = wrap(JNoa.minTensor(this.cast().tensorHandle)) + override fun Tensor.min(): T = minAll().item() + override fun Tensor.min(dim: Int, keepDim: Boolean): TensorType = + wrap(JNoa.minDimTensor(this.cast().tensorHandle, dim, keepDim)) + + public fun Tensor.maxAll(): TensorType = wrap(JNoa.maxTensor(this.cast().tensorHandle)) + override fun Tensor.max(): T = maxAll().item() + override fun Tensor.max(dim: Int, keepDim: Boolean): TensorType = + wrap(JNoa.maxDimTensor(this.cast().tensorHandle, dim, keepDim)) + + override fun Tensor.argMax(dim: Int, keepDim: Boolean): NoaIntTensor = + NoaIntTensor(scope, JNoa.argMaxTensor(this.cast().tensorHandle, dim, keepDim)) + + public fun Tensor.flatten(): TensorType = + wrap(JNoa.flattenTensor(this.cast().tensorHandle)) + public fun Tensor.copy(): TensorType = wrap(JNoa.copyTensor(this.cast().tensorHandle))