From c7de0bc4eedd0ac38f18437a038eca51777a8795 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Sat, 10 Jul 2021 17:53:41 +0100 Subject: [PATCH] flatten refactor --- .../java/space/kscience/kmath/noa/JNoa.java | 6 +++--- .../space/kscience/kmath/noa/algebras.kt | 4 ++-- .../resources/space_kscience_kmath_noa_JNoa.h | 20 +++++++++---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index 4b9aa492a..a5c9ba75b 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -182,10 +182,10 @@ class JNoa { public static native long unaryMinus(long tensorHandle); - public static native long absTensor(long tensorHandle); - public static native long transposeTensor(long tensorHandle, int i, int j); + public static native long absTensor(long tensorHandle); + public static native long expTensor(long tensorHandle); public static native long lnTensor(long tensorHandle); @@ -246,7 +246,7 @@ class JNoa { public static native long argMaxTensor(long tensorHandle, int dim, boolean keepDim); - public static native long flattenTensor(long tensorHandle); + public static native long flattenTensor(long tensorHandle, int startDim, int endDim); public static native long matmul(long lhs, long rhs); diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 36de8abf3..ec91436d6 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -124,8 +124,8 @@ protected constructor(protected val scope: NoaScope) : override fun Tensor.argMax(dim: Int, keepDim: Boolean): NoaIntTensor = NoaIntTensor(scope, JNoa.argMaxTensor(tensor.tensorHandle, dim, keepDim)) - public fun Tensor.flatten(): TensorType = - wrap(JNoa.flattenTensor(tensor.tensorHandle)) + public fun Tensor.flatten(startDim: Int, endDim: Int): TensorType = + wrap(JNoa.flattenTensor(tensor.tensorHandle, startDim, endDim)) public fun Tensor.randDiscrete(low: Long, high: Long): TensorType = wrap(JNoa.randintLike(tensor.tensorHandle, low, high)) diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index a9a5866c2..adc2ab5ae 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -655,14 +655,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_minusTensorAssign JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_unaryMinus (JNIEnv *, jclass, jlong); -/* - * Class: space_kscience_kmath_noa_JNoa - * Method: absTensor - * Signature: (J)J - */ -JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor - (JNIEnv *, jclass, jlong); - /* * Class: space_kscience_kmath_noa_JNoa * Method: transposeTensor @@ -671,6 +663,14 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor (JNIEnv *, jclass, jlong, jint, jint); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: absTensor + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor + (JNIEnv *, jclass, jlong); + /* * Class: space_kscience_kmath_noa_JNoa * Method: expTensor @@ -914,10 +914,10 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_argMaxTensor /* * Class: space_kscience_kmath_noa_JNoa * Method: flattenTensor - * Signature: (J)J + * Signature: (JII)J */ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_flattenTensor - (JNIEnv *, jclass, jlong); + (JNIEnv *, jclass, jlong, jint, jint); /* * Class: space_kscience_kmath_noa_JNoa