data for jit modules

This commit is contained in:
Roland Grinis 2021-07-12 17:18:02 +01:00
parent 09923a6c22
commit 1029871047
3 changed files with 54 additions and 3 deletions

View File

@ -299,4 +299,12 @@ class JNoa {
public static native long forwardPass(long jitModuleHandle, long tensorHandle); public static native long forwardPass(long jitModuleHandle, long tensorHandle);
public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle); public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle);
public static native long getModuleParameter(long jitModuleHandle, String name);
public static native void setModuleParameter(long jitModuleHandle, String name, long tensorHandle);
public static native long getModuleBuffer(long jitModuleHandle, String name);
public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle);
} }

View File

@ -142,12 +142,24 @@ protected constructor(protected val scope: NoaScope) :
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
public fun NoaJitModule.forward(parameters: TensorType): TensorType = public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensorHandle)) wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensor.tensorHandle))
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit = public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle) JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle)
public fun NoaJitModule.getParameter(name: String): TensorType =
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
public fun NoaJitModule.setParameter(name: String, parameter: Tensor<T>): Unit =
JNoa.setModuleParameter(this.jitModuleHandle, name, parameter.tensor.tensorHandle)
public fun NoaJitModule.getBuffer(name: String): TensorType =
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
JNoa.setModuleBuffer(this.jitModuleHandle, name, buffer.tensor.tensorHandle)
} }
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>> public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
@ -304,7 +316,6 @@ protected constructor(scope: NoaScope) :
public fun NoaJitModule.train(status: Boolean): Unit = public fun NoaJitModule.train(status: Boolean): Unit =
JNoa.trainMode(this.jitModuleHandle, status) JNoa.trainMode(this.jitModuleHandle, status)
} }
public sealed class NoaDoubleAlgebra public sealed class NoaDoubleAlgebra

View File

@ -1127,6 +1127,38 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPass
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign
(JNIEnv *, jclass, jlong, jlong); (JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getModuleParameter
* Signature: (JLjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleParameter
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setModuleParameter
* Signature: (JLjava/lang/String;J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleParameter
(JNIEnv *, jclass, jlong, jstring, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: getModuleBuffer
* Signature: (JLjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleBuffer
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: setModuleBuffer
* Signature: (JLjava/lang/String;J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer
(JNIEnv *, jclass, jlong, jstring, jlong);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif