flatten refactor

This commit is contained in:
Roland Grinis 2021-07-10 17:53:41 +01:00
parent a0b72f519b
commit c7de0bc4ee
3 changed files with 15 additions and 15 deletions

View File

@ -182,10 +182,10 @@ class JNoa {
public static native long unaryMinus(long tensorHandle); public static native long unaryMinus(long tensorHandle);
public static native long absTensor(long tensorHandle);
public static native long transposeTensor(long tensorHandle, int i, int j); public static native long transposeTensor(long tensorHandle, int i, int j);
public static native long absTensor(long tensorHandle);
public static native long expTensor(long tensorHandle); public static native long expTensor(long tensorHandle);
public static native long lnTensor(long tensorHandle); public static native long lnTensor(long tensorHandle);
@ -246,7 +246,7 @@ class JNoa {
public static native long argMaxTensor(long tensorHandle, int dim, boolean keepDim); public static native long argMaxTensor(long tensorHandle, int dim, boolean keepDim);
public static native long flattenTensor(long tensorHandle); public static native long flattenTensor(long tensorHandle, int startDim, int endDim);
public static native long matmul(long lhs, long rhs); public static native long matmul(long lhs, long rhs);

View File

@ -124,8 +124,8 @@ protected constructor(protected val scope: NoaScope) :
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): NoaIntTensor = override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): NoaIntTensor =
NoaIntTensor(scope, JNoa.argMaxTensor(tensor.tensorHandle, dim, keepDim)) NoaIntTensor(scope, JNoa.argMaxTensor(tensor.tensorHandle, dim, keepDim))
public fun Tensor<T>.flatten(): TensorType = public fun Tensor<T>.flatten(startDim: Int, endDim: Int): TensorType =
wrap(JNoa.flattenTensor(tensor.tensorHandle)) wrap(JNoa.flattenTensor(tensor.tensorHandle, startDim, endDim))
public fun Tensor<T>.randDiscrete(low: Long, high: Long): TensorType = public fun Tensor<T>.randDiscrete(low: Long, high: Long): TensorType =
wrap(JNoa.randintLike(tensor.tensorHandle, low, high)) wrap(JNoa.randintLike(tensor.tensorHandle, low, high))

View File

@ -655,14 +655,6 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_minusTensorAssign
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_unaryMinus JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_unaryMinus
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: absTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor
(JNIEnv *, jclass, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: transposeTensor * Method: transposeTensor
@ -671,6 +663,14 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_transposeTensor
(JNIEnv *, jclass, jlong, jint, jint); (JNIEnv *, jclass, jlong, jint, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: absTensor
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_absTensor
(JNIEnv *, jclass, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: expTensor * Method: expTensor
@ -914,10 +914,10 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_argMaxTensor
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: flattenTensor * Method: flattenTensor
* Signature: (J)J * Signature: (JII)J
*/ */
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_flattenTensor JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_flattenTensor
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong, jint, jint);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa