data for jit modules
This commit is contained in:
parent
09923a6c22
commit
1029871047
@ -299,4 +299,12 @@ class JNoa {
|
|||||||
public static native long forwardPass(long jitModuleHandle, long tensorHandle);
|
public static native long forwardPass(long jitModuleHandle, long tensorHandle);
|
||||||
|
|
||||||
public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle);
|
public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle);
|
||||||
|
|
||||||
|
public static native long getModuleParameter(long jitModuleHandle, String name);
|
||||||
|
|
||||||
|
public static native void setModuleParameter(long jitModuleHandle, String name, long tensorHandle);
|
||||||
|
|
||||||
|
public static native long getModuleBuffer(long jitModuleHandle, String name);
|
||||||
|
|
||||||
|
public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle);
|
||||||
}
|
}
|
||||||
|
@ -142,12 +142,24 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
|
|
||||||
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
|
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
|
||||||
|
|
||||||
public fun NoaJitModule.forward(parameters: TensorType): TensorType =
|
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
|
||||||
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensorHandle))
|
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensor.tensorHandle))
|
||||||
|
|
||||||
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
|
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
|
||||||
JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle)
|
JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle)
|
||||||
|
|
||||||
|
public fun NoaJitModule.getParameter(name: String): TensorType =
|
||||||
|
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
|
||||||
|
|
||||||
|
public fun NoaJitModule.setParameter(name: String, parameter: Tensor<T>): Unit =
|
||||||
|
JNoa.setModuleParameter(this.jitModuleHandle, name, parameter.tensor.tensorHandle)
|
||||||
|
|
||||||
|
public fun NoaJitModule.getBuffer(name: String): TensorType =
|
||||||
|
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name))
|
||||||
|
|
||||||
|
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
|
||||||
|
JNoa.setModuleBuffer(this.jitModuleHandle, name, buffer.tensor.tensorHandle)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||||
@ -304,7 +316,6 @@ protected constructor(scope: NoaScope) :
|
|||||||
|
|
||||||
public fun NoaJitModule.train(status: Boolean): Unit =
|
public fun NoaJitModule.train(status: Boolean): Unit =
|
||||||
JNoa.trainMode(this.jitModuleHandle, status)
|
JNoa.trainMode(this.jitModuleHandle, status)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaDoubleAlgebra
|
public sealed class NoaDoubleAlgebra
|
||||||
|
@ -1127,6 +1127,38 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPass
|
|||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign
|
||||||
(JNIEnv *, jclass, jlong, jlong);
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: getModuleParameter
|
||||||
|
* Signature: (JLjava/lang/String;)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleParameter
|
||||||
|
(JNIEnv *, jclass, jlong, jstring);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: setModuleParameter
|
||||||
|
* Signature: (JLjava/lang/String;J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleParameter
|
||||||
|
(JNIEnv *, jclass, jlong, jstring, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: getModuleBuffer
|
||||||
|
* Signature: (JLjava/lang/String;)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleBuffer
|
||||||
|
(JNIEnv *, jclass, jlong, jstring);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: setModuleBuffer
|
||||||
|
* Signature: (JLjava/lang/String;J)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer
|
||||||
|
(JNIEnv *, jclass, jlong, jstring, jlong);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user