data for jit modules
This commit is contained in:
parent
09923a6c22
commit
1029871047
@ -299,4 +299,12 @@ class JNoa {
|
||||
public static native long forwardPass(long jitModuleHandle, long tensorHandle);
|
||||
|
||||
public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle);
|
||||
|
||||
public static native long getModuleParameter(long jitModuleHandle, String name);
|
||||
|
||||
public static native void setModuleParameter(long jitModuleHandle, String name, long tensorHandle);
|
||||
|
||||
public static native long getModuleBuffer(long jitModuleHandle, String name);
|
||||
|
||||
public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle);
|
||||
}
|
||||
|
@ -142,12 +142,24 @@ protected constructor(protected val scope: NoaScope) :
|
||||
|
||||
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
|
||||
|
||||
public fun NoaJitModule.forward(parameters: TensorType): TensorType =
|
||||
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensorHandle))
|
||||
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
|
||||
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensor.tensorHandle))
|
||||
|
||||
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
|
||||
JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle)
|
||||
|
||||
public fun NoaJitModule.getParameter(name: String): TensorType =
|
||||
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
|
||||
|
||||
public fun NoaJitModule.setParameter(name: String, parameter: Tensor<T>): Unit =
|
||||
JNoa.setModuleParameter(this.jitModuleHandle, name, parameter.tensor.tensorHandle)
|
||||
|
||||
public fun NoaJitModule.getBuffer(name: String): TensorType =
|
||||
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
|
||||
|
||||
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
|
||||
JNoa.setModuleBuffer(this.jitModuleHandle, name, buffer.tensor.tensorHandle)
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||
@ -304,7 +316,6 @@ protected constructor(scope: NoaScope) :
|
||||
|
||||
public fun NoaJitModule.train(status: Boolean): Unit =
|
||||
JNoa.trainMode(this.jitModuleHandle, status)
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaDoubleAlgebra
|
||||
|
@ -1127,6 +1127,38 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPass
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getModuleParameter
|
||||
* Signature: (JLjava/lang/String;)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleParameter
|
||||
(JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setModuleParameter
|
||||
* Signature: (JLjava/lang/String;J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleParameter
|
||||
(JNIEnv *, jclass, jlong, jstring, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: getModuleBuffer
|
||||
* Signature: (JLjava/lang/String;)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleBuffer
|
||||
(JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: setModuleBuffer
|
||||
* Signature: (JLjava/lang/String;J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer
|
||||
(JNIEnv *, jclass, jlong, jstring, jlong);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user